From ec2cc2903f54d526dfdcfa314c9e181a8a5f76fa Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Thu, 14 May 2020 07:41:59 -0700 Subject: [PATCH] Introduce a higher-level function handling in the tracing oriented unified API This patch intends to make function tracing more of a first class concept in the API. It tries to move away from the "flat graph" model with "placeholder" operation introduced with the expectation to turn them into function parameters later. Instead the user starts by creating an empty function which is an ExecutionContext (and as such can trace operations). Function parameters can get added to this context using a dedicated API returning an AbstractTensor. The diff in UnifiedCAPI/TestBasicGraph is probably a good illustration of the change from a client point of view. Another important point of this patch is to make it so that no C public API is defined in the `c_api_unified_experimental_graph.cc` file, instead the implementation is dispatched based on a registered factory function to create the tracing context. This will allow to swap the tracing implementation through injection later. PiperOrigin-RevId: 311529850 Change-Id: I822047f4306835abc0e044dc87c14179596f64bd --- tensorflow/c/eager/BUILD | 2 + .../c/eager/c_api_unified_experimental.cc | 69 +++++++++++ .../c/eager/c_api_unified_experimental.h | 26 ++-- .../eager/c_api_unified_experimental_eager.cc | 11 ++ .../eager/c_api_unified_experimental_graph.cc | 111 ++++++++---------- .../c_api_unified_experimental_internal.h | 17 +++ .../eager/c_api_unified_experimental_test.cc | 73 +++++------- 7 files changed, 193 insertions(+), 116 deletions(-) diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index d3059df1bef..69808f6f49f 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -448,6 +448,8 @@ tf_cuda_library( "//conditions:default": [], }) + [ "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/container:flat_hash_map", "//tensorflow/c:tf_status_helper", "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", diff --git a/tensorflow/c/eager/c_api_unified_experimental.cc b/tensorflow/c/eager/c_api_unified_experimental.cc index 68afffb28b4..d29c457798e 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.cc +++ b/tensorflow/c/eager/c_api_unified_experimental.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" @@ -26,6 +28,51 @@ using tensorflow::string; using tensorflow::internal::OutputList; using tensorflow::internal::unwrap; +namespace tensorflow { +namespace internal { +typedef absl::flat_hash_map FactoriesMap; + +static FactoriesMap& GetFactories() { + static FactoriesMap* factories = new FactoriesMap; + return *factories; +} + +static const char* default_factory = ""; + +void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) { + assert((!GetFactories().count(name)) || + (GetFactories()[name] == factory) && + "Duplicate tracing factory registration"); + GetFactories()[name] = factory; +} + +void SetDefaultTracingEngine(const char* name) { default_factory = name; } + +static ExecutionContext* CreateTracingExecutionContext(const char* fn_name, + TF_Status* s) { + auto entry = GetFactories().find(default_factory); + if (entry != GetFactories().end()) return entry->second(fn_name, s); + string msg = absl::StrCat( + "No tracing engine factory has been registered with the key '", + default_factory, "' (available: "); + // Ensure deterministic (sorted) order in the error message + std::set factories_sorted; + for (const auto& factory : GetFactories()) + factories_sorted.insert(factory.first); + const char* comma = ""; + for (const string& factory : factories_sorted) { + msg += comma + factory; + comma = ", "; + } + msg += ")"; + + TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); + return nullptr; +} + +} // end namespace internal +} // end namespace tensorflow + // ============================================================================= // Public C API entry points // @@ -36,6 +83,28 @@ using tensorflow::internal::unwrap; // // ============================================================================= +void TF_SetTracingImplementation(const char* name) { + tensorflow::internal::SetDefaultTracingEngine(name); +} + +// Creates a new TensorFlow function, it is an execution context attached to a +// given tracing context. +TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) { + return wrap(tensorflow::internal::CreateTracingExecutionContext(fn_name, s)); +} + +TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx, + TF_OutputList* outputs, TF_Status* s) { + auto* func = wrap(unwrap(ctx)->Finalize(unwrap(outputs), s)); + TF_DeleteExecutionContext(ctx); + return func; +} + +TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func, + TF_DataType dtype, TF_Status* s) { + return wrap(unwrap(func)->AddParameter(dtype, s)); +} + void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); } TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) { diff --git a/tensorflow/c/eager/c_api_unified_experimental.h b/tensorflow/c/eager/c_api_unified_experimental.h index be8fc64c2e1..512717caa34 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.h +++ b/tensorflow/c/eager/c_api_unified_experimental.h @@ -49,15 +49,26 @@ typedef struct TF_AbstractOp TF_AbstractOp; // setting functional attributes of other composite ops e.g. control flow. typedef struct TF_AbstractFunction TF_AbstractFunction; -// Creates a context for tracing the execution of operations into a function. -TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s); +// This allows the client to swap the implementation of the tracing engine. +// Any future call to TF_CreateFunction will use the implementation defined +// here. +void TF_SetTracingImplementation(const char* name); + +// Creates a new TensorFlow function. A Function is an execution context, and as +// such it can trace operations through TF_ExecuteOperation. After completing +// tracing, a function can be obtained by TF_FinalizeFunction. +TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* status); // Creates a context for eager execution of operations. TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*, TF_Status* s); - void TF_DeleteExecutionContext(TF_ExecutionContext*); +// Add a new parameter to a TensorFlow Function. +// TODO(aminim): what about shape? +TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func, + TF_DataType dtype, TF_Status* s); + // Create an operation suitable to use with the provided context. The operation // requires its type (e.g. "AddV2") to be set independently. TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx); @@ -100,13 +111,12 @@ void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs, TF_ExecutionContext* ctx, TF_Status* s); // Creates a new TF_AbstractFunction from the current tracing states in the -// context. The returned TF_GraphToFunction must be deleted by the client. +// context. The provided `ctx` is consumed by this API call and deleted. +// The returned TF_AbstractFunction must be deleted by the client, // TODO(aminim): clarify the contract on the state of the context after this // call. -TF_AbstractFunction* TF_ExecutionContextToFunction( - const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs, - const TF_AbstractTensor* inputs, int num_outputs, - const TF_AbstractTensor* outputs, TF_Status* status); +TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx, + TF_OutputList*, TF_Status*); void TF_DeleteAbstractFunction(TF_AbstractFunction*); diff --git a/tensorflow/c/eager/c_api_unified_experimental_eager.cc b/tensorflow/c/eager/c_api_unified_experimental_eager.cc index 820c61445fb..cf8cf845834 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_eager.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_eager.cc @@ -123,6 +123,17 @@ class EagerContext : public ExecutionContext { } } + AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override { + TF_SetStatus(s, TF_INVALID_ARGUMENT, + "Can't add function parameter on an eager context."); + return nullptr; + } + AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override { + TF_SetStatus(s, TF_INVALID_ARGUMENT, + "Can't use finalize function on an eager context."); + return nullptr; + } + void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override { auto* func = afunc->GetTfFunction(s); if (!func) { diff --git a/tensorflow/c/eager/c_api_unified_experimental_graph.cc b/tensorflow/c/eager/c_api_unified_experimental_graph.cc index 36f8353894b..e38332e3e8e 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_graph.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_graph.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_unified_experimental.h" @@ -114,12 +115,14 @@ struct GraphFunction : public AbstractFunction { static constexpr AbstractFunctionKind kKind = kGraphFunc; }; -// GraphContext wraps a TF_Graph and manages the "execution" of operation, i.e. -// adding them to the graph. +// GraphContext wraps a TF_Graph modeling a single function and manages the +// "execution" of operation, i.e. adding them to the function. class GraphContext : public ExecutionContext { public: - GraphContext() - : ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {} + explicit GraphContext(const char* name) + : ExecutionContext(kKind), + graph_(new TF_Graph(), TF_DeleteGraph), + name_(name) {} AbstractOp* CreateOperation() override { // TODO(srbs): Should the lifetime of this op be tied to the context. @@ -164,24 +167,38 @@ class GraphContext : public ExecutionContext { } } - TF_Function* ToFunction(const char* fn_name, int num_inputs, - const GraphTensor* inputs, int num_outputs, - const GraphTensor* outputs, TF_Status* status) const { - std::vector graph_inputs; - graph_inputs.resize(num_inputs); + AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override { + TF_OperationDescription* opdesc = + TF_NewOperation(graph_.get(), "Placeholder", + absl::StrCat("_input_", inputs_.size()).c_str()); + TF_SetAttrType(opdesc, "dtype", dtype); + auto* operation = TF_FinishOperation(opdesc, s); + if (!s->status.ok()) return nullptr; + + inputs_.push_back(TF_Output{operation, 0}); + return new GraphTensor(inputs_.back(), this); + } + + AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override { + std::unique_ptr func(new GraphFunction); std::vector graph_outputs; - graph_outputs.resize(num_outputs); - for (int i = 0; i < num_inputs; i++) { - graph_inputs[i] = inputs[i].output; - } - for (int i = 0; i < num_outputs; i++) { - graph_outputs[i] = outputs[i].output; + graph_outputs.reserve(outputs->outputs.size()); + for (AbstractTensor* abstract_output : outputs->outputs) { + GraphTensor* output = dyncast(abstract_output); + if (!output) { + TF_SetStatus(s, TF_UNIMPLEMENTED, + "Returning a non-graph tensor from a function has not " + "been implemented yet."); + return nullptr; + } + graph_outputs.push_back(output->output); } - return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr, - graph_inputs.size(), graph_inputs.data(), - graph_outputs.size(), graph_outputs.data(), - nullptr, nullptr, fn_name, status); + func->func = TF_GraphToFunction( + graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(), + graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s); + if (TF_GetCode(s) != TF_OK) return nullptr; + return func.release(); } void RegisterFunction(AbstractFunction* func, TF_Status* s) override { @@ -195,54 +212,20 @@ class GraphContext : public ExecutionContext { private: std::unique_ptr graph_; + std::vector inputs_; + const char* name_; }; -// Helper that converts the graph currently held in the context into a function. -static AbstractFunction* ExecutionContextToFunction( - const ExecutionContext* fn_body, const char* fn_name, int num_inputs, - const AbstractTensor* inputs, int num_outputs, - const AbstractTensor* outputs, TF_Status* status) { - auto* graph_ctx = dyncast(fn_body); - if (graph_ctx == nullptr) { - TF_SetStatus(status, TF_INVALID_ARGUMENT, - "fn_body is not a TF_GraphContext."); - return nullptr; - } - auto* graph_inputs = dyncast(inputs); - if (!graph_inputs) { - TF_SetStatus(status, TF_INVALID_ARGUMENT, "inputs aren't GraphTensors."); - return nullptr; - } - auto* graph_outputs = dyncast(outputs); - if (!graph_outputs) { - TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors."); - return nullptr; - } - GraphFunction* func = new GraphFunction; - func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs, - num_outputs, graph_outputs, status); - return func; +static ExecutionContext* GraphTracingFactory(const char* name, TF_Status* s) { + return new GraphContext(name); } +// Register the tracing implemented in this file as the default tracing engine. +static bool register_tracing = [] { + RegisterTracingEngineFactory("graphdef", GraphTracingFactory); + SetDefaultTracingEngine("graphdef"); + return true; +}(); + } // namespace internal } // namespace tensorflow - -// ============================================================================= -// Public C API entry points -// These are only the entry points specific to the Graph API. -// ============================================================================= - -using tensorflow::internal::unwrap; - -TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) { - return wrap(new tensorflow::internal::GraphContext()); -} - -TF_AbstractFunction* TF_ExecutionContextToFunction( - const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs, - const TF_AbstractTensor* inputs, int num_outputs, - const TF_AbstractTensor* outputs, TF_Status* status) { - return wrap(ExecutionContextToFunction(unwrap(fn_body), fn_name, num_inputs, - unwrap(inputs), num_outputs, - unwrap(outputs), status)); -} diff --git a/tensorflow/c/eager/c_api_unified_experimental_internal.h b/tensorflow/c/eager/c_api_unified_experimental_internal.h index ab085a20ff0..49212a230ee 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_internal.h +++ b/tensorflow/c/eager/c_api_unified_experimental_internal.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace internal { @@ -148,6 +149,17 @@ struct ExecutionContext { // Creates an empty AbstractOperation suitable to use with this context. virtual AbstractOp* CreateOperation() = 0; + // Add a function parameter and return the corresponding tensor. + // This is only valid with an ExecutionContext obtained from a TracingContext, + // it'll always error out with an eager context. + virtual AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) = 0; + + // Finalize this context and make a function out of it. The context is in a + // invalid state after this call and must be destroyed. + // This is only valid with an ExecutionContext obtained from a TracingContext, + // it'll always error out with an eager context. + virtual AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) = 0; + // Registers a functions with this context, after this the function is // available to be called/referenced by its name in this context. virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0; @@ -156,6 +168,11 @@ struct ExecutionContext { const ExecutionContextKind k; }; +typedef ExecutionContext* (*FactoryFunction)(const char* fn_name, TF_Status*); +void SetDefaultTracingEngine(const char* name); +void RegisterTracingEngineFactory(const ::tensorflow::string& name, + FactoryFunction factory); + // Create utilities to wrap/unwrap: this convert from the C opaque types to the // C++ implementation, and back. #define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \ diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc index bd99189852e..9f56c8aa579 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_test.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc @@ -29,7 +29,12 @@ using tensorflow::string; namespace tensorflow { namespace { -TEST(UnifiedCAPI, TestBasicEager) { +class UnifiedCAPI : public ::testing::TestWithParam { + protected: + void SetUp() override { TF_SetTracingImplementation(GetParam()); } +}; + +TEST_P(UnifiedCAPI, TestBasicEager) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -81,33 +86,18 @@ TEST(UnifiedCAPI, TestBasicEager) { TF_DeleteExecutionContext(ctx); } -TEST(UnifiedCAPI, TestBasicGraph) { +TEST_P(UnifiedCAPI, TestBasicGraph) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); - TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); + // Start a new function / execution context. + string fn_name = "double"; + TF_ExecutionContext* graph_ctx = + TF_CreateFunction(fn_name.c_str(), status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - // Add a placeholder to the graph. - auto* placeholder_op = TF_NewAbstractOp(graph_ctx); - TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get()); + auto* placeholder_t = + TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - - // Build inputs and outputs. - TF_OutputList* placeholder_outputs = TF_NewOutputList(); - - // Execute. - TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs, - graph_ctx, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs)); - TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0); - - // Delete placeholder op. - TF_DeleteAbstractOp(placeholder_op); // Build an abstract operation. auto* add_op = TF_NewAbstractOp(graph_ctx); @@ -123,17 +113,13 @@ TEST(UnifiedCAPI, TestBasicGraph) { // Execute. TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_AbstractTensor* output_t = TF_OutputListGet(add_outputs, 0); // Clean up operation and inputs. TF_DeleteAbstractOp(add_op); - string fn_name = "double"; - TF_AbstractFunction* func = TF_ExecutionContextToFunction( - graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get()); + TF_AbstractFunction* func = + TF_FinalizeFunction(graph_ctx, add_outputs, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_DeleteAbstractTensor(placeholder_t); - TF_DeleteAbstractTensor(output_t); // Build eager context. TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -174,18 +160,16 @@ TEST(UnifiedCAPI, TestBasicGraph) { ASSERT_EQ(*f_value, 4.0); TF_DeleteOutputList(add_outputs); - TF_DeleteOutputList(placeholder_outputs); TF_DeleteAbstractOp(fn_op); TF_DeleteAbstractTensor(input_t); TF_DeleteAbstractTensor(final_result); TF_DeleteTensor(f_t); TF_DeleteAbstractFunction(func); - TF_DeleteExecutionContext(graph_ctx); TF_DeleteExecutionContext(eager_execution_ctx); } -TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) { +TEST_P(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -193,18 +177,15 @@ TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) { ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_DeleteContextOptions(opts); - TF_AbstractFunction* func = TF_ExecutionContextToFunction( - ctx, nullptr, 0, nullptr, 0, nullptr, status.get()); + TF_AbstractFunction* func = TF_FinalizeFunction(ctx, nullptr, status.get()); ASSERT_EQ(nullptr, func); ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); - - TF_DeleteExecutionContext(ctx); } -TEST(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) { +TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); - TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); + TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Add a placeholder to the graph. @@ -222,10 +203,10 @@ TEST(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) { TF_DeleteExecutionContext(graph_ctx); } -TEST(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) { +TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); - TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); + TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Add a placeholder to the graph. @@ -243,7 +224,7 @@ TEST(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) { TF_DeleteExecutionContext(graph_ctx); } -TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) { +TEST_P(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) { // Build an Eager context. std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -273,7 +254,8 @@ TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) { ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Build a Graph context. - TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Execute eager op using graph context. @@ -289,10 +271,11 @@ TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) { TF_DeleteExecutionContext(graph_ctx); } -TEST(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) { +TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); - TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Add a placeholder to the graph. @@ -349,5 +332,7 @@ TEST(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) { TF_DeleteExecutionContext(eager_execution_ctx); } +INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Values("graphdef")); + } // namespace } // namespace tensorflow