Split out circular dependency between op and op_kernel. Add tensorflow/core/framework:op target.
PiperOrigin-RevId: 290350328 Change-Id: Id04a6f8ce3689f658dacfa159269b1d468d3dc91
This commit is contained in:
parent
0197b66905
commit
1e22bd9c0a
tensorflow/core
@ -2317,9 +2317,11 @@ tf_cuda_library(
|
||||
"//tensorflow/core/framework:common_shape_fns",
|
||||
"//tensorflow/core/framework:node_def_util",
|
||||
"//tensorflow/core/framework:numeric_types",
|
||||
"//tensorflow/core/framework:op",
|
||||
"//tensorflow/core/framework:op_def_builder",
|
||||
"//tensorflow/core/framework:op_def_util",
|
||||
"//tensorflow/core/framework:resource_handle",
|
||||
"//tensorflow/core/framework:selective_registration",
|
||||
"//tensorflow/core/framework:shape_inference",
|
||||
"//tensorflow/core/framework:tensor",
|
||||
"//tensorflow/core/framework:tensor_shape",
|
||||
|
@ -46,7 +46,6 @@ exports_files(
|
||||
"model.h",
|
||||
"node_def_builder.h",
|
||||
"numeric_op.h",
|
||||
"op.h",
|
||||
"op_kernel.h",
|
||||
"op_segment.h",
|
||||
"ops_util.h",
|
||||
@ -61,7 +60,6 @@ exports_files(
|
||||
"resource_var.h",
|
||||
"run_handler.h",
|
||||
"run_handler_util.h",
|
||||
"selective_registration.h",
|
||||
"session_state.h",
|
||||
"shared_ptr_variant.h",
|
||||
"stats_aggregator.h",
|
||||
@ -235,7 +233,6 @@ filegroup(
|
||||
"memory_types.cc",
|
||||
"model.cc",
|
||||
"node_def_builder.cc",
|
||||
"op.cc",
|
||||
"op_kernel.cc",
|
||||
"op_segment.cc",
|
||||
"ops_util.cc",
|
||||
@ -878,6 +875,34 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "selective_registration",
|
||||
hdrs = ["selective_registration.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "op",
|
||||
srcs = ["op.cc"],
|
||||
hdrs = ["op.h"],
|
||||
deps = [
|
||||
":op_def_builder",
|
||||
":op_def_util",
|
||||
":selective_registration",
|
||||
"//tensorflow/core/lib/core:errors",
|
||||
"//tensorflow/core/lib/core:status",
|
||||
"//tensorflow/core/lib/gtl:map_util",
|
||||
"//tensorflow/core/lib/strings:str_util",
|
||||
"//tensorflow/core/lib/strings:strcat",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform:macros",
|
||||
"//tensorflow/core/platform:mutex",
|
||||
"//tensorflow/core/platform:platform_port",
|
||||
"//tensorflow/core/platform:protobuf",
|
||||
"//tensorflow/core/platform:thread_annotations",
|
||||
"//tensorflow/core/platform:types",
|
||||
],
|
||||
)
|
||||
|
||||
# Files whose users still need to be migrated from core:framework to the
|
||||
# above targets.
|
||||
# TODO(gonnet): Remove these files once targets depending on them have
|
||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/op_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
@ -32,6 +31,11 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status DefaultValidator(const OpRegistryInterface& op_registry) {
|
||||
LOG(WARNING) << "No kernel validator registered with OpRegistry.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// OpRegistry -----------------------------------------------------------------
|
||||
|
||||
OpRegistryInterface::~OpRegistryInterface() {}
|
||||
@ -45,7 +49,8 @@ Status OpRegistryInterface::LookUpOpDef(const string& op_type_name,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
OpRegistry::OpRegistry() : initialized_(false) {}
|
||||
OpRegistry::OpRegistry()
|
||||
: initialized_(false), op_registry_validator_(DefaultValidator) {}
|
||||
|
||||
OpRegistry::~OpRegistry() {
|
||||
for (const auto& e : registry_) delete e.second;
|
||||
@ -114,7 +119,7 @@ const OpRegistrationData* OpRegistry::LookUpSlow(
|
||||
// Note: Can't hold mu_ while calling Export() below.
|
||||
}
|
||||
if (first_call) {
|
||||
TF_QCHECK_OK(ValidateKernelRegistrations(*this));
|
||||
TF_QCHECK_OK(op_registry_validator_(*this));
|
||||
}
|
||||
if (res == nullptr) {
|
||||
if (first_unregistered) {
|
||||
|
@ -95,6 +95,12 @@ class OpRegistry : public OpRegistryInterface {
|
||||
// Get all `OpRegistrationData`s.
|
||||
void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data);
|
||||
|
||||
// Registers a function that validates op registry.
|
||||
void RegisterValidator(
|
||||
std::function<Status(const OpRegistryInterface&)> validator) {
|
||||
op_registry_validator_ = std::move(validator);
|
||||
}
|
||||
|
||||
// Watcher, a function object.
|
||||
// The watcher, if set by SetWatcher(), is called every time an op is
|
||||
// registered via the Register function. The watcher is passed the Status
|
||||
@ -159,6 +165,8 @@ class OpRegistry : public OpRegistryInterface {
|
||||
|
||||
// Registry watcher.
|
||||
mutable Watcher watcher_ GUARDED_BY(mu_);
|
||||
|
||||
std::function<Status(const OpRegistryInterface&)> op_registry_validator_;
|
||||
};
|
||||
|
||||
// An adapter to allow an OpList to be used as an OpRegistryInterface.
|
||||
|
@ -1212,7 +1212,11 @@ void LoadDynamicKernels() {
|
||||
}
|
||||
|
||||
void* GlobalKernelRegistry() {
|
||||
static KernelRegistry* global_kernel_registry = new KernelRegistry;
|
||||
static KernelRegistry* global_kernel_registry = []() {
|
||||
KernelRegistry* registry = new KernelRegistry;
|
||||
OpRegistry::Global()->RegisterValidator(ValidateKernelRegistrations);
|
||||
return registry;
|
||||
}();
|
||||
return global_kernel_registry;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user