1. Do not cache the tracing impl string in c_api_unified_experimental instead actively lookup the factory fn and store a pointer to that in default_factory
. This avoid keeping a copy of the string passed in from python. As a result SetDefaultTracingEngine
now returns a Status which needed to be plumbed through.
2. In `GraphContext` keep a copy of the string passed in from python since the original pointer may be prematurely destroyed. PiperOrigin-RevId: 330647473 Change-Id: I00bf7243b30c3b6b3eb2dd1bb5ba1cbc8b46c008
This commit is contained in:
parent
054738d534
commit
6cbf1daf2d
@ -797,6 +797,7 @@ tf_cuda_cc_test(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/platform:status",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -816,6 +817,7 @@ tf_cuda_cc_test(
|
|||||||
":c_api_test_util",
|
":c_api_test_util",
|
||||||
"//tensorflow/c:c_api",
|
"//tensorflow/c:c_api",
|
||||||
"//tensorflow/c:c_test_util",
|
"//tensorflow/c:c_test_util",
|
||||||
|
"//tensorflow/c:tf_status_helper",
|
||||||
"//tensorflow/cc/profiler",
|
"//tensorflow/cc/profiler",
|
||||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
@ -39,7 +39,7 @@ static FactoriesMap& GetFactories() {
|
|||||||
return *factories;
|
return *factories;
|
||||||
}
|
}
|
||||||
|
|
||||||
static const char* default_factory = "<unset>";
|
static tracing::FactoryFunction default_factory;
|
||||||
|
|
||||||
void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
|
void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
|
||||||
assert((!GetFactories().count(name)) ||
|
assert((!GetFactories().count(name)) ||
|
||||||
@ -48,15 +48,15 @@ void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
|
|||||||
GetFactories()[name] = factory;
|
GetFactories()[name] = factory;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetDefaultTracingEngine(const char* name) { default_factory = name; }
|
Status SetDefaultTracingEngine(const char* name) {
|
||||||
|
auto entry = GetFactories().find(name);
|
||||||
static TracingContext* CreateTracingExecutionContext(const char* fn_name,
|
if (entry != GetFactories().end()) {
|
||||||
TF_Status* s) {
|
default_factory = GetFactories().find(name)->second;
|
||||||
auto entry = GetFactories().find(default_factory);
|
return Status::OK();
|
||||||
if (entry != GetFactories().end()) return entry->second(fn_name, s);
|
}
|
||||||
string msg = absl::StrCat(
|
string msg = absl::StrCat(
|
||||||
"No tracing engine factory has been registered with the key '",
|
"No tracing engine factory has been registered with the key '", name,
|
||||||
default_factory, "' (available: ");
|
"' (available: ");
|
||||||
// Ensure deterministic (sorted) order in the error message
|
// Ensure deterministic (sorted) order in the error message
|
||||||
std::set<string> factories_sorted;
|
std::set<string> factories_sorted;
|
||||||
for (const auto& factory : GetFactories())
|
for (const auto& factory : GetFactories())
|
||||||
@ -68,7 +68,16 @@ static TracingContext* CreateTracingExecutionContext(const char* fn_name,
|
|||||||
}
|
}
|
||||||
msg += ")";
|
msg += ")";
|
||||||
|
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
return errors::InvalidArgument(msg.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
static TracingContext* CreateTracingExecutionContext(const char* fn_name,
|
||||||
|
TF_Status* s) {
|
||||||
|
if (default_factory) {
|
||||||
|
return default_factory(fn_name, s);
|
||||||
|
}
|
||||||
|
Set_TF_Status_from_Status(
|
||||||
|
s, errors::FailedPrecondition("default_factory is nullptr"));
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,8 +108,8 @@ using tensorflow::tracing::TracingContext;
|
|||||||
using tensorflow::tracing::TracingOperation;
|
using tensorflow::tracing::TracingOperation;
|
||||||
using tensorflow::tracing::TracingTensorHandle;
|
using tensorflow::tracing::TracingTensorHandle;
|
||||||
|
|
||||||
void TF_SetTracingImplementation(const char* name) {
|
void TF_SetTracingImplementation(const char* name, TF_Status* s) {
|
||||||
SetDefaultTracingEngine(name);
|
Set_TF_Status_from_Status(s, SetDefaultTracingEngine(name));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a new TensorFlow function, it is an execution context attached to a
|
// Creates a new TensorFlow function, it is an execution context attached to a
|
||||||
|
@ -52,7 +52,7 @@ typedef struct TF_AbstractFunction TF_AbstractFunction;
|
|||||||
// This allows the client to swap the implementation of the tracing engine.
|
// This allows the client to swap the implementation of the tracing engine.
|
||||||
// Any future call to TF_CreateFunction will use the implementation defined
|
// Any future call to TF_CreateFunction will use the implementation defined
|
||||||
// here.
|
// here.
|
||||||
void TF_SetTracingImplementation(const char* name);
|
void TF_SetTracingImplementation(const char* name, TF_Status*);
|
||||||
|
|
||||||
// Creates a new TensorFlow function. A Function is an execution context, and as
|
// Creates a new TensorFlow function. A Function is an execution context, and as
|
||||||
// such it can trace operations through TF_ExecuteOperation. After completing
|
// such it can trace operations through TF_ExecuteOperation. After completing
|
||||||
|
@ -365,9 +365,10 @@ class GraphContext : public TracingContext {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto s = TF_NewStatus();
|
auto s = TF_NewStatus();
|
||||||
func->func = TF_GraphToFunction(
|
func->func = TF_GraphToFunction(graph_.get(), name_.data(), 0, -1, nullptr,
|
||||||
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(),
|
inputs_.size(), inputs_.data(),
|
||||||
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s);
|
graph_outputs.size(), graph_outputs.data(),
|
||||||
|
nullptr, nullptr, name_.data(), s);
|
||||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
|
TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
|
||||||
TF_DeleteStatus(s);
|
TF_DeleteStatus(s);
|
||||||
*f = func.release();
|
*f = func.release();
|
||||||
@ -391,7 +392,7 @@ class GraphContext : public TracingContext {
|
|||||||
private:
|
private:
|
||||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
||||||
std::vector<TF_Output> inputs_;
|
std::vector<TF_Output> inputs_;
|
||||||
const char* name_;
|
string name_;
|
||||||
};
|
};
|
||||||
|
|
||||||
static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
|
static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
|
||||||
@ -401,7 +402,7 @@ static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
|
|||||||
// Register the tracing implemented in this file as the default tracing engine.
|
// Register the tracing implemented in this file as the default tracing engine.
|
||||||
static bool register_tracing = [] {
|
static bool register_tracing = [] {
|
||||||
RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
|
RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
|
||||||
SetDefaultTracingEngine("graphdef");
|
SetDefaultTracingEngine("graphdef").IgnoreError();
|
||||||
return true;
|
return true;
|
||||||
}();
|
}();
|
||||||
|
|
||||||
|
@ -120,7 +120,7 @@ class TracingContext : public AbstractContext {
|
|||||||
};
|
};
|
||||||
|
|
||||||
typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
|
typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
|
||||||
void SetDefaultTracingEngine(const char* name);
|
Status SetDefaultTracingEngine(const char* name);
|
||||||
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
|
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
|
||||||
FactoryFunction factory);
|
FactoryFunction factory);
|
||||||
} // namespace tracing
|
} // namespace tracing
|
||||||
|
@ -22,10 +22,15 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/c/tf_tensor.h"
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
using tensorflow::Status;
|
||||||
using tensorflow::string;
|
using tensorflow::string;
|
||||||
|
using tensorflow::TF_StatusPtr;
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
@ -37,7 +42,10 @@ class UnifiedCAPI
|
|||||||
: public ::testing::TestWithParam<std::tuple<const char*, bool>> {
|
: public ::testing::TestWithParam<std::tuple<const char*, bool>> {
|
||||||
protected:
|
protected:
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
TF_SetTracingImplementation(std::get<0>(GetParam()));
|
TF_StatusPtr status(TF_NewStatus());
|
||||||
|
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||||
|
Status s = StatusFromTF_Status(status.get());
|
||||||
|
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -38,13 +38,17 @@ namespace gradients {
|
|||||||
namespace internal {
|
namespace internal {
|
||||||
namespace {
|
namespace {
|
||||||
using std::vector;
|
using std::vector;
|
||||||
|
using tensorflow::TF_StatusPtr;
|
||||||
using tracing::TracingOperation;
|
using tracing::TracingOperation;
|
||||||
|
|
||||||
class CppGradients
|
class CppGradients
|
||||||
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||||
protected:
|
protected:
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
TF_SetTracingImplementation(std::get<0>(GetParam()));
|
TF_StatusPtr status(TF_NewStatus());
|
||||||
|
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||||
|
Status s = StatusFromTF_Status(status.get());
|
||||||
|
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -33,12 +33,16 @@ namespace tensorflow {
|
|||||||
namespace gradients {
|
namespace gradients {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
namespace {
|
namespace {
|
||||||
|
using tensorflow::TF_StatusPtr;
|
||||||
|
|
||||||
class CppGradients
|
class CppGradients
|
||||||
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||||
protected:
|
protected:
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
TF_SetTracingImplementation(std::get<0>(GetParam()));
|
TF_StatusPtr status(TF_NewStatus());
|
||||||
|
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||||
|
Status s = StatusFromTF_Status(status.get());
|
||||||
|
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user