1. Do not cache the tracing impl string in c_api_unified_experimental instead actively lookup the factory fn and store a pointer to that in default_factory. This avoid keeping a copy of the string passed in from python. As a result SetDefaultTracingEngine now returns a Status which needed to be plumbed through.

2. In `GraphContext` keep a copy of the string passed in from python since the original pointer may be prematurely destroyed.

PiperOrigin-RevId: 330647473
Change-Id: I00bf7243b30c3b6b3eb2dd1bb5ba1cbc8b46c008
This commit is contained in:
Saurabh Saxena 2020-09-08 21:05:14 -07:00 committed by TensorFlower Gardener
parent 054738d534
commit 6cbf1daf2d
8 changed files with 50 additions and 22 deletions

View File

@ -797,6 +797,7 @@ tf_cuda_cc_test(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
@ -816,6 +817,7 @@ tf_cuda_cc_test(
":c_api_test_util", ":c_api_test_util",
"//tensorflow/c:c_api", "//tensorflow/c:c_api",
"//tensorflow/c:c_test_util", "//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/cc/profiler", "//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib", "//tensorflow/core:lib",

View File

@ -39,7 +39,7 @@ static FactoriesMap& GetFactories() {
return *factories; return *factories;
} }
static const char* default_factory = "<unset>"; static tracing::FactoryFunction default_factory;
void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) { void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
assert((!GetFactories().count(name)) || assert((!GetFactories().count(name)) ||
@ -48,15 +48,15 @@ void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
GetFactories()[name] = factory; GetFactories()[name] = factory;
} }
void SetDefaultTracingEngine(const char* name) { default_factory = name; } Status SetDefaultTracingEngine(const char* name) {
auto entry = GetFactories().find(name);
static TracingContext* CreateTracingExecutionContext(const char* fn_name, if (entry != GetFactories().end()) {
TF_Status* s) { default_factory = GetFactories().find(name)->second;
auto entry = GetFactories().find(default_factory); return Status::OK();
if (entry != GetFactories().end()) return entry->second(fn_name, s); }
string msg = absl::StrCat( string msg = absl::StrCat(
"No tracing engine factory has been registered with the key '", "No tracing engine factory has been registered with the key '", name,
default_factory, "' (available: "); "' (available: ");
// Ensure deterministic (sorted) order in the error message // Ensure deterministic (sorted) order in the error message
std::set<string> factories_sorted; std::set<string> factories_sorted;
for (const auto& factory : GetFactories()) for (const auto& factory : GetFactories())
@ -68,7 +68,16 @@ static TracingContext* CreateTracingExecutionContext(const char* fn_name,
} }
msg += ")"; msg += ")";
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); return errors::InvalidArgument(msg.c_str());
}
static TracingContext* CreateTracingExecutionContext(const char* fn_name,
TF_Status* s) {
if (default_factory) {
return default_factory(fn_name, s);
}
Set_TF_Status_from_Status(
s, errors::FailedPrecondition("default_factory is nullptr"));
return nullptr; return nullptr;
} }
@ -99,8 +108,8 @@ using tensorflow::tracing::TracingContext;
using tensorflow::tracing::TracingOperation; using tensorflow::tracing::TracingOperation;
using tensorflow::tracing::TracingTensorHandle; using tensorflow::tracing::TracingTensorHandle;
void TF_SetTracingImplementation(const char* name) { void TF_SetTracingImplementation(const char* name, TF_Status* s) {
SetDefaultTracingEngine(name); Set_TF_Status_from_Status(s, SetDefaultTracingEngine(name));
} }
// Creates a new TensorFlow function, it is an execution context attached to a // Creates a new TensorFlow function, it is an execution context attached to a

View File

@ -52,7 +52,7 @@ typedef struct TF_AbstractFunction TF_AbstractFunction;
// This allows the client to swap the implementation of the tracing engine. // This allows the client to swap the implementation of the tracing engine.
// Any future call to TF_CreateFunction will use the implementation defined // Any future call to TF_CreateFunction will use the implementation defined
// here. // here.
void TF_SetTracingImplementation(const char* name); void TF_SetTracingImplementation(const char* name, TF_Status*);
// Creates a new TensorFlow function. A Function is an execution context, and as // Creates a new TensorFlow function. A Function is an execution context, and as
// such it can trace operations through TF_ExecuteOperation. After completing // such it can trace operations through TF_ExecuteOperation. After completing

View File

@ -365,9 +365,10 @@ class GraphContext : public TracingContext {
} }
auto s = TF_NewStatus(); auto s = TF_NewStatus();
func->func = TF_GraphToFunction( func->func = TF_GraphToFunction(graph_.get(), name_.data(), 0, -1, nullptr,
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(), inputs_.size(), inputs_.data(),
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s); graph_outputs.size(), graph_outputs.data(),
nullptr, nullptr, name_.data(), s);
TF_RETURN_IF_ERROR(StatusFromTF_Status(s)); TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
TF_DeleteStatus(s); TF_DeleteStatus(s);
*f = func.release(); *f = func.release();
@ -391,7 +392,7 @@ class GraphContext : public TracingContext {
private: private:
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_; std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
std::vector<TF_Output> inputs_; std::vector<TF_Output> inputs_;
const char* name_; string name_;
}; };
static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) { static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
@ -401,7 +402,7 @@ static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
// Register the tracing implemented in this file as the default tracing engine. // Register the tracing implemented in this file as the default tracing engine.
static bool register_tracing = [] { static bool register_tracing = [] {
RegisterTracingEngineFactory("graphdef", GraphTracingFactory); RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
SetDefaultTracingEngine("graphdef"); SetDefaultTracingEngine("graphdef").IgnoreError();
return true; return true;
}(); }();

View File

@ -120,7 +120,7 @@ class TracingContext : public AbstractContext {
}; };
typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*); typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
void SetDefaultTracingEngine(const char* name); Status SetDefaultTracingEngine(const char* name);
void RegisterTracingEngineFactory(const ::tensorflow::string& name, void RegisterTracingEngineFactory(const ::tensorflow::string& name,
FactoryFunction factory); FactoryFunction factory);
} // namespace tracing } // namespace tracing

View File

@ -22,10 +22,15 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
using tensorflow::Status;
using tensorflow::string; using tensorflow::string;
using tensorflow::TF_StatusPtr;
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -37,7 +42,10 @@ class UnifiedCAPI
: public ::testing::TestWithParam<std::tuple<const char*, bool>> { : public ::testing::TestWithParam<std::tuple<const char*, bool>> {
protected: protected:
void SetUp() override { void SetUp() override {
TF_SetTracingImplementation(std::get<0>(GetParam())); TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
} }
}; };

View File

@ -38,13 +38,17 @@ namespace gradients {
namespace internal { namespace internal {
namespace { namespace {
using std::vector; using std::vector;
using tensorflow::TF_StatusPtr;
using tracing::TracingOperation; using tracing::TracingOperation;
class CppGradients class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> { : public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected: protected:
void SetUp() override { void SetUp() override {
TF_SetTracingImplementation(std::get<0>(GetParam())); TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
} }
}; };

View File

@ -33,12 +33,16 @@ namespace tensorflow {
namespace gradients { namespace gradients {
namespace internal { namespace internal {
namespace { namespace {
using tensorflow::TF_StatusPtr;
class CppGradients class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> { : public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected: protected:
void SetUp() override { void SetUp() override {
TF_SetTracingImplementation(std::get<0>(GetParam())); TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
} }
}; };