diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index 6ee1db2c7c5..fd6fd4b5b58 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -20,15 +20,17 @@ limitations under the License. namespace tensorflow { -bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr, - const NodeDef& node_def) const { - return CanCreateXlaKernel(node_def); +bool XlaKernelCreator::CanCreateKernel( + const FunctionLibraryRuntime& flr, + const std::shared_ptr& props) const { + return CanCreateXlaKernel(props->node_def); } -Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr, - const NodeDef& node_def, - std::unique_ptr* kernel) const { - return CreateXlaKernel(flr, node_def, kernel); +Status XlaKernelCreator::CreateKernel( + FunctionLibraryRuntime* flr, + const std::shared_ptr& props, + std::unique_ptr* kernel) const { + return CreateXlaKernel(flr, props->node_def, kernel); } namespace { diff --git a/tensorflow/compiler/jit/xla_kernel_creator.h b/tensorflow/compiler/jit/xla_kernel_creator.h index 8815ee49ce5..856701a791d 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.h +++ b/tensorflow/compiler/jit/xla_kernel_creator.h @@ -29,11 +29,13 @@ class XlaKernelCreator : public CustomKernelCreator { // Given a NodeDef 'node_def' and the function library runtime 'flr', returns // true if 'node_def' is a call to a compilable function defined in 'flr', // with the kXlaCompileAttr set. - bool CanCreateKernel(const FunctionLibraryRuntime& flr, - const NodeDef& node_def) const override; + bool CanCreateKernel( + const FunctionLibraryRuntime& flr, + const std::shared_ptr& props) const override; // Given a supported NodeDef, returns a XlaLaunchOp that computes the node. - Status CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, + Status CreateKernel(FunctionLibraryRuntime* flr, + const std::shared_ptr& props, std::unique_ptr* kernel) const override; }; diff --git a/tensorflow/compiler/jit/xla_kernel_creator_test.cc b/tensorflow/compiler/jit/xla_kernel_creator_test.cc index 7ec37332906..ad94d60d9b5 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_test.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator_test.cc @@ -30,10 +30,12 @@ limitations under the License. namespace tensorflow { -NodeDef ToNodeDef(const string& text) { +std::shared_ptr ToNodeProperties(const string& text) { NodeDef node_def; + DataTypeVector dummy; EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def)); - return node_def; + return std::make_shared(nullptr, std::move(node_def), dummy, + dummy); } // Create a FunctionDef that takes one resource and one regular param @@ -98,11 +100,11 @@ TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) { (*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true); Init({fdef}); XlaKernelCreator xla_kernel_creator; - NodeDef callsite = - ToNodeDef(R"pb( + auto callsite = + ToNodeProperties(R"pb( name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b' )pb"); - (*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true); + (*(callsite->node_def.mutable_attr()))["_XlaMustCompile"] = BoolAttr(true); // Note: need to set attribute on the created node. Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_); @@ -127,13 +129,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) { Init({fdef}); XlaKernelCreator xla_kernel_creator; - Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto( - name: 'XTimesY' - op: 'XTimesY' - input: 'a' - input: 'b' - )proto"), - &kernel_); + Status status = + xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), + &kernel_); EXPECT_TRUE(errors::IsInternal(status)) << status.ToString(); } @@ -143,13 +146,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) { Init({fdef}); XlaKernelCreator xla_kernel_creator; - Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto( - name: 'XTimesY' - op: 'XTimesY' - input: 'a' - input: 'b' - )proto"), - &kernel_); + Status status = + xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), + &kernel_); EXPECT_TRUE(errors::IsInternal(status)) << status.ToString(); } diff --git a/tensorflow/compiler/jit/xla_kernel_creator_util.cc b/tensorflow/compiler/jit/xla_kernel_creator_util.cc index 5aab0ff3bd6..de091fc93b4 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_util.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator_util.cc @@ -218,12 +218,13 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function)); Device* dev = flr->device(); Status s; - OpKernelConstruction construction( - DeviceType(dev->device_type()), dev, - dev->GetAllocator(AllocatorAttributes()), &node_def, - &fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types, - input_memory_types, fbody->ret_types, output_memory_types, - flr->graph_def_version(), &s); + auto props = std::make_shared( + &fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types); + OpKernelConstruction construction(DeviceType(dev->device_type()), dev, + dev->GetAllocator(AllocatorAttributes()), + flr, dev->resource_manager(), props, + input_memory_types, output_memory_types, + flr->graph_def_version(), &s); *kernel = absl::make_unique( &construction, constant_arg_indices, resource_arg_indices, function, diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc index a88f2b5e29e..bc42de6832d 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc @@ -127,9 +127,14 @@ class TRTEngineOpTestBase : public OpsTestBase { private: Status InitOpWithFunctionLibrary() { OpKernel* kernel = nullptr; - Status status = CreateOpKernel(device_type_, device_, allocator(), - pflr_->GetFLR(device_->name()), node_def_, - TF_GRAPH_DEF_VERSION, &kernel); + auto flr = pflr_->GetFLR(device_->name()); + std::shared_ptr props; + Status status = NodeProperties::CreateFromNodeDef( + node_def_, flr->GetFunctionLibraryDefinition(), &props); + if (status.ok()) { + status.Update(CreateOpKernel(device_type_, device_, allocator(), flr, + props, TF_GRAPH_DEF_VERSION, &kernel)); + } kernel_ = std::unique_ptr(kernel); if (kernel_ != nullptr) input_types_ = kernel_->input_types(); return status; diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 34888fc0e2f..f0aebc9b543 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -133,7 +133,7 @@ Status GraphCompiler::Compile() { OpKernel* op_kernel_raw = nullptr; // The kernel is not actually run for functional ops, we just need it // for metadata. - Status s = flib_->CreateKernel(n->def(), &op_kernel_raw); + Status s = flib_->CreateKernel(n->properties(), &op_kernel_raw); // Transfer ownership of the kernel to a local smart pointer. std::unique_ptr op_kernel(op_kernel_raw); diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index b89068c7a83..4f0df417037 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -472,6 +472,7 @@ tf_cuda_library( "//tensorflow/core/framework:memory_types.h", "//tensorflow/core/framework:node_def_builder.h", "//tensorflow/core/framework:node_def_util.h", + "//tensorflow/core/framework:node_properties.h", "//tensorflow/core/framework:numeric_op.h", "//tensorflow/core/framework:numeric_types.h", "//tensorflow/core/framework:op.h", @@ -2323,6 +2324,7 @@ tf_cuda_library( "//tensorflow/core/framework:bfloat16", "//tensorflow/core/framework:common_shape_fns", "//tensorflow/core/framework:node_def_util", + "//tensorflow/core/framework:node_properties", "//tensorflow/core/framework:numeric_types", "//tensorflow/core/framework:op", "//tensorflow/core/framework:op_def_builder", diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 098217a607a..a196f74c65b 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -1356,24 +1356,25 @@ Status DirectSession::CreateExecutors( params.session_metadata = session_metadata; params.function_library = lib; auto opseg = device->op_segment(); - params.create_kernel = [this, lib, opseg](const NodeDef& ndef, - OpKernel** kernel) { - // NOTE(mrry): We must not share function kernels (implemented - // using `CallOp`) between subgraphs, because `CallOp::handle_` - // is tied to a particular subgraph. Even if the function itself - // is stateful, the `CallOp` that invokes it is not. - if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) { - return lib->CreateKernel(ndef, kernel); - } - auto create_fn = [lib, &ndef](OpKernel** kernel) { - return lib->CreateKernel(ndef, kernel); - }; - // Kernels created for subgraph nodes need to be cached. On - // cache miss, create_fn() is invoked to create a kernel based - // on the function library here + global op registry. - return opseg->FindOrCreate(session_handle_, ndef.name(), kernel, - create_fn); - }; + params.create_kernel = + [this, lib, opseg](const std::shared_ptr& props, + OpKernel** kernel) { + // NOTE(mrry): We must not share function kernels (implemented + // using `CallOp`) between subgraphs, because `CallOp::handle_` + // is tied to a particular subgraph. Even if the function itself + // is stateful, the `CallOp` that invokes it is not. + if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) { + return lib->CreateKernel(props, kernel); + } + auto create_fn = [lib, &props](OpKernel** kernel) { + return lib->CreateKernel(props, kernel); + }; + // Kernels created for subgraph nodes need to be cached. On + // cache miss, create_fn() is invoked to create a kernel based + // on the function library here + global op registry. + return opseg->FindOrCreate(session_handle_, props->node_def.name(), + kernel, create_fn); + }; params.delete_kernel = [lib](OpKernel* kernel) { if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) delete kernel; diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index 6e8a5b9689a..8ca02ca51c0 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -98,7 +98,10 @@ Status KernelAndDeviceOp::Init(const NodeDef& ndef, "A valid FunctionLibraryRuntime must be provided when running ops " "based on OpKernel."); } - TF_RETURN_IF_ERROR(flr_->CreateKernel(ndef, &k)); + std::shared_ptr props; + TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef( + ndef, flr_->GetFunctionLibraryDefinition(), &props)); + TF_RETURN_IF_ERROR(flr_->CreateKernel(props, &k)); kernel_.reset(k); input_alloc_attrs_.resize(kernel_->num_inputs()); diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index bd3e14129b3..3a43a193b9e 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -654,7 +654,7 @@ Status ExecutorImpl::Initialize(const Graph& graph) { item->input_start = frame_info->total_inputs; frame_info->total_inputs += n->num_inputs(); - Status s = params_.create_kernel(n->def(), &item->kernel); + Status s = params_.create_kernel(n->properties(), &item->kernel); if (!s.ok()) { item->kernel = nullptr; s = AttachDef(s, *n); @@ -2974,12 +2974,12 @@ Status NewLocalExecutor(const LocalExecutorParams& params, const Graph& graph, } Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, - const NodeDef& ndef, int graph_def_version, - OpKernel** kernel) { + const std::shared_ptr& props, + int graph_def_version, OpKernel** kernel) { const auto device_type = DeviceType(device->attributes().device_type()); auto allocator = device->GetAllocator(AllocatorAttributes()); return CreateOpKernel(device_type, device, allocator, flib, - device->resource_manager(), ndef, graph_def_version, + device->resource_manager(), props, graph_def_version, kernel); } diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h index a7cb01ec7f0..fcc64b9d986 100644 --- a/tensorflow/core/common_runtime/executor.h +++ b/tensorflow/core/common_runtime/executor.h @@ -145,7 +145,9 @@ struct LocalExecutorParams { // create_kernel returns an instance of op kernel based on NodeDef. // delete_kernel is called for every kernel used by the executor // when the executor is deleted. - std::function create_kernel; + std::function&, + OpKernel**)> + create_kernel; std::function delete_kernel; Executor::RendezvousFactory rendezvous_factory; @@ -240,12 +242,12 @@ class ExecutorBarrier { // A few helpers to facilitate create/delete kernels. -// Creates a kernel based on "ndef" on device "device". The kernel can +// Creates a kernel based on "props" on device "device". The kernel can // access the functions in the "flib". The caller takes ownership of // returned "*kernel". Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, - const NodeDef& ndef, int graph_def_version, - OpKernel** kernel); + const std::shared_ptr& props, + int graph_def_version, OpKernel** kernel); // Deletes "kernel" returned by CreateKernel. void DeleteNonCachedKernel(OpKernel* kernel); diff --git a/tensorflow/core/common_runtime/executor_test.cc b/tensorflow/core/common_runtime/executor_test.cc index e994512a43f..3f143c75714 100644 --- a/tensorflow/core/common_runtime/executor_test.cc +++ b/tensorflow/core/common_runtime/executor_test.cc @@ -61,11 +61,12 @@ class ExecutorTest : public ::testing::Test { const int version = graph->versions().producer(); LocalExecutorParams params; params.device = device_.get(); - params.create_kernel = [this, version](const NodeDef& ndef, - OpKernel** kernel) { - return CreateNonCachedKernel(device_.get(), nullptr, ndef, version, - kernel); - }; + params.create_kernel = + [this, version](const std::shared_ptr& props, + OpKernel** kernel) { + return CreateNonCachedKernel(device_.get(), nullptr, props, version, + kernel); + }; params.delete_kernel = [](OpKernel* kernel) { DeleteNonCachedKernel(kernel); }; diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 14c0a8f5ad2..2140bf7f72b 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -187,7 +187,8 @@ class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime { void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame, DoneCallback done) override; - Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override; + Status CreateKernel(const std::shared_ptr& props, + OpKernel** kernel) override; bool IsStateful(const string& function_name) const override; @@ -256,7 +257,8 @@ void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle, base_flr_->Run(opts, handle, call_frame, std::move(done)); } -Status FunctionLibraryRuntimeOverlay::CreateKernel(const NodeDef&, OpKernel**) { +Status FunctionLibraryRuntimeOverlay::CreateKernel( + const std::shared_ptr&, OpKernel**) { // We don't have access to base_lib_def_ in base function library runtime (aka // FunctionLibraryRuntimeImpl), so to make sure we do not create a kernel with // the wrong lib_def we just disable creation of new kernels through overlays. @@ -344,7 +346,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { Status GetRetTypes(Handle handle, DataTypeVector* ret_types) override; - Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override; + Status CreateKernel(const std::shared_ptr& props, + OpKernel** kernel) override; void Run(const Options& opts, Handle handle, gtl::ArraySlice args, std::vector* rets, DoneCallback done) override; @@ -393,7 +396,9 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { const string device_name_; std::function get_func_sig_; - std::function create_kernel_; + std::function&, + OpKernel**)> + create_kernel_; mutable mutex mu_; @@ -426,8 +431,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { // to use for kernel creation and execution. In particular, this method can // accept a FunctionLibraryRuntimeOverlay that overlays a different // FunctionLibraryDefinition. - Status CreateKernel(const NodeDef& ndef, FunctionLibraryRuntime* flr, - OpKernel** kernel); + Status CreateKernel(const std::shared_ptr& props, + FunctionLibraryRuntime* flr, OpKernel** kernel); Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs, const FunctionLibraryDefinition* lib_def, std::unique_ptr* fbody); @@ -476,8 +481,9 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( get_func_sig_ = [this](const string& op, const OpDef** sig) { return base_lib_def_->LookUpOpDef(op, sig); }; - create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) { - return CreateKernel(ndef, kernel); + create_kernel_ = [this](const std::shared_ptr& props, + OpKernel** kernel) { + return CreateKernel(props, kernel); }; thread::ThreadPool* pool = nullptr; if (device_ != nullptr) { @@ -589,20 +595,20 @@ Status FunctionLibraryRuntimeImpl::GetRetTypes(Handle h, return Status::OK(); } -Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, - OpKernel** kernel) { - return CreateKernel(ndef, this, kernel); +Status FunctionLibraryRuntimeImpl::CreateKernel( + const std::shared_ptr& props, OpKernel** kernel) { + return CreateKernel(props, this, kernel); } -Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, - FunctionLibraryRuntime* flr, - OpKernel** kernel) { +Status FunctionLibraryRuntimeImpl::CreateKernel( + const std::shared_ptr& props, + FunctionLibraryRuntime* flr, OpKernel** kernel) { // If a custom kernel creator is given, try that. Status s; if (custom_kernel_creator_ != nullptr && - custom_kernel_creator_->CanCreateKernel(*this, ndef)) { + custom_kernel_creator_->CanCreateKernel(*this, props)) { std::unique_ptr ret; - s = custom_kernel_creator_->CreateKernel(this, ndef, &ret); + s = custom_kernel_creator_->CreateKernel(this, props, &ret); if (s.ok()) { *kernel = ret.release(); } else { @@ -613,9 +619,9 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, const FunctionLibraryDefinition* lib_def = flr->GetFunctionLibraryDefinition(); - if (lib_def->Find(ndef.op()) == nullptr) { + if (lib_def->Find(props->node_def.op()) == nullptr) { // A primitive operation. Creates the registered kernel. - return CreateNonCachedKernel(device_, flr, ndef, graph_def_version_, + return CreateNonCachedKernel(device_, flr, props, graph_def_version_, kernel); } @@ -626,8 +632,9 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, options.lib_def = lib_def; } Handle handle; - TF_RETURN_IF_ERROR( - Instantiate(ndef.op(), AttrSlice(&ndef.attr()), options, &handle)); + TF_RETURN_IF_ERROR(Instantiate(props->node_def.op(), + AttrSlice(&props->node_def.attr()), options, + &handle)); const FunctionBody* fbody = GetFunctionBody(handle); CHECK_NOTNULL(fbody); @@ -647,10 +654,12 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, // Constructs a CallOp kernel for running the instantiated function. auto device_type = DeviceType(device_->attributes().device_type()); + auto new_props = std::make_shared( + &fbody->fdef.signature(), props->node_def, fbody->arg_types, + fbody->ret_types); OpKernelConstruction construction( - device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef, - &fbody->fdef.signature(), flr, device_->resource_manager(), - fbody->arg_types, input_memory_types, fbody->ret_types, + device_type, device_, device_->GetAllocator(AllocatorAttributes()), flr, + device_->resource_manager(), props, input_memory_types, output_memory_types, graph_def_version_, &s); if (s.ok()) { *kernel = new CallOp(handle, &construction); @@ -953,9 +962,11 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) { if (flr == this) { params.create_kernel = create_kernel_; } else { - params.create_kernel = [this, flr](const NodeDef& ndef, OpKernel** kernel) { - return CreateKernel(ndef, flr, kernel); - }; + params.create_kernel = + [this, flr](const std::shared_ptr& props, + OpKernel** kernel) { + return CreateKernel(props, flr, kernel); + }; } params.delete_kernel = [](OpKernel* kernel) { DeleteNonCachedKernel(kernel); diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index c1247190d2d..3e2371a686a 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -90,11 +90,12 @@ class FunctionTest : public ::testing::Test { const int version = g->versions().producer(); LocalExecutorParams params; params.device = device_.get(); - params.create_kernel = [this, version](const NodeDef& ndef, - OpKernel** kernel) { - return CreateNonCachedKernel(device_.get(), nullptr, ndef, version, - kernel); - }; + params.create_kernel = + [this, version](const std::shared_ptr& props, + OpKernel** kernel) { + return CreateNonCachedKernel(device_.get(), nullptr, props, version, + kernel); + }; params.delete_kernel = [](OpKernel* kernel) { DeleteNonCachedKernel(kernel); }; diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index 0a7d50f9ea4..7ffb860a2ce 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -157,9 +157,10 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, params.device = device_; params.function_library = function_library; const int producer = graph_to_run->versions().producer(); - params.create_kernel = [this, function_library, producer](const NodeDef& ndef, - OpKernel** kernel) { - return CreateNonCachedKernel(device_, function_library, ndef, producer, + params.create_kernel = [this, function_library, producer]( + const std::shared_ptr& props, + OpKernel** kernel) { + return CreateNonCachedKernel(device_, function_library, props, producer, kernel); }; params.delete_kernel = [](OpKernel* kernel) { delete kernel; }; diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc index fe703050602..4118534cb3e 100644 --- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc @@ -84,9 +84,10 @@ Benchmark::Benchmark(const string& device, Graph* g, LocalExecutorParams params; params.device = device_.get(); params.function_library = nullptr; - params.create_kernel = [this, graph_def_version](const NodeDef& ndef, - OpKernel** kernel) { - return CreateNonCachedKernel(device_.get(), nullptr, ndef, + params.create_kernel = [this, graph_def_version]( + const std::shared_ptr& props, + OpKernel** kernel) { + return CreateNonCachedKernel(device_.get(), nullptr, props, graph_def_version, kernel); }; params.delete_kernel = [](OpKernel* kernel) { diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 9b28651c597..96fc4f3d4f3 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -233,23 +233,25 @@ Status GraphMgr::InitItem( // Construct the root executor for the subgraph. params.device = unit->device; params.function_library = lib; - params.create_kernel = [handle, lib, opseg](const NodeDef& ndef, - OpKernel** kernel) { - // NOTE(mrry): We must not share function kernels (implemented - // using `CallOp`) between subgraphs, because `CallOp::handle_` - // is tied to a particular subgraph. Even if the function itself - // is stateful, the `CallOp` that invokes it is not. - if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) { - return lib->CreateKernel(ndef, kernel); - } - auto create_fn = [lib, &ndef](OpKernel** kernel) { - return lib->CreateKernel(ndef, kernel); - }; - // Kernels created for subgraph nodes need to be cached. On - // cache miss, create_fn() is invoked to create a kernel based - // on the function library here + global op registry. - return opseg->FindOrCreate(handle, ndef.name(), kernel, create_fn); - }; + params.create_kernel = + [handle, lib, opseg](const std::shared_ptr& props, + OpKernel** kernel) { + // NOTE(mrry): We must not share function kernels (implemented + // using `CallOp`) between subgraphs, because `CallOp::handle_` + // is tied to a particular subgraph. Even if the function itself + // is stateful, the `CallOp` that invokes it is not. + if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) { + return lib->CreateKernel(props, kernel); + } + auto create_fn = [lib, &props](OpKernel** kernel) { + return lib->CreateKernel(props, kernel); + }; + // Kernels created for subgraph nodes need to be cached. On + // cache miss, create_fn() is invoked to create a kernel based + // on the function library here + global op registry. + return opseg->FindOrCreate(handle, props->node_def.name(), kernel, + create_fn); + }; params.delete_kernel = [lib](OpKernel* kernel) { if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) { delete kernel; diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 003e4894788..f3207dd657a 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -129,6 +129,7 @@ exports_files( "attr_value_util.h", "common_shape_fns.h", "node_def_util.h", + "node_properties.h", "op.h", "op_def_builder.h", "op_def_util.h", @@ -172,6 +173,7 @@ filegroup( "model.h", "node_def_builder.h", "node_def_util.h", + "node_properties.h", "numeric_op.h", "numeric_types.h", "op.h", @@ -338,6 +340,8 @@ filegroup( "node_def_builder.h", "node_def_util.cc", "node_def_util.h", + "node_properties.cc", + "node_properties.h", "numeric_op.h", "op.cc", "op.h", @@ -862,6 +866,21 @@ cc_library( ], ) +cc_library( + name = "node_properties", + srcs = ["node_properties.cc"], + hdrs = ["node_properties.h"], + deps = [ + ":node_def_proto_cc", + ":node_def_util", + ":op", + ":op_def_proto_cc", + ":tensor", + ":types_proto_cc", + "//tensorflow/core/lib/core:status", + ], +) + cc_library( name = "op_def_builder", srcs = ["op_def_builder.cc"], @@ -967,6 +986,7 @@ tf_cc_tests( "model_test.cc", "node_def_builder_test.cc", "node_def_util_test.cc", + "node_properties_test.cc", "op_compatibility_test.cc", "op_def_builder_test.cc", "op_def_util_test.cc", diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 0e260d26592..58cc1bbdaf9 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -722,11 +722,13 @@ class FunctionLibraryRuntime { virtual void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame, DoneCallback done) = 0; - // Creates a "kernel" for the given node def "ndef". + // Creates a "kernel" for the given NodeProperties "props". // // If succeeds, returns OK and the caller takes the ownership of the // returned "*kernel". Otherwise, returns an error. - virtual Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) = 0; + virtual Status CreateKernel( + const std::shared_ptr& props, + OpKernel** kernel) = 0; // Returns true iff the function named `function_name` is stateful. // @@ -818,12 +820,15 @@ class CustomKernelCreator { // Given a NodeDef 'node_def' and the function library runtime 'flr', // validate if the class supports creating such a kernel. - virtual bool CanCreateKernel(const FunctionLibraryRuntime& flr, - const NodeDef& node_def) const = 0; + virtual bool CanCreateKernel( + const FunctionLibraryRuntime& flr, + const std::shared_ptr& props) const = 0; // Given a supported NodeDef, returns a kernel that computes the node. - virtual Status CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& ndef, - std::unique_ptr* kernel) const = 0; + virtual Status CreateKernel( + FunctionLibraryRuntime* flr, + const std::shared_ptr& props, + std::unique_ptr* kernel) const = 0; }; // Used to instantiate and run functions in a distributed system. diff --git a/tensorflow/core/framework/node_properties.cc b/tensorflow/core/framework/node_properties.cc new file mode 100644 index 00000000000..bcc81bdbbff --- /dev/null +++ b/tensorflow/core/framework/node_properties.cc @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_properties.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +// static +Status NodeProperties::CreateFromNodeDef( + NodeDef node_def, const OpRegistryInterface* op_registry, + std::shared_ptr* props) { + const OpDef* op_def; + TF_RETURN_IF_ERROR(op_registry->LookUpOpDef(node_def.op(), &op_def)); + DataTypeVector input_types; + DataTypeVector output_types; + TF_RETURN_IF_ERROR( + InOutTypesForNode(node_def, *op_def, &input_types, &output_types)); + props->reset(new NodeProperties(op_def, std::move(node_def), + std::move(input_types), + std::move(output_types))); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/node_properties.h b/tensorflow/core/framework/node_properties.h new file mode 100644 index 00000000000..0382321f486 --- /dev/null +++ b/tensorflow/core/framework/node_properties.h @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_PROPERTIES_H_ +#define TENSORFLOW_CORE_FRAMEWORK_NODE_PROPERTIES_H_ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class OpRegistryInterface; + +struct NodeProperties { + public: + NodeProperties(const OpDef* op_def, NodeDef node_def, + const DataTypeSlice inputs, const DataTypeSlice outputs) + : NodeProperties(op_def, std::move(node_def), + DataTypeVector(inputs.begin(), inputs.end()), + DataTypeVector(outputs.begin(), outputs.end())) {} + + NodeProperties(const OpDef* _op_def, NodeDef&& _node_def, + DataTypeVector inputs, DataTypeVector outputs) + : op_def(_op_def), + node_def(std::move(_node_def)), + input_types(std::move(inputs)), + input_types_slice(input_types), + output_types(std::move(outputs)), + output_types_slice(output_types) {} + + // Resets the 'props' shared pointer to point to a new NodeProperties created + // from the given NodeDef. 'op_registry' is used to look up the OpDef + // corresponding to node_def.op(). Returns an error if OpDef lookup or + // creation failed. + static Status CreateFromNodeDef(NodeDef node_def, + const OpRegistryInterface* op_registry, + std::shared_ptr* props); + + const OpDef* op_def; // not owned. + NodeDef node_def; + DataTypeVector input_types; + DataTypeSlice input_types_slice; + DataTypeVector output_types; + DataTypeSlice output_types_slice; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_NODE_PROPERTIES_H_ diff --git a/tensorflow/core/framework/node_properties_test.cc b/tensorflow/core/framework/node_properties_test.cc new file mode 100644 index 00000000000..9f76b953b06 --- /dev/null +++ b/tensorflow/core/framework/node_properties_test.cc @@ -0,0 +1,128 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_properties.h" + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +OpDef ToOpDef(const OpDefBuilder& builder) { + OpRegistrationData op_reg_data; + EXPECT_TRUE(builder.Finalize(&op_reg_data).ok()); + return op_reg_data.op_def; +} + +class MockOpRegistry : public OpRegistryInterface { + public: + MockOpRegistry() + : op_reg_(ToOpDef(OpDefBuilder("Foo") + .Input("f: float") + .Input("i: int32") + .Output("of: double"))) {} + ~MockOpRegistry() override {} + + // Returns an error status and sets *op_reg_data to nullptr if no OpDef is + // registered under that name, otherwise returns the registered OpDef. + // Caller must not delete the returned pointer. + Status LookUp(const string& op_type_name, + const OpRegistrationData** op_reg_data) const override { + if (op_type_name == "Foo") { + *op_reg_data = &op_reg_; + return Status::OK(); + } else { + *op_reg_data = nullptr; + return errors::InvalidArgument("Op type named ", op_type_name, + " not found"); + } + } + + const OpDef* get_op_def_addr() { return &op_reg_.op_def; } + + private: + const OpRegistrationData op_reg_; +}; + +void ValidateNodeProperties(const NodeProperties& props, const OpDef* op_def, + const NodeDef& node_def, + const DataTypeVector& input_types, + const DataTypeVector& output_types) { + EXPECT_EQ(props.op_def, op_def); + EXPECT_EQ(props.node_def.name(), node_def.name()); + ASSERT_EQ(props.input_types.size(), input_types.size()); + for (int i = 0; i < input_types.size(); ++i) { + EXPECT_EQ(props.input_types[i], input_types[i]); + EXPECT_EQ(props.input_types_slice[i], input_types[i]); + } + ASSERT_EQ(props.output_types.size(), output_types.size()); + for (int i = 0; i < output_types.size(); ++i) { + EXPECT_EQ(props.output_types[i], output_types[i]); + EXPECT_EQ(props.output_types_slice[i], output_types[i]); + } +} + +} // namespace + +TEST(NodeProperties, Contructors) { + OpDef op_def; + NodeDef node_def; + node_def.set_name("foo"); + DataTypeVector input_types{DT_FLOAT, DT_INT32}; + DataTypeVector output_types{DT_DOUBLE}; + DataTypeSlice input_types_slice(input_types); + DataTypeSlice output_types_slice(output_types); + + // Construct from slices. + NodeProperties props_from_slices(&op_def, node_def, input_types_slice, + output_types_slice); + ValidateNodeProperties(props_from_slices, &op_def, node_def, input_types, + output_types); + + // Construct from vectors. + NodeProperties props_from_vectors(&op_def, node_def, input_types, + output_types); + ValidateNodeProperties(props_from_vectors, &op_def, node_def, input_types, + output_types); +} + +TEST(NodeProperties, CreateFromNodeDef) { + MockOpRegistry op_registry; + NodeDef node_def; + node_def.set_name("bar"); + node_def.set_op("Foo"); + node_def.add_input("f_in"); + node_def.add_input("i_in"); + + std::shared_ptr props; + EXPECT_TRUE( + NodeProperties::CreateFromNodeDef(node_def, &op_registry, &props).ok()); + + DataTypeVector input_types{DT_FLOAT, DT_INT32}; + DataTypeVector output_types{DT_DOUBLE}; + ValidateNodeProperties(*props, op_registry.get_op_def_addr(), node_def, + input_types, output_types); + + // The OpDef lookup should fail for this one: + node_def.set_op("Baz"); + std::shared_ptr props_bad; + EXPECT_FALSE( + NodeProperties::CreateFromNodeDef(node_def, &op_registry, &props_bad) + .ok()); + EXPECT_EQ(props_bad, nullptr); +} +} // namespace tensorflow diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 2feb84a1786..38c56eb3b1c 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -35,9 +35,9 @@ limitations under the License. #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/node_properties.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -91,35 +91,53 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs, // OpKernel ------------------------------------------------------------------ -OpKernel::OpKernel(OpKernelConstruction* context) - : OpKernel(context, MakeUnique(context->def())) {} +OpKernel::OpKernel(OpKernelConstruction* context) : OpKernel(context, false) {} OpKernel::OpKernel(OpKernelConstruction* context, bool is_deferred) - : OpKernel(context, MakeUnique(context->def()), - is_deferred) {} - -OpKernel::OpKernel(OpKernelConstruction* context, - std::unique_ptr node_def, bool is_deferred) - : def_(std::move(node_def)), - input_types_(context->input_types().begin(), - context->input_types().end()), + : props_(context->props_), input_memory_types_(context->input_memory_types().begin(), context->input_memory_types().end()), - output_types_(context->output_types().begin(), - context->output_types().end()), output_memory_types_(context->output_memory_types().begin(), context->output_memory_types().end()), input_name_map_(context->num_inputs()), output_name_map_(context->num_outputs()), - name_view_(def_->name()), - type_string_view_(def_->op()), + name_view_(props_->node_def.name()), + type_string_view_(props_->node_def.op()), graph_def_version_(context->graph_def_version()), is_deferred_(is_deferred), cost_estimate_(OpKernel::kInitialCostEstimateCycles) { OP_REQUIRES_OK(context, - NameRangesForNode(*def_, *context->op_def_, &input_name_map_, - &output_name_map_)); - OP_REQUIRES_OK(context, CheckOpDeprecation(*context->op_def_, + NameRangesForNode(props_->node_def, *props_->op_def, + &input_name_map_, &output_name_map_)); + OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def, + context->graph_def_version())); + + // Kernels executing on GPU/SYCL tie very few resources on the CPU where the + // scheduler runs: we consider them as inexpensive. + expensive_ = context->device_type() != DeviceType(DEVICE_GPU) && + context->device_type() != DeviceType(DEVICE_SYCL); +} + +OpKernel::OpKernel(OpKernelConstruction* context, NodeDef&& custom_def, + bool is_deferred) + : props_(std::make_shared( + context->props_->op_def, std::move(custom_def), + context->props_->input_types, context->props_->output_types)), + input_memory_types_(context->input_memory_types().begin(), + context->input_memory_types().end()), + output_memory_types_(context->output_memory_types().begin(), + context->output_memory_types().end()), + input_name_map_(context->num_inputs()), + output_name_map_(context->num_outputs()), + name_view_(props_->node_def.name()), + type_string_view_(props_->node_def.op()), + graph_def_version_(context->graph_def_version()), + is_deferred_(is_deferred), + cost_estimate_(OpKernel::kInitialCostEstimateCycles) { + OP_REQUIRES_OK(context, + NameRangesForNode(props_->node_def, *props_->op_def, + &input_name_map_, &output_name_map_)); + OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def, context->graph_def_version())); // Kernels executing on GPU/SYCL tie very few resources on the CPU where the @@ -134,10 +152,6 @@ const uint64 OpKernel::kInitialCostEstimateCycles; const uint64 OpKernel::kOpIsExpensiveThresholdCycles; const uint64 OpKernel::kCostDecay; -const string& OpKernel::name() const { return def_->name(); } -const string& OpKernel::type_string() const { return def_->op(); } -const string& OpKernel::requested_device() const { return def_->device(); } -const string& OpKernel::requested_input(int i) const { return def_->input(i); } Status OpKernel::InputRange(StringPiece input_name, int* start, int* stop) const { @@ -216,22 +230,18 @@ Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) { OpKernelConstruction::OpKernelConstruction( DeviceType device_type, DeviceBase* device, Allocator* allocator, - const NodeDef* node_def, const OpDef* op_def, FunctionLibraryRuntime* flib, - ResourceMgr* resource_mgr, const DataTypeSlice& input_types, + FunctionLibraryRuntime* flib, ResourceMgr* resource_mgr, + const std::shared_ptr& props, const MemoryTypeSlice& input_memory_types, - const DataTypeSlice& output_types, const MemoryTypeSlice& output_memory_types, int graph_def_version, Status* status) : device_type_(std::move(device_type)), device_(device), allocator_(allocator), - def_(node_def), - op_def_(op_def), flib_(flib), resource_mgr_(resource_mgr), - input_types_(input_types), + props_(props), input_memory_types_(input_memory_types), - output_types_(output_types), output_memory_types_(output_memory_types), graph_def_version_(graph_def_version), status_(status) {} @@ -246,8 +256,8 @@ void OpKernelConstruction::SetStatus(const Status& status) { Status OpKernelConstruction::MatchSignature( const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) { - return MatchSignatureHelper(expected_inputs, expected_outputs, input_types_, - output_types_); + return MatchSignatureHelper(expected_inputs, expected_outputs, + props_->input_types, props_->output_types); } Status OpKernelConstruction::allocate_temp(DataType type, @@ -263,7 +273,7 @@ Status OpKernelConstruction::allocate_temp(DataType type, } if (LogMemory::IsEnabled()) { LogMemory::RecordTensorAllocation( - def_->name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp); + def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp); } *out_temp = new_temp; return Status::OK(); @@ -288,7 +298,7 @@ Status OpKernelConstruction::allocate_temp(DataType type, } if (LogMemory::IsEnabled()) { LogMemory::RecordTensorAllocation( - def_->name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp); + def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp); } *out_temp = new_temp; return Status::OK(); @@ -1544,45 +1554,65 @@ string KernelsRegisteredForOp(StringPiece op_name) { return ret; } +/* TODO(rmlarsen): This API is deprecated. Remove it if possible to avoid + * copying the NodeDef. */ std::unique_ptr CreateOpKernel( DeviceType device_type, DeviceBase* device, Allocator* allocator, const NodeDef& node_def, int graph_def_version, Status* status) { + // Look up the Op registered for this op name. + std::shared_ptr props; + status->Update(NodeProperties::CreateFromNodeDef( + node_def, OpRegistry::Global(), &props)); + if (!status->ok()) { + errors::AppendToMessage(status, + " for node: ", FormatNodeDefForError(node_def)); + return nullptr; + } + return CreateOpKernel(device_type, device, allocator, props, + graph_def_version, status); +} + +std::unique_ptr CreateOpKernel( + DeviceType device_type, DeviceBase* device, Allocator* allocator, + const std::shared_ptr& props, int graph_def_version, + Status* status) { OpKernel* kernel = nullptr; - *status = CreateOpKernel(std::move(device_type), device, allocator, nullptr, - node_def, graph_def_version, &kernel); + *status = CreateOpKernel(std::move(device_type), device, allocator, + /*flib=*/nullptr, props, graph_def_version, &kernel); return std::unique_ptr(kernel); } Status CreateOpKernel(DeviceType device_type, DeviceBase* device, Allocator* allocator, FunctionLibraryRuntime* flib, - const NodeDef& node_def, int graph_def_version, - OpKernel** kernel) { + const std::shared_ptr& props, + int graph_def_version, OpKernel** kernel) { return CreateOpKernel(std::move(device_type), device, allocator, flib, - /* resource_mgr= */ nullptr, node_def, - graph_def_version, kernel); + /* resource_mgr= */ nullptr, props, graph_def_version, + kernel); } Status CreateOpKernel(DeviceType device_type, DeviceBase* device, Allocator* allocator, FunctionLibraryRuntime* flib, - ResourceMgr* resource_mgr, const NodeDef& node_def, + ResourceMgr* resource_mgr, + const std::shared_ptr& props, int graph_def_version, OpKernel** kernel) { - VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def); - - // Look up the Op registered for this op name. - const OpDef* op_def = nullptr; - TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def)); - - // Validate node_def against OpDef. - TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *op_def)); - - // Look up kernel registration. - const KernelRegistration* registration; + const NodeDef& node_def = props->node_def; bool was_attr_mismatch; - Status s = FindKernelRegistration(device_type, node_def, ®istration, - &was_attr_mismatch); - if (!s.ok()) { - errors::AppendToMessage(&s, " when instantiating ", node_def.op()); - return s; + const KernelRegistration* registration = nullptr; + Status s; + if (props != nullptr) { + VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def); + + // Validate node_def against OpDef. + TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *props->op_def)); + + // Look up kernel registration. + s = FindKernelRegistration(device_type, node_def, ®istration, + &was_attr_mismatch); + if (!s.ok()) { + errors::AppendToMessage(&s, " when instantiating ", node_def.op()); + return s; + } } if (registration == nullptr) { s.Update(errors::NotFound("No registered '", node_def.op(), @@ -1599,15 +1629,6 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device, return s; } - // Get signature from the OpDef & NodeDef - DataTypeVector inputs; - DataTypeVector outputs; - s.Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs)); - if (!s.ok()) { - errors::AppendToMessage(&s, " for node: ", FormatNodeDefForError(node_def)); - return s; - } - // We are creating a kernel for an op registered in // OpRegistry::Global(), we consult the kernel registry to decide // the kernel's input and output memory types. @@ -1618,10 +1639,9 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device, &output_memory_types)); // Everything needed for OpKernel construction. - OpKernelConstruction context(std::move(device_type), device, allocator, - &node_def, op_def, flib, resource_mgr, inputs, - input_memory_types, outputs, output_memory_types, - graph_def_version, &s); + OpKernelConstruction context(std::move(device_type), device, allocator, flib, + resource_mgr, props, input_memory_types, + output_memory_types, graph_def_version, &s); *kernel = registration->factory->Create(&context); if (!s.ok()) { delete *kernel; diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 4f1cc91cd19..e0d9742768a 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/node_properties.h" #include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/selective_registration.h" @@ -85,19 +86,18 @@ class OpKernel { // expensive initialization in the descendant's constructor. explicit OpKernel(OpKernelConstruction* context); - // Specialized constructor that enables the descendant to provide a different - // `NodeDef` value. For example, this constructor can be used to provide a - // stripped-down `NodeDef` that does not contain the full set of attrs (such - // as tensor values) if the descendant stores them in a different form. - explicit OpKernel(OpKernelConstruction* context, - std::unique_ptr node_def, - bool is_deferred = false); - // Specialized constructor that allows a kernel implementation to mark itself // as a "deferred" op. If true, the executor will provide access to the // `OpKernelContext::inc_num_deferred_ops_function()` and // `OpKernelContext::dec_num_deferred_ops_function()` methods at run-time. - explicit OpKernel(OpKernelConstruction* context, bool is_deferred); + OpKernel(OpKernelConstruction* context, bool is_deferred); + + // Specialized constructor that enables the descendant to provide a custom + // `NodeDef` value. For example, this constructor can be used to provide a + // stripped-down `NodeDef` that does not contain the full set of attrs (such + // as tensor values) if the descendant stores them in a different form. + OpKernel(OpKernelConstruction* context, NodeDef&& custom_def, + bool is_deferred); virtual ~OpKernel(); @@ -170,24 +170,26 @@ class OpKernel { } // Accessors. - const NodeDef& def() const { return *def_; } - const string& name() const; // Same as def().name() + const NodeDef& def() const { return props_->node_def; } + const string& name() const { return props_->node_def.name(); } absl::string_view name_view() const { return name_view_; } - const string& type_string() const; // Same as def().op() + const string& type_string() const { return props_->node_def.op(); } absl::string_view type_string_view() const { return type_string_view_; } - const string& requested_device() const; // Same as def().device() + const string& requested_input(int i) const { + return props_->node_def.input(i); + } + const string& requested_device() const { return props_->node_def.device(); } - int num_inputs() const { return input_types_.size(); } - DataType input_type(int i) const { return input_types_[i]; } - const DataTypeVector& input_types() const { return input_types_; } + int num_inputs() const { return props_->input_types.size(); } + DataType input_type(int i) const { return props_->input_types[i]; } + const DataTypeVector& input_types() const { return props_->input_types; } const MemoryTypeVector& input_memory_types() const { return input_memory_types_; } - const string& requested_input(int i) const; // Same as def().input(i) - int num_outputs() const { return output_types_.size(); } - DataType output_type(int o) const { return output_types_[o]; } - const DataTypeVector& output_types() const { return output_types_; } + int num_outputs() const { return props_->output_types.size(); } + DataType output_type(int o) const { return props_->output_types[o]; } + const DataTypeVector& output_types() const { return props_->output_types; } const MemoryTypeVector& output_memory_types() const { return output_memory_types_; } @@ -209,10 +211,8 @@ class OpKernel { string GetTraceArgument(OpKernelContext* ctx); private: - const std::unique_ptr def_; - const DataTypeVector input_types_; + const std::shared_ptr props_; const MemoryTypeVector input_memory_types_; - const DataTypeVector output_types_; const MemoryTypeVector output_memory_types_; NameRangeMap input_name_map_; NameRangeMap output_name_map_; @@ -284,12 +284,10 @@ class PersistentTensor { class OpKernelConstruction { public: OpKernelConstruction(DeviceType device_type, DeviceBase* device, - Allocator* allocator, const NodeDef* node_def, - const OpDef* op_def, FunctionLibraryRuntime* flib, + Allocator* allocator, FunctionLibraryRuntime* flib, ResourceMgr* resource_mgr, - const DataTypeSlice& input_types, + const std::shared_ptr& props, const MemoryTypeSlice& input_memory_types, - const DataTypeSlice& output_types, const MemoryTypeSlice& output_memory_types, int graph_def_version, Status* status); @@ -330,20 +328,22 @@ class OpKernelConstruction { Tensor** out_tensor); // User-supplied configuration of this operation. - const NodeDef& def() const { return *def_; } + const NodeDef& def() const { return props_->node_def; } // For inspecting the inputs to this operation. - int num_inputs() const { return input_types_.size(); } - DataType input_type(int i) const { return input_types_[i]; } - const DataTypeSlice& input_types() const { return input_types_; } + int num_inputs() const { return props_->input_types.size(); } + DataType input_type(int i) const { return props_->input_types[i]; } + const DataTypeSlice& input_types() const { return props_->input_types_slice; } const MemoryTypeSlice& input_memory_types() const { return input_memory_types_; } // For inspecting the outputs expected from this operation. - int num_outputs() const { return output_types_.size(); } - DataType output_type(int i) const { return output_types_[i]; } - const DataTypeSlice& output_types() const { return output_types_; } + int num_outputs() const { return props_->output_types.size(); } + DataType output_type(int i) const { return props_->output_types[i]; } + const DataTypeSlice& output_types() const { + return props_->output_types_slice; + } const MemoryTypeSlice& output_memory_types() const { return output_memory_types_; } @@ -403,19 +403,15 @@ class OpKernelConstruction { const DeviceType device_type_; DeviceBase* const device_; Allocator* allocator_; - const NodeDef* def_; - const OpDef* op_def_; FunctionLibraryRuntime* flib_; ResourceMgr* const resource_mgr_; - DataTypeSlice input_types_; + std::shared_ptr props_; MemoryTypeSlice input_memory_types_; - DataTypeSlice output_types_; MemoryTypeSlice output_memory_types_; const int graph_def_version_; Status* status_; - // Allow op_def_ across from OpKernel, but not from subclasses. - // TODO(irving): Remove protos from this header entirely. + // Allow access from OpKernel ctor. friend class OpKernel; TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction); @@ -1404,15 +1400,23 @@ const Eigen::SyclDevice& OpKernelContext::eigen_device() const; std::unique_ptr CreateOpKernel(DeviceType device_type, DeviceBase* device, Allocator* allocator, - const NodeDef& def, + const NodeDef& node_def, int graph_def_version, Status* status); + +std::unique_ptr CreateOpKernel( + DeviceType device_type, DeviceBase* device, Allocator* allocator, + const std::shared_ptr& props, int graph_def_version, + Status* status); + Status CreateOpKernel(DeviceType device_type, DeviceBase* device, Allocator* allocator, FunctionLibraryRuntime* flib, - const NodeDef& def, int graph_def_version, - OpKernel** kernel); + const std::shared_ptr& props, + int graph_def_version, OpKernel** kernel); + Status CreateOpKernel(DeviceType device_type, DeviceBase* device, Allocator* allocator, FunctionLibraryRuntime* flib, - ResourceMgr* resource_mgr, const NodeDef& def, + ResourceMgr* resource_mgr, + const std::shared_ptr& props, int graph_def_version, OpKernel** kernel); // Returns into 'device_types' the subset of prioritized_types that this diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index ec887a0ad93..40425cf24e0 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/str_util.h" diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 6240d0fb1ca..1f8a4d06c7a 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/node_properties.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/versions.pb.h" @@ -37,23 +37,7 @@ namespace tensorflow { const int Graph::kControlSlot = -1; -struct NodeProperties { - public: - NodeProperties(const OpDef* op_def, NodeDef node_def, - const DataTypeSlice inputs, const DataTypeSlice outputs) - : op_def(op_def), - node_def(std::move(node_def)), - input_types(inputs.begin(), inputs.end()), - output_types(outputs.begin(), outputs.end()) {} - - const OpDef* op_def; // not owned - NodeDef node_def; - const DataTypeVector input_types; - const DataTypeVector output_types; -}; - // Node - #define REF_CLASS(key, value) \ {key, value}, { "Ref" key, value } @@ -97,7 +81,8 @@ const std::unordered_map& Node::kNodeClassTable = {"StatelessIf", NC_IF}, {"While", NC_WHILE}, {"StatelessWhile", NC_WHILE}, - // Not using the constants defined in FunctionLibraryDefinition for the + // Not using the constants defined in FunctionLibraryDefinition + // for the // 4 ops below because android inference library does not link // tf.function related files. {"_Arg", NC_ARG}, diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index b33c0319c75..235d944bd60 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/edgeset.h" @@ -67,7 +68,6 @@ class WhileContext; class NeighborIter; // Declared below class NodeIter; // Declared below -struct NodeProperties; // Defined in .cc class Node { public: @@ -229,11 +229,12 @@ class Node { while_ctx_ = while_ctx; } + std::shared_ptr properties() const { return props_; } + private: friend class Graph; Node(); - NodeProperties* properties() const { return props_.get(); } void Initialize(int id, int cost_id, std::shared_ptr props, bool is_function_op); diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index 5931599c6e2..ccdafdf91c9 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -47,31 +47,30 @@ namespace tensorflow { namespace { -std::unique_ptr StripTensorDataFromNodeDef( - OpKernelConstruction* ctx) { +NodeDef StripTensorDataFromNodeDef(OpKernelConstruction* ctx) { #ifndef __ANDROID__ DCHECK_EQ(NodeDef::descriptor()->field_count(), 6) << "The NodeDef format has changed, and the attr-stripping code may need " << "to be updated."; #endif const NodeDef& original = ctx->def(); - NodeDef* ret = new NodeDef; - ret->set_name(original.name()); - ret->set_op(original.op()); - ret->set_device(original.device()); + NodeDef ret; + ret.set_name(original.name()); + ret.set_op(original.op()); + ret.set_device(original.device()); // Strip the "value" attr from the returned NodeDef. // NOTE(mrry): The present implementation of `OpKernel::OpKernel()` only uses // attrs that affect the cardinality of list-typed inputs and outputs, so it // is safe to drop other attrs from the NodeDef. - AddNodeAttr("dtype", ctx->output_type(0), ret); - MergeDebugInfo(original, ret); - return std::unique_ptr(ret); + AddNodeAttr("dtype", ctx->output_type(0), &ret); + MergeDebugInfo(original, &ret); + return ret; } } // namespace ConstantOp::ConstantOp(OpKernelConstruction* ctx) - : OpKernel(ctx, StripTensorDataFromNodeDef(ctx)), + : OpKernel(ctx, StripTensorDataFromNodeDef(ctx), false), tensor_(ctx->output_type(0)) { const TensorProto* proto = nullptr; MEMDEBUG_CACHE_OP(ctx->def().name().c_str()); diff --git a/tensorflow/core/kernels/data/dataset_test_base.cc b/tensorflow/core/kernels/data/dataset_test_base.cc index 7c5d0c3f679..817e075e69b 100644 --- a/tensorflow/core/kernels/data/dataset_test_base.cc +++ b/tensorflow/core/kernels/data/dataset_test_base.cc @@ -304,9 +304,14 @@ Status DatasetOpsTestBase::ExpectEqual(std::vector produced_tensors, Status DatasetOpsTestBase::CreateOpKernel( const NodeDef& node_def, std::unique_ptr* op_kernel) { OpKernel* kernel; + Status s; + + std::shared_ptr props; + TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef( + node_def, flr_->GetFunctionLibraryDefinition(), &props)); TF_RETURN_IF_ERROR(tensorflow::CreateOpKernel( device_type_, device_.get(), allocator_, flr_, - device_->resource_manager(), node_def, TF_GRAPH_DEF_VERSION, &kernel)); + device_->resource_manager(), props, TF_GRAPH_DEF_VERSION, &kernel)); op_kernel->reset(kernel); return Status::OK(); } @@ -435,9 +440,10 @@ Status DatasetOpsTestBase::RunFunction( LocalExecutorParams params; params.function_library = flr_; params.device = device_.get(); - params.create_kernel = [this, version](const NodeDef& ndef, - OpKernel** kernel) { - return CreateNonCachedKernel(device_.get(), this->flr_, ndef, version, + params.create_kernel = [this, version]( + const std::shared_ptr& props, + OpKernel** kernel) { + return CreateNonCachedKernel(device_.get(), this->flr_, props, version, kernel); }; params.delete_kernel = [](OpKernel* kernel) { diff --git a/tensorflow/core/kernels/data/single_threaded_executor.cc b/tensorflow/core/kernels/data/single_threaded_executor.cc index a6b31679fa6..5393d5557eb 100644 --- a/tensorflow/core/kernels/data/single_threaded_executor.cc +++ b/tensorflow/core/kernels/data/single_threaded_executor.cc @@ -108,7 +108,8 @@ class SingleThreadedExecutorImpl : public Executor { KernelState& kernel_state = kernels_[kernel_index]; node_to_index_map[n] = kernel_index; - TF_RETURN_IF_ERROR(params_.create_kernel(n->def(), &kernel_state.kernel)); + TF_RETURN_IF_ERROR( + params_.create_kernel(n->properties(), &kernel_state.kernel)); kernel_state.num_inputs = n->num_inputs(); kernel_state.num_outputs = n->num_outputs(); diff --git a/tensorflow/core/kernels/data/single_threaded_executor_test.cc b/tensorflow/core/kernels/data/single_threaded_executor_test.cc index 1a5059487a4..898a6555265 100644 --- a/tensorflow/core/kernels/data/single_threaded_executor_test.cc +++ b/tensorflow/core/kernels/data/single_threaded_executor_test.cc @@ -58,11 +58,12 @@ class ExecutorTest : public ::testing::Test { const int version = graph->versions().producer(); LocalExecutorParams params; params.device = device_.get(); - params.create_kernel = [this, version](const NodeDef& ndef, - OpKernel** kernel) { - return CreateNonCachedKernel(device_.get(), nullptr, ndef, version, - kernel); - }; + params.create_kernel = + [this, version](const std::shared_ptr& props, + OpKernel** kernel) { + return CreateNonCachedKernel(device_.get(), nullptr, props, version, + kernel); + }; params.delete_kernel = [](OpKernel* kernel) { DeleteNonCachedKernel(kernel); }; diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py index f510f24d777..c2389025a25 100644 --- a/tensorflow/python/eager/pywrap_tfe_test.py +++ b/tensorflow/python/eager/pywrap_tfe_test.py @@ -237,8 +237,7 @@ class Tests(test.TestCase): @test_util.assert_no_garbage_created def testInvalidNumOutputs(self): with self.assertRaisesRegexp( - Exception, - "Value for attr 'num_split' of -1 must be at least minimum 1"): + Exception, r"Value for number_attr\(\) -1 < 0 \[Op:Split\]"): array_ops.split(value=[1, 2, 3], num_or_size_splits=-1) with self.assertRaisesRegexp(