1. Do not cache the tracing impl string in c_api_unified_experimental instead actively lookup the factory fn and store a pointer to that in default_factory
. This avoid keeping a copy of the string passed in from python. As a result SetDefaultTracingEngine
now returns a Status which needed to be plumbed through.
2. In `GraphContext` keep a copy of the string passed in from python since the original pointer may be prematurely destroyed. PiperOrigin-RevId: 330647473 Change-Id: I00bf7243b30c3b6b3eb2dd1bb5ba1cbc8b46c008
This commit is contained in:
parent
054738d534
commit
6cbf1daf2d
@ -797,6 +797,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -816,6 +817,7 @@ tf_cuda_cc_test(
|
||||
":c_api_test_util",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -39,7 +39,7 @@ static FactoriesMap& GetFactories() {
|
||||
return *factories;
|
||||
}
|
||||
|
||||
static const char* default_factory = "<unset>";
|
||||
static tracing::FactoryFunction default_factory;
|
||||
|
||||
void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
|
||||
assert((!GetFactories().count(name)) ||
|
||||
@ -48,15 +48,15 @@ void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
|
||||
GetFactories()[name] = factory;
|
||||
}
|
||||
|
||||
void SetDefaultTracingEngine(const char* name) { default_factory = name; }
|
||||
|
||||
static TracingContext* CreateTracingExecutionContext(const char* fn_name,
|
||||
TF_Status* s) {
|
||||
auto entry = GetFactories().find(default_factory);
|
||||
if (entry != GetFactories().end()) return entry->second(fn_name, s);
|
||||
Status SetDefaultTracingEngine(const char* name) {
|
||||
auto entry = GetFactories().find(name);
|
||||
if (entry != GetFactories().end()) {
|
||||
default_factory = GetFactories().find(name)->second;
|
||||
return Status::OK();
|
||||
}
|
||||
string msg = absl::StrCat(
|
||||
"No tracing engine factory has been registered with the key '",
|
||||
default_factory, "' (available: ");
|
||||
"No tracing engine factory has been registered with the key '", name,
|
||||
"' (available: ");
|
||||
// Ensure deterministic (sorted) order in the error message
|
||||
std::set<string> factories_sorted;
|
||||
for (const auto& factory : GetFactories())
|
||||
@ -68,7 +68,16 @@ static TracingContext* CreateTracingExecutionContext(const char* fn_name,
|
||||
}
|
||||
msg += ")";
|
||||
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return errors::InvalidArgument(msg.c_str());
|
||||
}
|
||||
|
||||
static TracingContext* CreateTracingExecutionContext(const char* fn_name,
|
||||
TF_Status* s) {
|
||||
if (default_factory) {
|
||||
return default_factory(fn_name, s);
|
||||
}
|
||||
Set_TF_Status_from_Status(
|
||||
s, errors::FailedPrecondition("default_factory is nullptr"));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -99,8 +108,8 @@ using tensorflow::tracing::TracingContext;
|
||||
using tensorflow::tracing::TracingOperation;
|
||||
using tensorflow::tracing::TracingTensorHandle;
|
||||
|
||||
void TF_SetTracingImplementation(const char* name) {
|
||||
SetDefaultTracingEngine(name);
|
||||
void TF_SetTracingImplementation(const char* name, TF_Status* s) {
|
||||
Set_TF_Status_from_Status(s, SetDefaultTracingEngine(name));
|
||||
}
|
||||
|
||||
// Creates a new TensorFlow function, it is an execution context attached to a
|
||||
|
@ -52,7 +52,7 @@ typedef struct TF_AbstractFunction TF_AbstractFunction;
|
||||
// This allows the client to swap the implementation of the tracing engine.
|
||||
// Any future call to TF_CreateFunction will use the implementation defined
|
||||
// here.
|
||||
void TF_SetTracingImplementation(const char* name);
|
||||
void TF_SetTracingImplementation(const char* name, TF_Status*);
|
||||
|
||||
// Creates a new TensorFlow function. A Function is an execution context, and as
|
||||
// such it can trace operations through TF_ExecuteOperation. After completing
|
||||
|
@ -365,9 +365,10 @@ class GraphContext : public TracingContext {
|
||||
}
|
||||
|
||||
auto s = TF_NewStatus();
|
||||
func->func = TF_GraphToFunction(
|
||||
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(),
|
||||
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s);
|
||||
func->func = TF_GraphToFunction(graph_.get(), name_.data(), 0, -1, nullptr,
|
||||
inputs_.size(), inputs_.data(),
|
||||
graph_outputs.size(), graph_outputs.data(),
|
||||
nullptr, nullptr, name_.data(), s);
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
|
||||
TF_DeleteStatus(s);
|
||||
*f = func.release();
|
||||
@ -391,7 +392,7 @@ class GraphContext : public TracingContext {
|
||||
private:
|
||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
||||
std::vector<TF_Output> inputs_;
|
||||
const char* name_;
|
||||
string name_;
|
||||
};
|
||||
|
||||
static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
|
||||
@ -401,7 +402,7 @@ static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
|
||||
// Register the tracing implemented in this file as the default tracing engine.
|
||||
static bool register_tracing = [] {
|
||||
RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
|
||||
SetDefaultTracingEngine("graphdef");
|
||||
SetDefaultTracingEngine("graphdef").IgnoreError();
|
||||
return true;
|
||||
}();
|
||||
|
||||
|
@ -120,7 +120,7 @@ class TracingContext : public AbstractContext {
|
||||
};
|
||||
|
||||
typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
|
||||
void SetDefaultTracingEngine(const char* name);
|
||||
Status SetDefaultTracingEngine(const char* name);
|
||||
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
|
||||
FactoryFunction factory);
|
||||
} // namespace tracing
|
||||
|
@ -22,10 +22,15 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
using tensorflow::Status;
|
||||
using tensorflow::string;
|
||||
using tensorflow::TF_StatusPtr;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -37,7 +42,10 @@ class UnifiedCAPI
|
||||
: public ::testing::TestWithParam<std::tuple<const char*, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()));
|
||||
TF_StatusPtr status(TF_NewStatus());
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||
Status s = StatusFromTF_Status(status.get());
|
||||
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -38,13 +38,17 @@ namespace gradients {
|
||||
namespace internal {
|
||||
namespace {
|
||||
using std::vector;
|
||||
using tensorflow::TF_StatusPtr;
|
||||
using tracing::TracingOperation;
|
||||
|
||||
class CppGradients
|
||||
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()));
|
||||
TF_StatusPtr status(TF_NewStatus());
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||
Status s = StatusFromTF_Status(status.get());
|
||||
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -33,12 +33,16 @@ namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
namespace {
|
||||
using tensorflow::TF_StatusPtr;
|
||||
|
||||
class CppGradients
|
||||
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()));
|
||||
TF_StatusPtr status(TF_NewStatus());
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||
Status s = StatusFromTF_Status(status.get());
|
||||
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user