From 5f3a3019baf611d3720e70c902fd8170dfe3c0b4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 18 Feb 2020 13:05:19 -0800 Subject: [PATCH] Replace NodeDef with std::shared_ptr in the kernel creation code paths and try to avoid as many copies of NodeDefs as possible. This will in most cases allow sharing the NodeDef between the OpKernel and the graph Node from which it is created. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reduces the number of allocations in the executor benchmark by about 8%: name old time/op new time/op delta BM_executor/16/1k [Nodes = 9824 ] 911µs ± 3% 911µs ± 1% ~ (p=0.548 n=5+5) BM_executor/32/8k [Nodes = 141991] 17.1ms ± 2% 16.8ms ± 1% -2.17% (p=0.016 n=5+5) BM_executor/1k/16 [Nodes = 6781 ] 1.21ms ± 1% 1.25ms ± 7% ~ (p=0.095 n=5+5) BM_executor/8k/32 [Nodes = 130875] 4.35s ± 0% 4.34s ± 0% ~ (p=0.841 n=5+5) BM_executor/1k/1k [Nodes = 526256] 3.33s ± 1% 3.31s ± 1% ~ (p=0.095 n=5+5) BM_FeedInputFetchOutput 54.0µs ± 7% 56.9µs ±13% ~ (p=0.222 n=5+5) name old allocs/op new allocs/op delta BM_executor/16/1k [Nodes = 9824 ] 15.4k ± 0% 14.1k ± 0% -7.95% (p=0.008 n=5+5) BM_executor/32/8k [Nodes = 141991] 226k ± 0% 208k ± 0% -7.86% (p=0.008 n=5+5) BM_executor/1k/16 [Nodes = 6781 ] 10.2k ± 0% 9.3k ± 0% -8.36% (p=0.008 n=5+5) BM_executor/8k/32 [Nodes = 130875] 197k ± 0% 180k ± 0% -8.31% (p=0.016 n=4+5) BM_executor/1k/1k [Nodes = 526256] 771k ± 0% 706k ± 0% -8.53% (p=0.008 n=5+5) BM_FeedInputFetchOutput 58.0 ± 0% 57.0 ± 0% -1.72% (p=0.008 n=5+5) PiperOrigin-RevId: 295803318 Change-Id: I0d262c6082822023f449f9817dc943d20bd302d5 --- tensorflow/compiler/jit/xla_kernel_creator.cc | 16 +- tensorflow/compiler/jit/xla_kernel_creator.h | 8 +- .../compiler/jit/xla_kernel_creator_test.cc | 42 ++--- .../compiler/jit/xla_kernel_creator_util.cc | 13 +- .../tf2tensorrt/kernels/trt_engine_op_test.cc | 11 +- tensorflow/compiler/tf2xla/graph_compiler.cc | 2 +- tensorflow/core/BUILD | 2 + .../core/common_runtime/direct_session.cc | 37 ++-- .../common_runtime/eager/kernel_and_device.cc | 5 +- tensorflow/core/common_runtime/executor.cc | 8 +- tensorflow/core/common_runtime/executor.h | 10 +- .../core/common_runtime/executor_test.cc | 11 +- tensorflow/core/common_runtime/function.cc | 63 ++++--- .../core/common_runtime/function_test.cc | 11 +- .../core/common_runtime/graph_runner.cc | 7 +- .../kernel_benchmark_testlib.cc | 7 +- .../core/distributed_runtime/graph_mgr.cc | 36 ++-- tensorflow/core/framework/BUILD | 20 +++ tensorflow/core/framework/function.h | 17 +- tensorflow/core/framework/node_properties.cc | 39 +++++ tensorflow/core/framework/node_properties.h | 63 +++++++ .../core/framework/node_properties_test.cc | 128 ++++++++++++++ tensorflow/core/framework/op_kernel.cc | 158 ++++++++++-------- tensorflow/core/framework/op_kernel.h | 92 +++++----- tensorflow/core/framework/op_kernel_test.cc | 1 + tensorflow/core/graph/graph.cc | 21 +-- tensorflow/core/graph/graph.h | 5 +- tensorflow/core/kernels/constant_op.cc | 19 +-- .../core/kernels/data/dataset_test_base.cc | 14 +- .../kernels/data/single_threaded_executor.cc | 3 +- .../data/single_threaded_executor_test.cc | 11 +- tensorflow/python/eager/pywrap_tfe_test.py | 3 +- 32 files changed, 597 insertions(+), 286 deletions(-) create mode 100644 tensorflow/core/framework/node_properties.cc create mode 100644 tensorflow/core/framework/node_properties.h create mode 100644 tensorflow/core/framework/node_properties_test.cc diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index 6ee1db2c7c5..fd6fd4b5b58 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -20,15 +20,17 @@ limitations under the License. namespace tensorflow { -bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr, - const NodeDef& node_def) const { - return CanCreateXlaKernel(node_def); +bool XlaKernelCreator::CanCreateKernel( + const FunctionLibraryRuntime& flr, + const std::shared_ptr& props) const { + return CanCreateXlaKernel(props->node_def); } -Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr, - const NodeDef& node_def, - std::unique_ptr* kernel) const { - return CreateXlaKernel(flr, node_def, kernel); +Status XlaKernelCreator::CreateKernel( + FunctionLibraryRuntime* flr, + const std::shared_ptr& props, + std::unique_ptr* kernel) const { + return CreateXlaKernel(flr, props->node_def, kernel); } namespace { diff --git a/tensorflow/compiler/jit/xla_kernel_creator.h b/tensorflow/compiler/jit/xla_kernel_creator.h index 8815ee49ce5..856701a791d 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.h +++ b/tensorflow/compiler/jit/xla_kernel_creator.h @@ -29,11 +29,13 @@ class XlaKernelCreator : public CustomKernelCreator { // Given a NodeDef 'node_def' and the function library runtime 'flr', returns // true if 'node_def' is a call to a compilable function defined in 'flr', // with the kXlaCompileAttr set. - bool CanCreateKernel(const FunctionLibraryRuntime& flr, - const NodeDef& node_def) const override; + bool CanCreateKernel( + const FunctionLibraryRuntime& flr, + const std::shared_ptr& props) const override; // Given a supported NodeDef, returns a XlaLaunchOp that computes the node. - Status CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, + Status CreateKernel(FunctionLibraryRuntime* flr, + const std::shared_ptr& props, std::unique_ptr* kernel) const override; }; diff --git a/tensorflow/compiler/jit/xla_kernel_creator_test.cc b/tensorflow/compiler/jit/xla_kernel_creator_test.cc index 7ec37332906..ad94d60d9b5 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_test.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator_test.cc @@ -30,10 +30,12 @@ limitations under the License. namespace tensorflow { -NodeDef ToNodeDef(const string& text) { +std::shared_ptr ToNodeProperties(const string& text) { NodeDef node_def; + DataTypeVector dummy; EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def)); - return node_def; + return std::make_shared(nullptr, std::move(node_def), dummy, + dummy); } // Create a FunctionDef that takes one resource and one regular param @@ -98,11 +100,11 @@ TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) { (*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true); Init({fdef}); XlaKernelCreator xla_kernel_creator; - NodeDef callsite = - ToNodeDef(R"pb( + auto callsite = + ToNodeProperties(R"pb( name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b' )pb"); - (*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true); + (*(callsite->node_def.mutable_attr()))["_XlaMustCompile"] = BoolAttr(true); // Note: need to set attribute on the created node. Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_); @@ -127,13 +129,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) { Init({fdef}); XlaKernelCreator xla_kernel_creator; - Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto( - name: 'XTimesY' - op: 'XTimesY' - input: 'a' - input: 'b' - )proto"), - &kernel_); + Status status = + xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), + &kernel_); EXPECT_TRUE(errors::IsInternal(status)) << status.ToString(); } @@ -143,13 +146,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) { Init({fdef}); XlaKernelCreator xla_kernel_creator; - Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto( - name: 'XTimesY' - op: 'XTimesY' - input: 'a' - input: 'b' - )proto"), - &kernel_); + Status status = + xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), + &kernel_); EXPECT_TRUE(errors::IsInternal(status)) << status.ToString(); } diff --git a/tensorflow/compiler/jit/xla_kernel_creator_util.cc b/tensorflow/compiler/jit/xla_kernel_creator_util.cc index 5aab0ff3bd6..de091fc93b4 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_util.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator_util.cc @@ -218,12 +218,13 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function)); Device* dev = flr->device(); Status s; - OpKernelConstruction construction( - DeviceType(dev->device_type()), dev, - dev->GetAllocator(AllocatorAttributes()), &node_def, - &fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types, - input_memory_types, fbody->ret_types, output_memory_types, - flr->graph_def_version(), &s); + auto props = std::make_shared( + &fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types); + OpKernelConstruction construction(DeviceType(dev->device_type()), dev, + dev->GetAllocator(AllocatorAttributes()), + flr, dev->resource_manager(), props, + input_memory_types, output_memory_types, + flr->graph_def_version(), &s); *kernel = absl::make_unique( &construction, constant_arg_indices, resource_arg_indices, function, diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc index a88f2b5e29e..bc42de6832d 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc @@ -127,9 +127,14 @@ class TRTEngineOpTestBase : public OpsTestBase { private: Status InitOpWithFunctionLibrary() { OpKernel* kernel = nullptr; - Status status = CreateOpKernel(device_type_, device_, allocator(), - pflr_->GetFLR(device_->name()), node_def_, - TF_GRAPH_DEF_VERSION, &kernel); + auto flr = pflr_->GetFLR(device_->name()); + std::shared_ptr props; + Status status = NodeProperties::CreateFromNodeDef( + node_def_, flr->GetFunctionLibraryDefinition(), &props); + if (status.ok()) { + status.Update(CreateOpKernel(device_type_, device_, allocator(), flr, + props, TF_GRAPH_DEF_VERSION, &kernel)); + } kernel_ = std::unique_ptr(kernel); if (kernel_ != nullptr) input_types_ = kernel_->input_types(); return status; diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 34888fc0e2f..f0aebc9b543 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -133,7 +133,7 @@ Status GraphCompiler::Compile() { OpKernel* op_kernel_raw = nullptr; // The kernel is not actually run for functional ops, we just need it // for metadata. - Status s = flib_->CreateKernel(n->def(), &op_kernel_raw); + Status s = flib_->CreateKernel(n->properties(), &op_kernel_raw); // Transfer ownership of the kernel to a local smart pointer. std::unique_ptr op_kernel(op_kernel_raw); diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index b89068c7a83..4f0df417037 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -472,6 +472,7 @@ tf_cuda_library( "//tensorflow/core/framework:memory_types.h", "//tensorflow/core/framework:node_def_builder.h", "//tensorflow/core/framework:node_def_util.h", + "//tensorflow/core/framework:node_properties.h", "//tensorflow/core/framework:numeric_op.h", "//tensorflow/core/framework:numeric_types.h", "//tensorflow/core/framework:op.h", @@ -2323,6 +2324,7 @@ tf_cuda_library( "//tensorflow/core/framework:bfloat16", "//tensorflow/core/framework:common_shape_fns", "//tensorflow/core/framework:node_def_util", + "//tensorflow/core/framework:node_properties", "//tensorflow/core/framework:numeric_types", "//tensorflow/core/framework:op", "//tensorflow/core/framework:op_def_builder", diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 098217a607a..a196f74c65b 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -1356,24 +1356,25 @@ Status DirectSession::CreateExecutors( params.session_metadata = session_metadata; params.function_library = lib; auto opseg = device->op_segment(); - params.create_kernel = [this, lib, opseg](const NodeDef& ndef, - OpKernel** kernel) { - // NOTE(mrry): We must not share function kernels (implemented - // using `CallOp`) between subgraphs, because `CallOp::handle_` - // is tied to a particular subgraph. Even if the function itself - // is stateful, the `CallOp` that invokes it is not. - if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) { - return lib->CreateKernel(ndef, kernel); - } - auto create_fn = [lib, &ndef](OpKernel** kernel) { - return lib->CreateKernel(ndef, kernel); - }; - // Kernels created for subgraph nodes need to be cached. On - // cache miss, create_fn() is invoked to create a kernel based - // on the function library here + global op registry. - return opseg->FindOrCreate(session_handle_, ndef.name(), kernel, - create_fn); - }; + params.create_kernel = + [this, lib, opseg](const std::shared_ptr& props, + OpKernel** kernel) { + // NOTE(mrry): We must not share function kernels (implemented + // using `CallOp`) between subgraphs, because `CallOp::handle_` + // is tied to a particular subgraph. Even if the function itself + // is stateful, the `CallOp` that invokes it is not. + if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) { + return lib->CreateKernel(props, kernel); + } + auto create_fn = [lib, &props](OpKernel** kernel) { + return lib->CreateKernel(props, kernel); + }; + // Kernels created for subgraph nodes need to be cached. On + // cache miss, create_fn() is invoked to create a kernel based + // on the function library here + global op registry. + return opseg->FindOrCreate(session_handle_, props->node_def.name(), + kernel, create_fn); + }; params.delete_kernel = [lib](OpKernel* kernel) { if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) delete kernel; diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index 6e8a5b9689a..8ca02ca51c0 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -98,7 +98,10 @@ Status KernelAndDeviceOp::Init(const NodeDef& ndef, "A valid FunctionLibraryRuntime must be provided when running ops " "based on OpKernel."); } - TF_RETURN_IF_ERROR(flr_->CreateKernel(ndef, &k)); + std::shared_ptr props; + TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef( + ndef, flr_->GetFunctionLibraryDefinition(), &props)); + TF_RETURN_IF_ERROR(flr_->CreateKernel(props, &k)); kernel_.reset(k); input_alloc_attrs_.resize(kernel_->num_inputs()); diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index bd3e14129b3..3a43a193b9e 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -654,7 +654,7 @@ Status ExecutorImpl::Initialize(const Graph& graph) { item->input_start = frame_info->total_inputs; frame_info->total_inputs += n->num_inputs(); - Status s = params_.create_kernel(n->def(), &item->kernel); + Status s = params_.create_kernel(n->properties(), &item->kernel); if (!s.ok()) { item->kernel = nullptr; s = AttachDef(s, *n); @@ -2974,12 +2974,12 @@ Status NewLocalExecutor(const LocalExecutorParams& params, const Graph& graph, } Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, - const NodeDef& ndef, int graph_def_version, - OpKernel** kernel) { + const std::shared_ptr& props, + int graph_def_version, OpKernel** kernel) { const auto device_type = DeviceType(device->attributes().device_type()); auto allocator = device->GetAllocator(AllocatorAttributes()); return CreateOpKernel(device_type, device, allocator, flib, - device->resource_manager(), ndef, graph_def_version, + device->resource_manager(), props, graph_def_version, kernel); } diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h index a7cb01ec7f0..fcc64b9d986 100644 --- a/tensorflow/core/common_runtime/executor.h +++ b/tensorflow/core/common_runtime/executor.h @@ -145,7 +145,9 @@ struct LocalExecutorParams { // create_kernel returns an instance of op kernel based on NodeDef. // delete_kernel is called for every kernel used by the executor // when the executor is deleted. - std::function create_kernel; + std::function&, + OpKernel**)> + create_kernel; std::function delete_kernel; Executor::RendezvousFactory rendezvous_factory; @@ -240,12 +242,12 @@ class ExecutorBarrier { // A few helpers to facilitate create/delete kernels. -// Creates a kernel based on "ndef" on device "device". The kernel can +// Creates a kernel based on "props" on device "device". The kernel can // access the functions in the "flib". The caller takes ownership of // returned "*kernel". Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, - const NodeDef& ndef, int graph_def_version, - OpKernel** kernel); + const std::shared_ptr& props, + int graph_def_version, OpKernel** kernel); // Deletes "kernel" returned by CreateKernel. void DeleteNonCachedKernel(OpKernel* kernel); diff --git a/tensorflow/core/common_runtime/executor_test.cc b/tensorflow/core/common_runtime/executor_test.cc index e994512a43f..3f143c75714 100644 --- a/tensorflow/core/common_runtime/executor_test.cc +++ b/tensorflow/core/common_runtime/executor_test.cc @@ -61,11 +61,12 @@ class ExecutorTest : public ::testing::Test { const int version = graph->versions().producer(); LocalExecutorParams params; params.device = device_.get(); - params.create_kernel = [this, version](const NodeDef& ndef, - OpKernel** kernel) { - return CreateNonCachedKernel(device_.get(), nullptr, ndef, version, - kernel); - }; + params.create_kernel = + [this, version](const std::shared_ptr& props, + OpKernel** kernel) { + return CreateNonCachedKernel(device_.get(), nullptr, props, version, + kernel); + }; params.delete_kernel = [](OpKernel* kernel) { DeleteNonCachedKernel(kernel); }; diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 14c0a8f5ad2..2140bf7f72b 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -187,7 +187,8 @@ class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime { void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame, DoneCallback done) override; - Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override; + Status CreateKernel(const std::shared_ptr& props, + OpKernel** kernel) override; bool IsStateful(const string& function_name) const override; @@ -256,7 +257,8 @@ void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle, base_flr_->Run(opts, handle, call_frame, std::move(done)); } -Status FunctionLibraryRuntimeOverlay::CreateKernel(const NodeDef&, OpKernel**) { +Status FunctionLibraryRuntimeOverlay::CreateKernel( + const std::shared_ptr&, OpKernel**) { // We don't have access to base_lib_def_ in base function library runtime (aka // FunctionLibraryRuntimeImpl), so to make sure we do not create a kernel with // the wrong lib_def we just disable creation of new kernels through overlays. @@ -344,7 +346,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { Status GetRetTypes(Handle handle, DataTypeVector* ret_types) override; - Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override; + Status CreateKernel(const std::shared_ptr& props, + OpKernel** kernel) override; void Run(const Options& opts, Handle handle, gtl::ArraySlice args, std::vector* rets, DoneCallback done) override; @@ -393,7 +396,9 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { const string device_name_; std::function get_func_sig_; - std::function create_kernel_; + std::function&, + OpKernel**)> + create_kernel_; mutable mutex mu_; @@ -426,8 +431,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { // to use for kernel creation and execution. In particular, this method can // accept a FunctionLibraryRuntimeOverlay that overlays a different // FunctionLibraryDefinition. - Status CreateKernel(const NodeDef& ndef, FunctionLibraryRuntime* flr, - OpKernel** kernel); + Status CreateKernel(const std::shared_ptr& props, + FunctionLibraryRuntime* flr, OpKernel** kernel); Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs, const FunctionLibraryDefinition* lib_def, std::unique_ptr* fbody); @@ -476,8 +481,9 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( get_func_sig_ = [this](const string& op, const OpDef** sig) { return base_lib_def_->LookUpOpDef(op, sig); }; - create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) { - return CreateKernel(ndef, kernel); + create_kernel_ = [this](const std::shared_ptr& props, + OpKernel** kernel) { + return CreateKernel(props, kernel); }; thread::ThreadPool* pool = nullptr; if (device_ != nullptr) { @@ -589,20 +595,20 @@ Status FunctionLibraryRuntimeImpl::GetRetTypes(Handle h, return Status::OK(); } -Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, - OpKernel** kernel) { - return CreateKernel(ndef, this, kernel); +Status FunctionLibraryRuntimeImpl::CreateKernel( + const std::shared_ptr& props, OpKernel** kernel) { + return CreateKernel(props, this, kernel); } -Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, - FunctionLibraryRuntime* flr, - OpKernel** kernel) { +Status FunctionLibraryRuntimeImpl::CreateKernel( + const std::shared_ptr& props, + FunctionLibraryRuntime* flr, OpKernel** kernel) { // If a custom kernel creator is given, try that. Status s; if (custom_kernel_creator_ != nullptr && - custom_kernel_creator_->CanCreateKernel(*this, ndef)) { + custom_kernel_creator_->CanCreateKernel(*this, props)) { std::unique_ptr ret; - s = custom_kernel_creator_->CreateKernel(this, ndef, &ret); + s = custom_kernel_creator_->CreateKernel(this, props, &ret); if (s.ok()) { *kernel = ret.release(); } else { @@ -613,9 +619,9 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, const FunctionLibraryDefinition* lib_def = flr->GetFunctionLibraryDefinition(); - if (lib_def->Find(ndef.op()) == nullptr) { + if (lib_def->Find(props->node_def.op()) == nullptr) { // A primitive operation. Creates the registered kernel. - return CreateNonCachedKernel(device_, flr, ndef, graph_def_version_, + return CreateNonCachedKernel(device_, flr, props, graph_def_version_, kernel); } @@ -626,8 +632,9 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, options.lib_def = lib_def; } Handle handle; - TF_RETURN_IF_ERROR( - Instantiate(ndef.op(), AttrSlice(&ndef.attr()), options, &handle)); + TF_RETURN_IF_ERROR(Instantiate(props->node_def.op(), + AttrSlice(&props->node_def.attr()), options, + &handle)); const FunctionBody* fbody = GetFunctionBody(handle); CHECK_NOTNULL(fbody); @@ -647,10 +654,12 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, // Constructs a CallOp kernel for running the instantiated function. auto device_type = DeviceType(device_->attributes().device_type()); + auto new_props = std::make_shared( + &fbody->fdef.signature(), props->node_def, fbody->arg_types, + fbody->ret_types); OpKernelConstruction construction( - device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef, - &fbody->fdef.signature(), flr, device_->resource_manager(), - fbody->arg_types, input_memory_types, fbody->ret_types, + device_type, device_, device_->GetAllocator(AllocatorAttributes()), flr, + device_->resource_manager(), props, input_memory_types, output_memory_types, graph_def_version_, &s); if (s.ok()) { *kernel = new CallOp(handle, &construction); @@ -953,9 +962,11 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) { if (flr == this) { params.create_kernel = create_kernel_; } else { - params.create_kernel = [this, flr](const NodeDef& ndef, OpKernel** kernel) { - return CreateKernel(ndef, flr, kernel); - }; + params.create_kernel = + [this, flr](const std::shared_ptr& props, + OpKernel** kernel) { + return CreateKernel(props, flr, kernel); + }; } params.delete_kernel = [](OpKernel* kernel) { DeleteNonCachedKernel(kernel); diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index c1247190d2d..3e2371a686a 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -90,11 +90,12 @@ class FunctionTest : public ::testing::Test { const int version = g->versions().producer(); LocalExecutorParams params; params.device = device_.get(); - params.create_kernel = [this, version](const NodeDef& ndef, - OpKernel** kernel) { - return CreateNonCachedKernel(device_.get(), nullptr, ndef, version, - kernel); - }; + params.create_kernel = + [this, version](const std::shared_ptr& props, + OpKernel** kernel) { + return CreateNonCachedKernel(device_.get(), nullptr, props, version, + kernel); + }; params.delete_kernel = [](OpKernel* kernel) { DeleteNonCachedKernel(kernel); }; diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index 0a7d50f9ea4..7ffb860a2ce 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -157,9 +157,10 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, params.device = device_; params.function_library = function_library; const int producer = graph_to_run->versions().producer(); - params.create_kernel = [this, function_library, producer](const NodeDef& ndef, - OpKernel** kernel) { - return CreateNonCachedKernel(device_, function_library, ndef, producer, + params.create_kernel = [this, function_library, producer]( + const std::shared_ptr& props, + OpKernel** kernel) { + return CreateNonCachedKernel(device_, function_library, props, producer, kernel); }; params.delete_kernel = [](OpKernel* kernel) { delete kernel; }; diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc index fe703050602..4118534cb3e 100644 --- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc @@ -84,9 +84,10 @@ Benchmark::Benchmark(const string& device, Graph* g, LocalExecutorParams params; params.device = device_.get(); params.function_library = nullptr; - params.create_kernel = [this, graph_def_version](const NodeDef& ndef, - OpKernel** kernel) { - return CreateNonCachedKernel(device_.get(), nullptr, ndef, + params.create_kernel = [this, graph_def_version]( + const std::shared_ptr& props, + OpKernel** kernel) { + return CreateNonCachedKernel(device_.get(), nullptr, props, graph_def_version, kernel); }; params.delete_kernel = [](OpKernel* kernel) { diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 9b28651c597..96fc4f3d4f3 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -233,23 +233,25 @@ Status GraphMgr::InitItem( // Construct the root executor for the subgraph. params.device = unit->device; params.function_library = lib; - params.create_kernel = [handle, lib, opseg](const NodeDef& ndef, - OpKernel** kernel) { - // NOTE(mrry): We must not share function kernels (implemented - // using `CallOp`) between subgraphs, because `CallOp::handle_` - // is tied to a particular subgraph. Even if the function itself - // is stateful, the `CallOp` that invokes it is not. - if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) { - return lib->CreateKernel(ndef, kernel); - } - auto create_fn = [lib, &ndef](OpKernel** kernel) { - return lib->CreateKernel(ndef, kernel); - }; - // Kernels created for subgraph nodes need to be cached. On - // cache miss, create_fn() is invoked to create a kernel based - // on the function library here + global op registry. - return opseg->FindOrCreate(handle, ndef.name(), kernel, create_fn); - }; + params.create_kernel = + [handle, lib, opseg](const std::shared_ptr& props, + OpKernel** kernel) { + // NOTE(mrry): We must not share function kernels (implemented + // using `CallOp`) between subgraphs, because `CallOp::handle_` + // is tied to a particular subgraph. Even if the function itself + // is stateful, the `CallOp` that invokes it is not. + if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) { + return lib->CreateKernel(props, kernel); + } + auto create_fn = [lib, &props](OpKernel** kernel) { + return lib->CreateKernel(props, kernel); + }; + // Kernels created for subgraph nodes need to be cached. On + // cache miss, create_fn() is invoked to create a kernel based + // on the function library here + global op registry. + return opseg->FindOrCreate(handle, props->node_def.name(), kernel, + create_fn); + }; params.delete_kernel = [lib](OpKernel* kernel) { if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) { delete kernel; diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 003e4894788..f3207dd657a 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -129,6 +129,7 @@ exports_files( "attr_value_util.h", "common_shape_fns.h", "node_def_util.h", + "node_properties.h", "op.h", "op_def_builder.h", "op_def_util.h", @@ -172,6 +173,7 @@ filegroup( "model.h", "node_def_builder.h", "node_def_util.h", + "node_properties.h", "numeric_op.h", "numeric_types.h", "op.h", @@ -338,6 +340,8 @@ filegroup( "node_def_builder.h", "node_def_util.cc", "node_def_util.h", + "node_properties.cc", + "node_properties.h", "numeric_op.h", "op.cc", "op.h", @@ -862,6 +866,21 @@ cc_library( ], ) +cc_library( + name = "node_properties", + srcs = ["node_properties.cc"], + hdrs = ["node_properties.h"], + deps = [ + ":node_def_proto_cc", + ":node_def_util", + ":op", + ":op_def_proto_cc", + ":tensor", + ":types_proto_cc", + "//tensorflow/core/lib/core:status", + ], +) + cc_library( name = "op_def_builder", srcs = ["op_def_builder.cc"], @@ -967,6 +986,7 @@ tf_cc_tests( "model_test.cc", "node_def_builder_test.cc", "node_def_util_test.cc", + "node_properties_test.cc", "op_compatibility_test.cc", "op_def_builder_test.cc", "op_def_util_test.cc", diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 0e260d26592..58cc1bbdaf9 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -722,11 +722,13 @@ class FunctionLibraryRuntime { virtual void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame, DoneCallback done) = 0; - // Creates a "kernel" for the given node def "ndef". + // Creates a "kernel" for the given NodeProperties "props". // // If succeeds, returns OK and the caller takes the ownership of the // returned "*kernel". Otherwise, returns an error. - virtual Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) = 0; + virtual Status CreateKernel( + const std::shared_ptr& props, + OpKernel** kernel) = 0; // Returns true iff the function named `function_name` is stateful. // @@ -818,12 +820,15 @@ class CustomKernelCreator { // Given a NodeDef 'node_def' and the function library runtime 'flr', // validate if the class supports creating such a kernel. - virtual bool CanCreateKernel(const FunctionLibraryRuntime& flr, - const NodeDef& node_def) const = 0; + virtual bool CanCreateKernel( + const FunctionLibraryRuntime& flr, + const std::shared_ptr& props) const = 0; // Given a supported NodeDef, returns a kernel that computes the node. - virtual Status CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& ndef, - std::unique_ptr* kernel) const = 0; + virtual Status CreateKernel( + FunctionLibraryRuntime* flr, + const std::shared_ptr& props, + std::unique_ptr* kernel) const = 0; }; // Used to instantiate and run functions in a distributed system. diff --git a/tensorflow/core/framework/node_properties.cc b/tensorflow/core/framework/node_properties.cc new file mode 100644 index 00000000000..bcc81bdbbff --- /dev/null +++ b/tensorflow/core/framework/node_properties.cc @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_properties.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +// static +Status NodeProperties::CreateFromNodeDef( + NodeDef node_def, const OpRegistryInterface* op_registry, + std::shared_ptr* props) { + const OpDef* op_def; + TF_RETURN_IF_ERROR(op_registry->LookUpOpDef(node_def.op(), &op_def)); + DataTypeVector input_types; + DataTypeVector output_types; + TF_RETURN_IF_ERROR( + InOutTypesForNode(node_def, *op_def, &input_types, &output_types)); + props->reset(new NodeProperties(op_def, std::move(node_def), + std::move(input_types), + std::move(output_types))); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/node_properties.h b/tensorflow/core/framework/node_properties.h new file mode 100644 index 00000000000..0382321f486 --- /dev/null +++ b/tensorflow/core/framework/node_properties.h @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_PROPERTIES_H_ +#define TENSORFLOW_CORE_FRAMEWORK_NODE_PROPERTIES_H_ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class OpRegistryInterface; + +struct NodeProperties { + public: + NodeProperties(const OpDef* op_def, NodeDef node_def, + const DataTypeSlice inputs, const DataTypeSlice outputs) + : NodeProperties(op_def, std::move(node_def), + DataTypeVector(inputs.begin(), inputs.end()), + DataTypeVector(outputs.begin(), outputs.end())) {} + + NodeProperties(const OpDef* _op_def, NodeDef&& _node_def, + DataTypeVector inputs, DataTypeVector outputs) + : op_def(_op_def), + node_def(std::move(_node_def)), + input_types(std::move(inputs)), + input_types_slice(input_types), + output_types(std::move(outputs)), + output_types_slice(output_types) {} + + // Resets the 'props' shared pointer to point to a new NodeProperties created + // from the given NodeDef. 'op_registry' is used to look up the OpDef + // corresponding to node_def.op(). Returns an error if OpDef lookup or + // creation failed. + static Status CreateFromNodeDef(NodeDef node_def, + const OpRegistryInterface* op_registry, + std::shared_ptr* props); + + const OpDef* op_def; // not owned. + NodeDef node_def; + DataTypeVector input_types; + DataTypeSlice input_types_slice; + DataTypeVector output_types; + DataTypeSlice output_types_slice; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_NODE_PROPERTIES_H_ diff --git a/tensorflow/core/framework/node_properties_test.cc b/tensorflow/core/framework/node_properties_test.cc new file mode 100644 index 00000000000..9f76b953b06 --- /dev/null +++ b/tensorflow/core/framework/node_properties_test.cc @@ -0,0 +1,128 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_properties.h" + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +OpDef ToOpDef(const OpDefBuilder& builder) { + OpRegistrationData op_reg_data; + EXPECT_TRUE(builder.Finalize(&op_reg_data).ok()); + return op_reg_data.op_def; +} + +class MockOpRegistry : public OpRegistryInterface { + public: + MockOpRegistry() + : op_reg_(ToOpDef(OpDefBuilder("Foo") + .Input("f: float") + .Input("i: int32") + .Output("of: double"))) {} + ~MockOpRegistry() override {} + + // Returns an error status and sets *op_reg_data to nullptr if no OpDef is + // registered under that name, otherwise returns the registered OpDef. + // Caller must not delete the returned pointer. + Status LookUp(const string& op_type_name, + const OpRegistrationData** op_reg_data) const override { + if (op_type_name == "Foo") { + *op_reg_data = &op_reg_; + return Status::OK(); + } else { + *op_reg_data = nullptr; + return errors::InvalidArgument("Op type named ", op_type_name, + " not found"); + } + } + + const OpDef* get_op_def_addr() { return &op_reg_.op_def; } + + private: + const OpRegistrationData op_reg_; +}; + +void ValidateNodeProperties(const NodeProperties& props, const OpDef* op_def, + const NodeDef& node_def, + const DataTypeVector& input_types, + const DataTypeVector& output_types) { + EXPECT_EQ(props.op_def, op_def); + EXPECT_EQ(props.node_def.name(), node_def.name()); + ASSERT_EQ(props.input_types.size(), input_types.size()); + for (int i = 0; i < input_types.size(); ++i) { + EXPECT_EQ(props.input_types[i], input_types[i]); + EXPECT_EQ(props.input_types_slice[i], input_types[i]); + } + ASSERT_EQ(props.output_types.size(), output_types.size()); + for (int i = 0; i < output_types.size(); ++i) { + EXPECT_EQ(props.output_types[i], output_types[i]); + EXPECT_EQ(props.output_types_slice[i], output_types[i]); + } +} + +} // namespace + +TEST(NodeProperties, Contructors) { + OpDef op_def; + NodeDef node_def; + node_def.set_name("foo"); + DataTypeVector input_types{DT_FLOAT, DT_INT32}; + DataTypeVector output_types{DT_DOUBLE}; + DataTypeSlice input_types_slice(input_types); + DataTypeSlice output_types_slice(output_types); + + // Construct from slices. + NodeProperties props_from_slices(&op_def, node_def, input_types_slice, + output_types_slice); + ValidateNodeProperties(props_from_slices, &op_def, node_def, input_types, + output_types); + + // Construct from vectors. + NodeProperties props_from_vectors(&op_def, node_def, input_types, + output_types); + ValidateNodeProperties(props_from_vectors, &op_def, node_def, input_types, + output_types); +} + +TEST(NodeProperties, CreateFromNodeDef) { + MockOpRegistry op_registry; + NodeDef node_def; + node_def.set_name("bar"); + node_def.set_op("Foo"); + node_def.add_input("f_in"); + node_def.add_input("i_in"); + + std::shared_ptr props; + EXPECT_TRUE( + NodeProperties::CreateFromNodeDef(node_def, &op_registry, &props).ok()); + + DataTypeVector input_types{DT_FLOAT, DT_INT32}; + DataTypeVector output_types{DT_DOUBLE}; + ValidateNodeProperties(*props, op_registry.get_op_def_addr(), node_def, + input_types, output_types); + + // The OpDef lookup should fail for this one: + node_def.set_op("Baz"); + std::shared_ptr props_bad; + EXPECT_FALSE( + NodeProperties::CreateFromNodeDef(node_def, &op_registry, &props_bad) + .ok()); + EXPECT_EQ(props_bad, nullptr); +} +} // namespace tensorflow diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 2feb84a1786..38c56eb3b1c 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -35,9 +35,9 @@ limitations under the License. #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/node_properties.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -91,35 +91,53 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs, // OpKernel ------------------------------------------------------------------ -OpKernel::OpKernel(OpKernelConstruction* context) - : OpKernel(context, MakeUnique(context->def())) {} +OpKernel::OpKernel(OpKernelConstruction* context) : OpKernel(context, false) {} OpKernel::OpKernel(OpKernelConstruction* context, bool is_deferred) - : OpKernel(context, MakeUnique(context->def()), - is_deferred) {} - -OpKernel::OpKernel(OpKernelConstruction* context, - std::unique_ptr node_def, bool is_deferred) - : def_(std::move(node_def)), - input_types_(context->input_types().begin(), - context->input_types().end()), + : props_(context->props_), input_memory_types_(context->input_memory_types().begin(), context->input_memory_types().end()), - output_types_(context->output_types().begin(), - context->output_types().end()), output_memory_types_(context->output_memory_types().begin(), context->output_memory_types().end()), input_name_map_(context->num_inputs()), output_name_map_(context->num_outputs()), - name_view_(def_->name()), - type_string_view_(def_->op()), + name_view_(props_->node_def.name()), + type_string_view_(props_->node_def.op()), graph_def_version_(context->graph_def_version()), is_deferred_(is_deferred), cost_estimate_(OpKernel::kInitialCostEstimateCycles) { OP_REQUIRES_OK(context, - NameRangesForNode(*def_, *context->op_def_, &input_name_map_, - &output_name_map_)); - OP_REQUIRES_OK(context, CheckOpDeprecation(*context->op_def_, + NameRangesForNode(props_->node_def, *props_->op_def, + &input_name_map_, &output_name_map_)); + OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def, + context->graph_def_version())); + + // Kernels executing on GPU/SYCL tie very few resources on the CPU where the + // scheduler runs: we consider them as inexpensive. + expensive_ = context->device_type() != DeviceType(DEVICE_GPU) && + context->device_type() != DeviceType(DEVICE_SYCL); +} + +OpKernel::OpKernel(OpKernelConstruction* context, NodeDef&& custom_def, + bool is_deferred) + : props_(std::make_shared( + context->props_->op_def, std::move(custom_def), + context->props_->input_types, context->props_->output_types)), + input_memory_types_(context->input_memory_types().begin(), + context->input_memory_types().end()), + output_memory_types_(context->output_memory_types().begin(), + context->output_memory_types().end()), + input_name_map_(context->num_inputs()), + output_name_map_(context->num_outputs()), + name_view_(props_->node_def.name()), + type_string_view_(props_->node_def.op()), + graph_def_version_(context->graph_def_version()), + is_deferred_(is_deferred), + cost_estimate_(OpKernel::kInitialCostEstimateCycles) { + OP_REQUIRES_OK(context, + NameRangesForNode(props_->node_def, *props_->op_def, + &input_name_map_, &output_name_map_)); + OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def, context->graph_def_version())); // Kernels executing on GPU/SYCL tie very few resources on the CPU where the @@ -134,10 +152,6 @@ const uint64 OpKernel::kInitialCostEstimateCycles; const uint64 OpKernel::kOpIsExpensiveThresholdCycles; const uint64 OpKernel::kCostDecay; -const string& OpKernel::name() const { return def_->name(); } -const string& OpKernel::type_string() const { return def_->op(); } -const string& OpKernel::requested_device() const { return def_->device(); } -const string& OpKernel::requested_input(int i) const { return def_->input(i); } Status OpKernel::InputRange(StringPiece input_name, int* start, int* stop) const { @@ -216,22 +230,18 @@ Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) { OpKernelConstruction::OpKernelConstruction( DeviceType device_type, DeviceBase* device, Allocator* allocator, - const NodeDef* node_def, const OpDef* op_def, FunctionLibraryRuntime* flib, - ResourceMgr* resource_mgr, const DataTypeSlice& input_types, + FunctionLibraryRuntime* flib, ResourceMgr* resource_mgr, + const std::shared_ptr& props, const MemoryTypeSlice& input_memory_types, - const DataTypeSlice& output_types, const MemoryTypeSlice& output_memory_types, int graph_def_version, Status* status) : device_type_(std::move(device_type)), device_(device), allocator_(allocator), - def_(node_def), - op_def_(op_def), flib_(flib), resource_mgr_(resource_mgr), - input_types_(input_types), + props_(props), input_memory_types_(input_memory_types), - output_types_(output_types), output_memory_types_(output_memory_types), graph_def_version_(graph_def_version), status_(status) {} @@ -246,8 +256,8 @@ void OpKernelConstruction::SetStatus(const Status& status) { Status OpKernelConstruction::MatchSignature( const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) { - return MatchSignatureHelper(expected_inputs, expected_outputs, input_types_, - output_types_); + return MatchSignatureHelper(expected_inputs, expected_outputs, + props_->input_types, props_->output_types); } Status OpKernelConstruction::allocate_temp(DataType type, @@ -263,7 +273,7 @@ Status OpKernelConstruction::allocate_temp(DataType type, } if (LogMemory::IsEnabled()) { LogMemory::RecordTensorAllocation( - def_->name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp); + def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp); } *out_temp = new_temp; return Status::OK(); @@ -288,7 +298,7 @@ Status OpKernelConstruction::allocate_temp(DataType type, } if (LogMemory::IsEnabled()) { LogMemory::RecordTensorAllocation( - def_->name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp); + def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp); } *out_temp = new_temp; return Status::OK(); @@ -1544,45 +1554,65 @@ string KernelsRegisteredForOp(StringPiece op_name) { return ret; } +/* TODO(rmlarsen): This API is deprecated. Remove it if possible to avoid + * copying the NodeDef. */ std::unique_ptr CreateOpKernel( DeviceType device_type, DeviceBase* device, Allocator* allocator, const NodeDef& node_def, int graph_def_version, Status* status) { + // Look up the Op registered for this op name. + std::shared_ptr props; + status->Update(NodeProperties::CreateFromNodeDef( + node_def, OpRegistry::Global(), &props)); + if (!status->ok()) { + errors::AppendToMessage(status, + " for node: ", FormatNodeDefForError(node_def)); + return nullptr; + } + return CreateOpKernel(device_type, device, allocator, props, + graph_def_version, status); +} + +std::unique_ptr CreateOpKernel( + DeviceType device_type, DeviceBase* device, Allocator* allocator, + const std::shared_ptr& props, int graph_def_version, + Status* status) { OpKernel* kernel = nullptr; - *status = CreateOpKernel(std::move(device_type), device, allocator, nullptr, - node_def, graph_def_version, &kernel); + *status = CreateOpKernel(std::move(device_type), device, allocator, + /*flib=*/nullptr, props, graph_def_version, &kernel); return std::unique_ptr(kernel); } Status CreateOpKernel(DeviceType device_type, DeviceBase* device, Allocator* allocator, FunctionLibraryRuntime* flib, - const NodeDef& node_def, int graph_def_version, - OpKernel** kernel) { + const std::shared_ptr& props, + int graph_def_version, OpKernel** kernel) { return CreateOpKernel(std::move(device_type), device, allocator, flib, - /* resource_mgr= */ nullptr, node_def, - graph_def_version, kernel); + /* resource_mgr= */ nullptr, props, graph_def_version, + kernel); } Status CreateOpKernel(DeviceType device_type, DeviceBase* device, Allocator* allocator, FunctionLibraryRuntime* flib, - ResourceMgr* resource_mgr, const NodeDef& node_def, + ResourceMgr* resource_mgr, + const std::shared_ptr& props, int graph_def_version, OpKernel** kernel) { - VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def); - - // Look up the Op registered for this op name. - const OpDef* op_def = nullptr; - TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def)); - - // Validate node_def against OpDef. - TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *op_def)); - - // Look up kernel registration. - const KernelRegistration* registration; + const NodeDef& node_def = props->node_def; bool was_attr_mismatch; - Status s = FindKernelRegistration(device_type, node_def, ®istration, - &was_attr_mismatch); - if (!s.ok()) { - errors::AppendToMessage(&s, " when instantiating ", node_def.op()); - return s; + const KernelRegistration* registration = nullptr; + Status s; + if (props != nullptr) { + VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def); + + // Validate node_def against OpDef. + TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *props->op_def)); + + // Look up kernel registration. + s = FindKernelRegistration(device_type, node_def, ®istration, + &was_attr_mismatch); + if (!s.ok()) { + errors::AppendToMessage(&s, " when instantiating ", node_def.op()); + return s; + } } if (registration == nullptr) { s.Update(errors::NotFound("No registered '", node_def.op(), @@ -1599,15 +1629,6 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device, return s; } - // Get signature from the OpDef & NodeDef - DataTypeVector inputs; - DataTypeVector outputs; - s.Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs)); - if (!s.ok()) { - errors::AppendToMessage(&s, " for node: ", FormatNodeDefForError(node_def)); - return s; - } - // We are creating a kernel for an op registered in // OpRegistry::Global(), we consult the kernel registry to decide // the kernel's input and output memory types. @@ -1618,10 +1639,9 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device, &output_memory_types)); // Everything needed for OpKernel construction. - OpKernelConstruction context(std::move(device_type), device, allocator, - &node_def, op_def, flib, resource_mgr, inputs, - input_memory_types, outputs, output_memory_types, - graph_def_version, &s); + OpKernelConstruction context(std::move(device_type), device, allocator, flib, + resource_mgr, props, input_memory_types, + output_memory_types, graph_def_version, &s); *kernel = registration->factory->Create(&context); if (!s.ok()) { delete *kernel; diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 4f1cc91cd19..e0d9742768a 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/node_properties.h" #include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/selective_registration.h" @@ -85,19 +86,18 @@ class OpKernel { // expensive initialization in the descendant's constructor. explicit OpKernel(OpKernelConstruction* context); - // Specialized constructor that enables the descendant to provide a different - // `NodeDef` value. For example, this constructor can be used to provide a - // stripped-down `NodeDef` that does not contain the full set of attrs (such - // as tensor values) if the descendant stores them in a different form. - explicit OpKernel(OpKernelConstruction* context, - std::unique_ptr node_def, - bool is_deferred = false); - // Specialized constructor that allows a kernel implementation to mark itself // as a "deferred" op. If true, the executor will provide access to the // `OpKernelContext::inc_num_deferred_ops_function()` and // `OpKernelContext::dec_num_deferred_ops_function()` methods at run-time. - explicit OpKernel(OpKernelConstruction* context, bool is_deferred); + OpKernel(OpKernelConstruction* context, bool is_deferred); + + // Specialized constructor that enables the descendant to provide a custom + // `NodeDef` value. For example, this constructor can be used to provide a + // stripped-down `NodeDef` that does not contain the full set of attrs (such + // as tensor values) if the descendant stores them in a different form. + OpKernel(OpKernelConstruction* context, NodeDef&& custom_def, + bool is_deferred); virtual ~OpKernel(); @@ -170,24 +170,26 @@ class OpKernel { } // Accessors. - const NodeDef& def() const { return *def_; } - const string& name() const; // Same as def().name() + const NodeDef& def() const { return props_->node_def; } + const string& name() const { return props_->node_def.name(); } absl::string_view name_view() const { return name_view_; } - const string& type_string() const; // Same as def().op() + const string& type_string() const { return props_->node_def.op(); } absl::string_view type_string_view() const { return type_string_view_; } - const string& requested_device() const; // Same as def().device() + const string& requested_input(int i) const { + return props_->node_def.input(i); + } + const string& requested_device() const { return props_->node_def.device(); } - int num_inputs() const { return input_types_.size(); } - DataType input_type(int i) const { return input_types_[i]; } - const DataTypeVector& input_types() const { return input_types_; } + int num_inputs() const { return props_->input_types.size(); } + DataType input_type(int i) const { return props_->input_types[i]; } + const DataTypeVector& input_types() const { return props_->input_types; } const MemoryTypeVector& input_memory_types() const { return input_memory_types_; } - const string& requested_input(int i) const; // Same as def().input(i) - int num_outputs() const { return output_types_.size(); } - DataType output_type(int o) const { return output_types_[o]; } - const DataTypeVector& output_types() const { return output_types_; } + int num_outputs() const { return props_->output_types.size(); } + DataType output_type(int o) const { return props_->output_types[o]; } + const DataTypeVector& output_types() const { return props_->output_types; } const MemoryTypeVector& output_memory_types() const { return output_memory_types_; } @@ -209,10 +211,8 @@ class OpKernel { string GetTraceArgument(OpKernelContext* ctx); private: - const std::unique_ptr def_; - const DataTypeVector input_types_; + const std::shared_ptr props_; const MemoryTypeVector input_memory_types_; - const DataTypeVector output_types_; const MemoryTypeVector output_memory_types_; NameRangeMap input_name_map_; NameRangeMap output_name_map_; @@ -284,12 +284,10 @@ class PersistentTensor { class OpKernelConstruction { public: OpKernelConstruction(DeviceType device_type, DeviceBase* device, - Allocator* allocator, const NodeDef* node_def, - const OpDef* op_def, FunctionLibraryRuntime* flib, + Allocator* allocator, FunctionLibraryRuntime* flib, ResourceMgr* resource_mgr, - const DataTypeSlice& input_types, + const std::shared_ptr& props, const MemoryTypeSlice& input_memory_types, - const DataTypeSlice& output_types, const MemoryTypeSlice& output_memory_types, int graph_def_version, Status* status); @@ -330,20 +328,22 @@ class OpKernelConstruction { Tensor** out_tensor); // User-supplied configuration of this operation. - const NodeDef& def() const { return *def_; } + const NodeDef& def() const { return props_->node_def; } // For inspecting the inputs to this operation. - int num_inputs() const { return input_types_.size(); } - DataType input_type(int i) const { return input_types_[i]; } - const DataTypeSlice& input_types() const { return input_types_; } + int num_inputs() const { return props_->input_types.size(); } + DataType input_type(int i) const { return props_->input_types[i]; } + const DataTypeSlice& input_types() const { return props_->input_types_slice; } const MemoryTypeSlice& input_memory_types() const { return input_memory_types_; } // For inspecting the outputs expected from this operation. - int num_outputs() const { return output_types_.size(); } - DataType output_type(int i) const { return output_types_[i]; } - const DataTypeSlice& output_types() const { return output_types_; } + int num_outputs() const { return props_->output_types.size(); } + DataType output_type(int i) const { return props_->output_types[i]; } + const DataTypeSlice& output_types() const { + return props_->output_types_slice; + } const MemoryTypeSlice& output_memory_types() const { return output_memory_types_; } @@ -403,19 +403,15 @@ class OpKernelConstruction { const DeviceType device_type_; DeviceBase* const device_; Allocator* allocator_; - const NodeDef* def_; - const OpDef* op_def_; FunctionLibraryRuntime* flib_; ResourceMgr* const resource_mgr_; - DataTypeSlice input_types_; + std::shared_ptr props_; MemoryTypeSlice input_memory_types_; - DataTypeSlice output_types_; MemoryTypeSlice output_memory_types_; const int graph_def_version_; Status* status_; - // Allow op_def_ across from OpKernel, but not from subclasses. - // TODO(irving): Remove protos from this header entirely. + // Allow access from OpKernel ctor. friend class OpKernel; TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction); @@ -1404,15 +1400,23 @@ const Eigen::SyclDevice& OpKernelContext::eigen_device() const; std::unique_ptr CreateOpKernel(DeviceType device_type, DeviceBase* device, Allocator* allocator, - const NodeDef& def, + const NodeDef& node_def, int graph_def_version, Status* status); + +std::unique_ptr CreateOpKernel( + DeviceType device_type, DeviceBase* device, Allocator* allocator, + const std::shared_ptr& props, int graph_def_version, + Status* status); + Status CreateOpKernel(DeviceType device_type, DeviceBase* device, Allocator* allocator, FunctionLibraryRuntime* flib, - const NodeDef& def, int graph_def_version, - OpKernel** kernel); + const std::shared_ptr& props, + int graph_def_version, OpKernel** kernel); + Status CreateOpKernel(DeviceType device_type, DeviceBase* device, Allocator* allocator, FunctionLibraryRuntime* flib, - ResourceMgr* resource_mgr, const NodeDef& def, + ResourceMgr* resource_mgr, + const std::shared_ptr& props, int graph_def_version, OpKernel** kernel); // Returns into 'device_types' the subset of prioritized_types that this diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index ec887a0ad93..40425cf24e0 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/str_util.h" diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 6240d0fb1ca..1f8a4d06c7a 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/node_properties.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/versions.pb.h" @@ -37,23 +37,7 @@ namespace tensorflow { const int Graph::kControlSlot = -1; -struct NodeProperties { - public: - NodeProperties(const OpDef* op_def, NodeDef node_def, - const DataTypeSlice inputs, const DataTypeSlice outputs) - : op_def(op_def), - node_def(std::move(node_def)), - input_types(inputs.begin(), inputs.end()), - output_types(outputs.begin(), outputs.end()) {} - - const OpDef* op_def; // not owned - NodeDef node_def; - const DataTypeVector input_types; - const DataTypeVector output_types; -}; - // Node - #define REF_CLASS(key, value) \ {key, value}, { "Ref" key, value } @@ -97,7 +81,8 @@ const std::unordered_map& Node::kNodeClassTable = {"StatelessIf", NC_IF}, {"While", NC_WHILE}, {"StatelessWhile", NC_WHILE}, - // Not using the constants defined in FunctionLibraryDefinition for the + // Not using the constants defined in FunctionLibraryDefinition + // for the // 4 ops below because android inference library does not link // tf.function related files. {"_Arg", NC_ARG}, diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index b33c0319c75..235d944bd60 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/edgeset.h" @@ -67,7 +68,6 @@ class WhileContext; class NeighborIter; // Declared below class NodeIter; // Declared below -struct NodeProperties; // Defined in .cc class Node { public: @@ -229,11 +229,12 @@ class Node { while_ctx_ = while_ctx; } + std::shared_ptr properties() const { return props_; } + private: friend class Graph; Node(); - NodeProperties* properties() const { return props_.get(); } void Initialize(int id, int cost_id, std::shared_ptr props, bool is_function_op); diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index 5931599c6e2..ccdafdf91c9 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -47,31 +47,30 @@ namespace tensorflow { namespace { -std::unique_ptr StripTensorDataFromNodeDef( - OpKernelConstruction* ctx) { +NodeDef StripTensorDataFromNodeDef(OpKernelConstruction* ctx) { #ifndef __ANDROID__ DCHECK_EQ(NodeDef::descriptor()->field_count(), 6) << "The NodeDef format has changed, and the attr-stripping code may need " << "to be updated."; #endif const NodeDef& original = ctx->def(); - NodeDef* ret = new NodeDef; - ret->set_name(original.name()); - ret->set_op(original.op()); - ret->set_device(original.device()); + NodeDef ret; + ret.set_name(original.name()); + ret.set_op(original.op()); + ret.set_device(original.device()); // Strip the "value" attr from the returned NodeDef. // NOTE(mrry): The present implementation of `OpKernel::OpKernel()` only uses // attrs that affect the cardinality of list-typed inputs and outputs, so it // is safe to drop other attrs from the NodeDef. - AddNodeAttr("dtype", ctx->output_type(0), ret); - MergeDebugInfo(original, ret); - return std::unique_ptr(ret); + AddNodeAttr("dtype", ctx->output_type(0), &ret); + MergeDebugInfo(original, &ret); + return ret; } } // namespace ConstantOp::ConstantOp(OpKernelConstruction* ctx) - : OpKernel(ctx, StripTensorDataFromNodeDef(ctx)), + : OpKernel(ctx, StripTensorDataFromNodeDef(ctx), false), tensor_(ctx->output_type(0)) { const TensorProto* proto = nullptr; MEMDEBUG_CACHE_OP(ctx->def().name().c_str()); diff --git a/tensorflow/core/kernels/data/dataset_test_base.cc b/tensorflow/core/kernels/data/dataset_test_base.cc index 7c5d0c3f679..817e075e69b 100644 --- a/tensorflow/core/kernels/data/dataset_test_base.cc +++ b/tensorflow/core/kernels/data/dataset_test_base.cc @@ -304,9 +304,14 @@ Status DatasetOpsTestBase::ExpectEqual(std::vector produced_tensors, Status DatasetOpsTestBase::CreateOpKernel( const NodeDef& node_def, std::unique_ptr* op_kernel) { OpKernel* kernel; + Status s; + + std::shared_ptr props; + TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef( + node_def, flr_->GetFunctionLibraryDefinition(), &props)); TF_RETURN_IF_ERROR(tensorflow::CreateOpKernel( device_type_, device_.get(), allocator_, flr_, - device_->resource_manager(), node_def, TF_GRAPH_DEF_VERSION, &kernel)); + device_->resource_manager(), props, TF_GRAPH_DEF_VERSION, &kernel)); op_kernel->reset(kernel); return Status::OK(); } @@ -435,9 +440,10 @@ Status DatasetOpsTestBase::RunFunction( LocalExecutorParams params; params.function_library = flr_; params.device = device_.get(); - params.create_kernel = [this, version](const NodeDef& ndef, - OpKernel** kernel) { - return CreateNonCachedKernel(device_.get(), this->flr_, ndef, version, + params.create_kernel = [this, version]( + const std::shared_ptr& props, + OpKernel** kernel) { + return CreateNonCachedKernel(device_.get(), this->flr_, props, version, kernel); }; params.delete_kernel = [](OpKernel* kernel) { diff --git a/tensorflow/core/kernels/data/single_threaded_executor.cc b/tensorflow/core/kernels/data/single_threaded_executor.cc index a6b31679fa6..5393d5557eb 100644 --- a/tensorflow/core/kernels/data/single_threaded_executor.cc +++ b/tensorflow/core/kernels/data/single_threaded_executor.cc @@ -108,7 +108,8 @@ class SingleThreadedExecutorImpl : public Executor { KernelState& kernel_state = kernels_[kernel_index]; node_to_index_map[n] = kernel_index; - TF_RETURN_IF_ERROR(params_.create_kernel(n->def(), &kernel_state.kernel)); + TF_RETURN_IF_ERROR( + params_.create_kernel(n->properties(), &kernel_state.kernel)); kernel_state.num_inputs = n->num_inputs(); kernel_state.num_outputs = n->num_outputs(); diff --git a/tensorflow/core/kernels/data/single_threaded_executor_test.cc b/tensorflow/core/kernels/data/single_threaded_executor_test.cc index 1a5059487a4..898a6555265 100644 --- a/tensorflow/core/kernels/data/single_threaded_executor_test.cc +++ b/tensorflow/core/kernels/data/single_threaded_executor_test.cc @@ -58,11 +58,12 @@ class ExecutorTest : public ::testing::Test { const int version = graph->versions().producer(); LocalExecutorParams params; params.device = device_.get(); - params.create_kernel = [this, version](const NodeDef& ndef, - OpKernel** kernel) { - return CreateNonCachedKernel(device_.get(), nullptr, ndef, version, - kernel); - }; + params.create_kernel = + [this, version](const std::shared_ptr& props, + OpKernel** kernel) { + return CreateNonCachedKernel(device_.get(), nullptr, props, version, + kernel); + }; params.delete_kernel = [](OpKernel* kernel) { DeleteNonCachedKernel(kernel); }; diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py index f510f24d777..c2389025a25 100644 --- a/tensorflow/python/eager/pywrap_tfe_test.py +++ b/tensorflow/python/eager/pywrap_tfe_test.py @@ -237,8 +237,7 @@ class Tests(test.TestCase): @test_util.assert_no_garbage_created def testInvalidNumOutputs(self): with self.assertRaisesRegexp( - Exception, - "Value for attr 'num_split' of -1 must be at least minimum 1"): + Exception, r"Value for number_attr\(\) -1 < 0 \[Op:Split\]"): array_ops.split(value=[1, 2, 3], num_or_size_splits=-1) with self.assertRaisesRegexp(