1. Do not cache the tracing impl string in c_api_unified_experimental instead actively lookup the factory fn and store a pointer to that in default_factory. This avoid keeping a copy of the string passed in from python. As a result SetDefaultTracingEngine now returns a Status which needed to be plumbed through.

2. In `GraphContext` keep a copy of the string passed in from python since the original pointer may be prematurely destroyed.

PiperOrigin-RevId: 330647473
Change-Id: I00bf7243b30c3b6b3eb2dd1bb5ba1cbc8b46c008
This commit is contained in:
Saurabh Saxena 2020-09-08 21:05:14 -07:00 committed by TensorFlower Gardener
parent 054738d534
commit 6cbf1daf2d
8 changed files with 50 additions and 22 deletions

View File

@ -797,6 +797,7 @@ tf_cuda_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/strings",
],
)
@ -816,6 +817,7 @@ tf_cuda_cc_test(
":c_api_test_util",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",

View File

@ -39,7 +39,7 @@ static FactoriesMap& GetFactories() {
return *factories;
}
static const char* default_factory = "<unset>";
static tracing::FactoryFunction default_factory;
void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
assert((!GetFactories().count(name)) ||
@ -48,15 +48,15 @@ void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
GetFactories()[name] = factory;
}
void SetDefaultTracingEngine(const char* name) { default_factory = name; }
static TracingContext* CreateTracingExecutionContext(const char* fn_name,
TF_Status* s) {
auto entry = GetFactories().find(default_factory);
if (entry != GetFactories().end()) return entry->second(fn_name, s);
Status SetDefaultTracingEngine(const char* name) {
auto entry = GetFactories().find(name);
if (entry != GetFactories().end()) {
default_factory = GetFactories().find(name)->second;
return Status::OK();
}
string msg = absl::StrCat(
"No tracing engine factory has been registered with the key '",
default_factory, "' (available: ");
"No tracing engine factory has been registered with the key '", name,
"' (available: ");
// Ensure deterministic (sorted) order in the error message
std::set<string> factories_sorted;
for (const auto& factory : GetFactories())
@ -68,7 +68,16 @@ static TracingContext* CreateTracingExecutionContext(const char* fn_name,
}
msg += ")";
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return errors::InvalidArgument(msg.c_str());
}
static TracingContext* CreateTracingExecutionContext(const char* fn_name,
TF_Status* s) {
if (default_factory) {
return default_factory(fn_name, s);
}
Set_TF_Status_from_Status(
s, errors::FailedPrecondition("default_factory is nullptr"));
return nullptr;
}
@ -99,8 +108,8 @@ using tensorflow::tracing::TracingContext;
using tensorflow::tracing::TracingOperation;
using tensorflow::tracing::TracingTensorHandle;
void TF_SetTracingImplementation(const char* name) {
SetDefaultTracingEngine(name);
void TF_SetTracingImplementation(const char* name, TF_Status* s) {
Set_TF_Status_from_Status(s, SetDefaultTracingEngine(name));
}
// Creates a new TensorFlow function, it is an execution context attached to a

View File

@ -52,7 +52,7 @@ typedef struct TF_AbstractFunction TF_AbstractFunction;
// This allows the client to swap the implementation of the tracing engine.
// Any future call to TF_CreateFunction will use the implementation defined
// here.
void TF_SetTracingImplementation(const char* name);
void TF_SetTracingImplementation(const char* name, TF_Status*);
// Creates a new TensorFlow function. A Function is an execution context, and as
// such it can trace operations through TF_ExecuteOperation. After completing

View File

@ -365,9 +365,10 @@ class GraphContext : public TracingContext {
}
auto s = TF_NewStatus();
func->func = TF_GraphToFunction(
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(),
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s);
func->func = TF_GraphToFunction(graph_.get(), name_.data(), 0, -1, nullptr,
inputs_.size(), inputs_.data(),
graph_outputs.size(), graph_outputs.data(),
nullptr, nullptr, name_.data(), s);
TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
TF_DeleteStatus(s);
*f = func.release();
@ -391,7 +392,7 @@ class GraphContext : public TracingContext {
private:
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
std::vector<TF_Output> inputs_;
const char* name_;
string name_;
};
static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
@ -401,7 +402,7 @@ static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
// Register the tracing implemented in this file as the default tracing engine.
static bool register_tracing = [] {
RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
SetDefaultTracingEngine("graphdef");
SetDefaultTracingEngine("graphdef").IgnoreError();
return true;
}();

View File

@ -120,7 +120,7 @@ class TracingContext : public AbstractContext {
};
typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
void SetDefaultTracingEngine(const char* name);
Status SetDefaultTracingEngine(const char* name);
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
FactoryFunction factory);
} // namespace tracing

View File

@ -22,10 +22,15 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
using tensorflow::Status;
using tensorflow::string;
using tensorflow::TF_StatusPtr;
namespace tensorflow {
namespace {
@ -37,7 +42,10 @@ class UnifiedCAPI
: public ::testing::TestWithParam<std::tuple<const char*, bool>> {
protected:
void SetUp() override {
TF_SetTracingImplementation(std::get<0>(GetParam()));
TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
}
};

View File

@ -38,13 +38,17 @@ namespace gradients {
namespace internal {
namespace {
using std::vector;
using tensorflow::TF_StatusPtr;
using tracing::TracingOperation;
class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
void SetUp() override {
TF_SetTracingImplementation(std::get<0>(GetParam()));
TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
}
};

View File

@ -33,12 +33,16 @@ namespace tensorflow {
namespace gradients {
namespace internal {
namespace {
using tensorflow::TF_StatusPtr;
class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
void SetUp() override {
TF_SetTracingImplementation(std::get<0>(GetParam()));
TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
}
};