Split out circular dependency between op and op_kernel. Add tensorflow/core/framework:op target.

PiperOrigin-RevId: 290350328
Change-Id: Id04a6f8ce3689f658dacfa159269b1d468d3dc91
This commit is contained in:
Anna R 2020-01-17 15:09:50 -08:00 committed by TensorFlower Gardener
parent 0197b66905
commit 1e22bd9c0a
5 changed files with 51 additions and 7 deletions
tensorflow/core

View File

@ -2317,9 +2317,11 @@ tf_cuda_library(
"//tensorflow/core/framework:common_shape_fns",
"//tensorflow/core/framework:node_def_util",
"//tensorflow/core/framework:numeric_types",
"//tensorflow/core/framework:op",
"//tensorflow/core/framework:op_def_builder",
"//tensorflow/core/framework:op_def_util",
"//tensorflow/core/framework:resource_handle",
"//tensorflow/core/framework:selective_registration",
"//tensorflow/core/framework:shape_inference",
"//tensorflow/core/framework:tensor",
"//tensorflow/core/framework:tensor_shape",

View File

@ -46,7 +46,6 @@ exports_files(
"model.h",
"node_def_builder.h",
"numeric_op.h",
"op.h",
"op_kernel.h",
"op_segment.h",
"ops_util.h",
@ -61,7 +60,6 @@ exports_files(
"resource_var.h",
"run_handler.h",
"run_handler_util.h",
"selective_registration.h",
"session_state.h",
"shared_ptr_variant.h",
"stats_aggregator.h",
@ -235,7 +233,6 @@ filegroup(
"memory_types.cc",
"model.cc",
"node_def_builder.cc",
"op.cc",
"op_kernel.cc",
"op_segment.cc",
"ops_util.cc",
@ -878,6 +875,34 @@ cc_library(
],
)
cc_library(
name = "selective_registration",
hdrs = ["selective_registration.h"],
)
cc_library(
name = "op",
srcs = ["op.cc"],
hdrs = ["op.h"],
deps = [
":op_def_builder",
":op_def_util",
":selective_registration",
"//tensorflow/core/lib/core:errors",
"//tensorflow/core/lib/core:status",
"//tensorflow/core/lib/gtl:map_util",
"//tensorflow/core/lib/strings:str_util",
"//tensorflow/core/lib/strings:strcat",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:macros",
"//tensorflow/core/platform:mutex",
"//tensorflow/core/platform:platform_port",
"//tensorflow/core/platform:protobuf",
"//tensorflow/core/platform:thread_annotations",
"//tensorflow/core/platform:types",
],
)
# Files whose users still need to be migrated from core:framework to the
# above targets.
# TODO(gonnet): Remove these files once targets depending on them have

View File

@ -20,7 +20,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
@ -32,6 +31,11 @@ limitations under the License.
namespace tensorflow {
Status DefaultValidator(const OpRegistryInterface& op_registry) {
LOG(WARNING) << "No kernel validator registered with OpRegistry.";
return Status::OK();
}
// OpRegistry -----------------------------------------------------------------
OpRegistryInterface::~OpRegistryInterface() {}
@ -45,7 +49,8 @@ Status OpRegistryInterface::LookUpOpDef(const string& op_type_name,
return Status::OK();
}
OpRegistry::OpRegistry() : initialized_(false) {}
OpRegistry::OpRegistry()
: initialized_(false), op_registry_validator_(DefaultValidator) {}
OpRegistry::~OpRegistry() {
for (const auto& e : registry_) delete e.second;
@ -114,7 +119,7 @@ const OpRegistrationData* OpRegistry::LookUpSlow(
// Note: Can't hold mu_ while calling Export() below.
}
if (first_call) {
TF_QCHECK_OK(ValidateKernelRegistrations(*this));
TF_QCHECK_OK(op_registry_validator_(*this));
}
if (res == nullptr) {
if (first_unregistered) {

View File

@ -95,6 +95,12 @@ class OpRegistry : public OpRegistryInterface {
// Get all `OpRegistrationData`s.
void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data);
// Registers a function that validates op registry.
void RegisterValidator(
std::function<Status(const OpRegistryInterface&)> validator) {
op_registry_validator_ = std::move(validator);
}
// Watcher, a function object.
// The watcher, if set by SetWatcher(), is called every time an op is
// registered via the Register function. The watcher is passed the Status
@ -159,6 +165,8 @@ class OpRegistry : public OpRegistryInterface {
// Registry watcher.
mutable Watcher watcher_ GUARDED_BY(mu_);
std::function<Status(const OpRegistryInterface&)> op_registry_validator_;
};
// An adapter to allow an OpList to be used as an OpRegistryInterface.

View File

@ -1212,7 +1212,11 @@ void LoadDynamicKernels() {
}
void* GlobalKernelRegistry() {
static KernelRegistry* global_kernel_registry = new KernelRegistry;
static KernelRegistry* global_kernel_registry = []() {
KernelRegistry* registry = new KernelRegistry;
OpRegistry::Global()->RegisterValidator(ValidateKernelRegistrations);
return registry;
}();
return global_kernel_registry;
}