diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 1a3b348e8f9..232f53f5c30 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -797,6 +797,7 @@ tf_cuda_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:status", "@com_google_absl//absl/strings", ], ) @@ -816,6 +817,7 @@ tf_cuda_cc_test( ":c_api_test_util", "//tensorflow/c:c_api", "//tensorflow/c:c_test_util", + "//tensorflow/c:tf_status_helper", "//tensorflow/cc/profiler", "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/core:lib", diff --git a/tensorflow/c/eager/c_api_unified_experimental.cc b/tensorflow/c/eager/c_api_unified_experimental.cc index 8408f7ef60f..2d290df19ce 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.cc +++ b/tensorflow/c/eager/c_api_unified_experimental.cc @@ -39,7 +39,7 @@ static FactoriesMap& GetFactories() { return *factories; } -static const char* default_factory = ""; +static tracing::FactoryFunction default_factory; void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) { assert((!GetFactories().count(name)) || @@ -48,15 +48,15 @@ void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) { GetFactories()[name] = factory; } -void SetDefaultTracingEngine(const char* name) { default_factory = name; } - -static TracingContext* CreateTracingExecutionContext(const char* fn_name, - TF_Status* s) { - auto entry = GetFactories().find(default_factory); - if (entry != GetFactories().end()) return entry->second(fn_name, s); +Status SetDefaultTracingEngine(const char* name) { + auto entry = GetFactories().find(name); + if (entry != GetFactories().end()) { + default_factory = GetFactories().find(name)->second; + return Status::OK(); + } string msg = absl::StrCat( - "No tracing engine factory has been registered with the key '", - default_factory, "' (available: "); + "No tracing engine factory has been registered with the key '", name, + "' (available: "); // Ensure deterministic (sorted) order in the error message std::set factories_sorted; for (const auto& factory : GetFactories()) @@ -68,7 +68,16 @@ static TracingContext* CreateTracingExecutionContext(const char* fn_name, } msg += ")"; - TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); + return errors::InvalidArgument(msg.c_str()); +} + +static TracingContext* CreateTracingExecutionContext(const char* fn_name, + TF_Status* s) { + if (default_factory) { + return default_factory(fn_name, s); + } + Set_TF_Status_from_Status( + s, errors::FailedPrecondition("default_factory is nullptr")); return nullptr; } @@ -99,8 +108,8 @@ using tensorflow::tracing::TracingContext; using tensorflow::tracing::TracingOperation; using tensorflow::tracing::TracingTensorHandle; -void TF_SetTracingImplementation(const char* name) { - SetDefaultTracingEngine(name); +void TF_SetTracingImplementation(const char* name, TF_Status* s) { + Set_TF_Status_from_Status(s, SetDefaultTracingEngine(name)); } // Creates a new TensorFlow function, it is an execution context attached to a diff --git a/tensorflow/c/eager/c_api_unified_experimental.h b/tensorflow/c/eager/c_api_unified_experimental.h index b66869b4290..d216b4e694b 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.h +++ b/tensorflow/c/eager/c_api_unified_experimental.h @@ -52,7 +52,7 @@ typedef struct TF_AbstractFunction TF_AbstractFunction; // This allows the client to swap the implementation of the tracing engine. // Any future call to TF_CreateFunction will use the implementation defined // here. -void TF_SetTracingImplementation(const char* name); +void TF_SetTracingImplementation(const char* name, TF_Status*); // Creates a new TensorFlow function. A Function is an execution context, and as // such it can trace operations through TF_ExecuteOperation. After completing diff --git a/tensorflow/c/eager/c_api_unified_experimental_graph.cc b/tensorflow/c/eager/c_api_unified_experimental_graph.cc index 9d064039141..0e9d6c18157 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_graph.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_graph.cc @@ -365,9 +365,10 @@ class GraphContext : public TracingContext { } auto s = TF_NewStatus(); - func->func = TF_GraphToFunction( - graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(), - graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s); + func->func = TF_GraphToFunction(graph_.get(), name_.data(), 0, -1, nullptr, + inputs_.size(), inputs_.data(), + graph_outputs.size(), graph_outputs.data(), + nullptr, nullptr, name_.data(), s); TF_RETURN_IF_ERROR(StatusFromTF_Status(s)); TF_DeleteStatus(s); *f = func.release(); @@ -391,7 +392,7 @@ class GraphContext : public TracingContext { private: std::unique_ptr graph_; std::vector inputs_; - const char* name_; + string name_; }; static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) { @@ -401,7 +402,7 @@ static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) { // Register the tracing implemented in this file as the default tracing engine. static bool register_tracing = [] { RegisterTracingEngineFactory("graphdef", GraphTracingFactory); - SetDefaultTracingEngine("graphdef"); + SetDefaultTracingEngine("graphdef").IgnoreError(); return true; }(); diff --git a/tensorflow/c/eager/c_api_unified_experimental_internal.h b/tensorflow/c/eager/c_api_unified_experimental_internal.h index c00e04d98af..9433fe8f120 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_internal.h +++ b/tensorflow/c/eager/c_api_unified_experimental_internal.h @@ -120,7 +120,7 @@ class TracingContext : public AbstractContext { }; typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*); -void SetDefaultTracingEngine(const char* name); +Status SetDefaultTracingEngine(const char* name); void RegisterTracingEngineFactory(const ::tensorflow::string& name, FactoryFunction factory); } // namespace tracing diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc index 7b3a497a0c5..432ddb4b2d4 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_test.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc @@ -22,10 +22,15 @@ limitations under the License. #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" +using tensorflow::Status; using tensorflow::string; +using tensorflow::TF_StatusPtr; namespace tensorflow { namespace { @@ -37,7 +42,10 @@ class UnifiedCAPI : public ::testing::TestWithParam> { protected: void SetUp() override { - TF_SetTracingImplementation(std::get<0>(GetParam())); + TF_StatusPtr status(TF_NewStatus()); + TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); + Status s = StatusFromTF_Status(status.get()); + CHECK_EQ(errors::OK, s.code()) << s.error_message(); } }; diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc index 56f0b847002..3aedf55e97a 100644 --- a/tensorflow/c/eager/gradients_test.cc +++ b/tensorflow/c/eager/gradients_test.cc @@ -38,13 +38,17 @@ namespace gradients { namespace internal { namespace { using std::vector; +using tensorflow::TF_StatusPtr; using tracing::TracingOperation; class CppGradients : public ::testing::TestWithParam> { protected: void SetUp() override { - TF_SetTracingImplementation(std::get<0>(GetParam())); + TF_StatusPtr status(TF_NewStatus()); + TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); + Status s = StatusFromTF_Status(status.get()); + CHECK_EQ(errors::OK, s.code()) << s.error_message(); } }; diff --git a/tensorflow/c/eager/mnist_gradients_test.cc b/tensorflow/c/eager/mnist_gradients_test.cc index 1f04e25820e..0dc38b21a2e 100644 --- a/tensorflow/c/eager/mnist_gradients_test.cc +++ b/tensorflow/c/eager/mnist_gradients_test.cc @@ -33,12 +33,16 @@ namespace tensorflow { namespace gradients { namespace internal { namespace { +using tensorflow::TF_StatusPtr; class CppGradients : public ::testing::TestWithParam> { protected: void SetUp() override { - TF_SetTracingImplementation(std::get<0>(GetParam())); + TF_StatusPtr status(TF_NewStatus()); + TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); + Status s = StatusFromTF_Status(status.get()); + CHECK_EQ(errors::OK, s.code()) << s.error_message(); } };