Replace NodeDef with std::shared_ptr<NodeProperties> in the kernel creation code paths and try to avoid as many copies of NodeDefs as possible. This will in most cases allow sharing the NodeDef between the OpKernel and the graph Node from which it is created.

This reduces the number of allocations in the executor benchmark by about 8%:

name                                                 old time/op             new time/op             delta
BM_executor/16/1k       [Nodes = 9824  ]              911µs ± 3%              911µs ± 1%    ~     (p=0.548 n=5+5)
BM_executor/32/8k       [Nodes = 141991]             17.1ms ± 2%             16.8ms ± 1%  -2.17%  (p=0.016 n=5+5)
BM_executor/1k/16       [Nodes = 6781  ]             1.21ms ± 1%             1.25ms ± 7%    ~     (p=0.095 n=5+5)
BM_executor/8k/32       [Nodes = 130875]              4.35s ± 0%              4.34s ± 0%    ~     (p=0.841 n=5+5)
BM_executor/1k/1k       [Nodes = 526256]              3.33s ± 1%              3.31s ± 1%    ~     (p=0.095 n=5+5)
BM_FeedInputFetchOutput                              54.0µs ± 7%             56.9µs ±13%    ~     (p=0.222 n=5+5)

name                                                 old allocs/op           new allocs/op           delta
BM_executor/16/1k       [Nodes = 9824  ]              15.4k ± 0%              14.1k ± 0%  -7.95%  (p=0.008 n=5+5)
BM_executor/32/8k       [Nodes = 141991]               226k ± 0%               208k ± 0%  -7.86%  (p=0.008 n=5+5)
BM_executor/1k/16       [Nodes = 6781  ]              10.2k ± 0%               9.3k ± 0%  -8.36%  (p=0.008 n=5+5)
BM_executor/8k/32       [Nodes = 130875]               197k ± 0%               180k ± 0%  -8.31%  (p=0.016 n=4+5)
BM_executor/1k/1k       [Nodes = 526256]               771k ± 0%               706k ± 0%  -8.53%  (p=0.008 n=5+5)
BM_FeedInputFetchOutput                                58.0 ± 0%               57.0 ± 0%  -1.72%  (p=0.008 n=5+5)

PiperOrigin-RevId: 295803318
Change-Id: I0d262c6082822023f449f9817dc943d20bd302d5
This commit is contained in:
A. Unique TensorFlower 2020-02-18 13:05:19 -08:00 committed by TensorFlower Gardener
parent 36fe0e7aad
commit 5f3a3019ba
32 changed files with 597 additions and 286 deletions

View File

@ -20,15 +20,17 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr, bool XlaKernelCreator::CanCreateKernel(
const NodeDef& node_def) const { const FunctionLibraryRuntime& flr,
return CanCreateXlaKernel(node_def); const std::shared_ptr<const NodeProperties>& props) const {
return CanCreateXlaKernel(props->node_def);
} }
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr, Status XlaKernelCreator::CreateKernel(
const NodeDef& node_def, FunctionLibraryRuntime* flr,
const std::shared_ptr<const NodeProperties>& props,
std::unique_ptr<OpKernel>* kernel) const { std::unique_ptr<OpKernel>* kernel) const {
return CreateXlaKernel(flr, node_def, kernel); return CreateXlaKernel(flr, props->node_def, kernel);
} }
namespace { namespace {

View File

@ -29,11 +29,13 @@ class XlaKernelCreator : public CustomKernelCreator {
// Given a NodeDef 'node_def' and the function library runtime 'flr', returns // Given a NodeDef 'node_def' and the function library runtime 'flr', returns
// true if 'node_def' is a call to a compilable function defined in 'flr', // true if 'node_def' is a call to a compilable function defined in 'flr',
// with the kXlaCompileAttr set. // with the kXlaCompileAttr set.
bool CanCreateKernel(const FunctionLibraryRuntime& flr, bool CanCreateKernel(
const NodeDef& node_def) const override; const FunctionLibraryRuntime& flr,
const std::shared_ptr<const NodeProperties>& props) const override;
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node. // Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
Status CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, Status CreateKernel(FunctionLibraryRuntime* flr,
const std::shared_ptr<const NodeProperties>& props,
std::unique_ptr<OpKernel>* kernel) const override; std::unique_ptr<OpKernel>* kernel) const override;
}; };

View File

@ -30,10 +30,12 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
NodeDef ToNodeDef(const string& text) { std::shared_ptr<NodeProperties> ToNodeProperties(const string& text) {
NodeDef node_def; NodeDef node_def;
DataTypeVector dummy;
EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def)); EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
return node_def; return std::make_shared<NodeProperties>(nullptr, std::move(node_def), dummy,
dummy);
} }
// Create a FunctionDef that takes one resource and one regular param // Create a FunctionDef that takes one resource and one regular param
@ -98,11 +100,11 @@ TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true); (*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
Init({fdef}); Init({fdef});
XlaKernelCreator xla_kernel_creator; XlaKernelCreator xla_kernel_creator;
NodeDef callsite = auto callsite =
ToNodeDef(R"pb( ToNodeProperties(R"pb(
name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b' name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
)pb"); )pb");
(*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true); (*(callsite->node_def.mutable_attr()))["_XlaMustCompile"] = BoolAttr(true);
// Note: need to set attribute on the created node. // Note: need to set attribute on the created node.
Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_); Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_);
@ -127,7 +129,8 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) {
Init({fdef}); Init({fdef});
XlaKernelCreator xla_kernel_creator; XlaKernelCreator xla_kernel_creator;
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto( Status status =
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
name: 'XTimesY' name: 'XTimesY'
op: 'XTimesY' op: 'XTimesY'
input: 'a' input: 'a'
@ -143,7 +146,8 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
Init({fdef}); Init({fdef});
XlaKernelCreator xla_kernel_creator; XlaKernelCreator xla_kernel_creator;
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto( Status status =
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
name: 'XTimesY' name: 'XTimesY'
op: 'XTimesY' op: 'XTimesY'
input: 'a' input: 'a'

View File

@ -218,11 +218,12 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function)); TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
Device* dev = flr->device(); Device* dev = flr->device();
Status s; Status s;
OpKernelConstruction construction( auto props = std::make_shared<NodeProperties>(
DeviceType(dev->device_type()), dev, &fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types);
dev->GetAllocator(AllocatorAttributes()), &node_def, OpKernelConstruction construction(DeviceType(dev->device_type()), dev,
&fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types, dev->GetAllocator(AllocatorAttributes()),
input_memory_types, fbody->ret_types, output_memory_types, flr, dev->resource_manager(), props,
input_memory_types, output_memory_types,
flr->graph_def_version(), &s); flr->graph_def_version(), &s);
*kernel = absl::make_unique<XlaLocalLaunchBase>( *kernel = absl::make_unique<XlaLocalLaunchBase>(

View File

@ -127,9 +127,14 @@ class TRTEngineOpTestBase : public OpsTestBase {
private: private:
Status InitOpWithFunctionLibrary() { Status InitOpWithFunctionLibrary() {
OpKernel* kernel = nullptr; OpKernel* kernel = nullptr;
Status status = CreateOpKernel(device_type_, device_, allocator(), auto flr = pflr_->GetFLR(device_->name());
pflr_->GetFLR(device_->name()), node_def_, std::shared_ptr<const NodeProperties> props;
TF_GRAPH_DEF_VERSION, &kernel); Status status = NodeProperties::CreateFromNodeDef(
node_def_, flr->GetFunctionLibraryDefinition(), &props);
if (status.ok()) {
status.Update(CreateOpKernel(device_type_, device_, allocator(), flr,
props, TF_GRAPH_DEF_VERSION, &kernel));
}
kernel_ = std::unique_ptr<OpKernel>(kernel); kernel_ = std::unique_ptr<OpKernel>(kernel);
if (kernel_ != nullptr) input_types_ = kernel_->input_types(); if (kernel_ != nullptr) input_types_ = kernel_->input_types();
return status; return status;

View File

@ -133,7 +133,7 @@ Status GraphCompiler::Compile() {
OpKernel* op_kernel_raw = nullptr; OpKernel* op_kernel_raw = nullptr;
// The kernel is not actually run for functional ops, we just need it // The kernel is not actually run for functional ops, we just need it
// for metadata. // for metadata.
Status s = flib_->CreateKernel(n->def(), &op_kernel_raw); Status s = flib_->CreateKernel(n->properties(), &op_kernel_raw);
// Transfer ownership of the kernel to a local smart pointer. // Transfer ownership of the kernel to a local smart pointer.
std::unique_ptr<OpKernel> op_kernel(op_kernel_raw); std::unique_ptr<OpKernel> op_kernel(op_kernel_raw);

View File

@ -472,6 +472,7 @@ tf_cuda_library(
"//tensorflow/core/framework:memory_types.h", "//tensorflow/core/framework:memory_types.h",
"//tensorflow/core/framework:node_def_builder.h", "//tensorflow/core/framework:node_def_builder.h",
"//tensorflow/core/framework:node_def_util.h", "//tensorflow/core/framework:node_def_util.h",
"//tensorflow/core/framework:node_properties.h",
"//tensorflow/core/framework:numeric_op.h", "//tensorflow/core/framework:numeric_op.h",
"//tensorflow/core/framework:numeric_types.h", "//tensorflow/core/framework:numeric_types.h",
"//tensorflow/core/framework:op.h", "//tensorflow/core/framework:op.h",
@ -2323,6 +2324,7 @@ tf_cuda_library(
"//tensorflow/core/framework:bfloat16", "//tensorflow/core/framework:bfloat16",
"//tensorflow/core/framework:common_shape_fns", "//tensorflow/core/framework:common_shape_fns",
"//tensorflow/core/framework:node_def_util", "//tensorflow/core/framework:node_def_util",
"//tensorflow/core/framework:node_properties",
"//tensorflow/core/framework:numeric_types", "//tensorflow/core/framework:numeric_types",
"//tensorflow/core/framework:op", "//tensorflow/core/framework:op",
"//tensorflow/core/framework:op_def_builder", "//tensorflow/core/framework:op_def_builder",

View File

@ -1356,23 +1356,24 @@ Status DirectSession::CreateExecutors(
params.session_metadata = session_metadata; params.session_metadata = session_metadata;
params.function_library = lib; params.function_library = lib;
auto opseg = device->op_segment(); auto opseg = device->op_segment();
params.create_kernel = [this, lib, opseg](const NodeDef& ndef, params.create_kernel =
[this, lib, opseg](const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) { OpKernel** kernel) {
// NOTE(mrry): We must not share function kernels (implemented // NOTE(mrry): We must not share function kernels (implemented
// using `CallOp`) between subgraphs, because `CallOp::handle_` // using `CallOp`) between subgraphs, because `CallOp::handle_`
// is tied to a particular subgraph. Even if the function itself // is tied to a particular subgraph. Even if the function itself
// is stateful, the `CallOp` that invokes it is not. // is stateful, the `CallOp` that invokes it is not.
if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) { if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) {
return lib->CreateKernel(ndef, kernel); return lib->CreateKernel(props, kernel);
} }
auto create_fn = [lib, &ndef](OpKernel** kernel) { auto create_fn = [lib, &props](OpKernel** kernel) {
return lib->CreateKernel(ndef, kernel); return lib->CreateKernel(props, kernel);
}; };
// Kernels created for subgraph nodes need to be cached. On // Kernels created for subgraph nodes need to be cached. On
// cache miss, create_fn() is invoked to create a kernel based // cache miss, create_fn() is invoked to create a kernel based
// on the function library here + global op registry. // on the function library here + global op registry.
return opseg->FindOrCreate(session_handle_, ndef.name(), kernel, return opseg->FindOrCreate(session_handle_, props->node_def.name(),
create_fn); kernel, create_fn);
}; };
params.delete_kernel = [lib](OpKernel* kernel) { params.delete_kernel = [lib](OpKernel* kernel) {
if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string()))

View File

@ -98,7 +98,10 @@ Status KernelAndDeviceOp::Init(const NodeDef& ndef,
"A valid FunctionLibraryRuntime must be provided when running ops " "A valid FunctionLibraryRuntime must be provided when running ops "
"based on OpKernel."); "based on OpKernel.");
} }
TF_RETURN_IF_ERROR(flr_->CreateKernel(ndef, &k)); std::shared_ptr<const NodeProperties> props;
TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef(
ndef, flr_->GetFunctionLibraryDefinition(), &props));
TF_RETURN_IF_ERROR(flr_->CreateKernel(props, &k));
kernel_.reset(k); kernel_.reset(k);
input_alloc_attrs_.resize(kernel_->num_inputs()); input_alloc_attrs_.resize(kernel_->num_inputs());

View File

@ -654,7 +654,7 @@ Status ExecutorImpl::Initialize(const Graph& graph) {
item->input_start = frame_info->total_inputs; item->input_start = frame_info->total_inputs;
frame_info->total_inputs += n->num_inputs(); frame_info->total_inputs += n->num_inputs();
Status s = params_.create_kernel(n->def(), &item->kernel); Status s = params_.create_kernel(n->properties(), &item->kernel);
if (!s.ok()) { if (!s.ok()) {
item->kernel = nullptr; item->kernel = nullptr;
s = AttachDef(s, *n); s = AttachDef(s, *n);
@ -2974,12 +2974,12 @@ Status NewLocalExecutor(const LocalExecutorParams& params, const Graph& graph,
} }
Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
const NodeDef& ndef, int graph_def_version, const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) { int graph_def_version, OpKernel** kernel) {
const auto device_type = DeviceType(device->attributes().device_type()); const auto device_type = DeviceType(device->attributes().device_type());
auto allocator = device->GetAllocator(AllocatorAttributes()); auto allocator = device->GetAllocator(AllocatorAttributes());
return CreateOpKernel(device_type, device, allocator, flib, return CreateOpKernel(device_type, device, allocator, flib,
device->resource_manager(), ndef, graph_def_version, device->resource_manager(), props, graph_def_version,
kernel); kernel);
} }

View File

@ -145,7 +145,9 @@ struct LocalExecutorParams {
// create_kernel returns an instance of op kernel based on NodeDef. // create_kernel returns an instance of op kernel based on NodeDef.
// delete_kernel is called for every kernel used by the executor // delete_kernel is called for every kernel used by the executor
// when the executor is deleted. // when the executor is deleted.
std::function<Status(const NodeDef&, OpKernel**)> create_kernel; std::function<Status(const std::shared_ptr<const NodeProperties>&,
OpKernel**)>
create_kernel;
std::function<void(OpKernel*)> delete_kernel; std::function<void(OpKernel*)> delete_kernel;
Executor::RendezvousFactory rendezvous_factory; Executor::RendezvousFactory rendezvous_factory;
@ -240,12 +242,12 @@ class ExecutorBarrier {
// A few helpers to facilitate create/delete kernels. // A few helpers to facilitate create/delete kernels.
// Creates a kernel based on "ndef" on device "device". The kernel can // Creates a kernel based on "props" on device "device". The kernel can
// access the functions in the "flib". The caller takes ownership of // access the functions in the "flib". The caller takes ownership of
// returned "*kernel". // returned "*kernel".
Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
const NodeDef& ndef, int graph_def_version, const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel); int graph_def_version, OpKernel** kernel);
// Deletes "kernel" returned by CreateKernel. // Deletes "kernel" returned by CreateKernel.
void DeleteNonCachedKernel(OpKernel* kernel); void DeleteNonCachedKernel(OpKernel* kernel);

View File

@ -61,9 +61,10 @@ class ExecutorTest : public ::testing::Test {
const int version = graph->versions().producer(); const int version = graph->versions().producer();
LocalExecutorParams params; LocalExecutorParams params;
params.device = device_.get(); params.device = device_.get();
params.create_kernel = [this, version](const NodeDef& ndef, params.create_kernel =
[this, version](const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) { OpKernel** kernel) {
return CreateNonCachedKernel(device_.get(), nullptr, ndef, version, return CreateNonCachedKernel(device_.get(), nullptr, props, version,
kernel); kernel);
}; };
params.delete_kernel = [](OpKernel* kernel) { params.delete_kernel = [](OpKernel* kernel) {

View File

@ -187,7 +187,8 @@ class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime {
void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame, void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame,
DoneCallback done) override; DoneCallback done) override;
Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override; Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) override;
bool IsStateful(const string& function_name) const override; bool IsStateful(const string& function_name) const override;
@ -256,7 +257,8 @@ void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle,
base_flr_->Run(opts, handle, call_frame, std::move(done)); base_flr_->Run(opts, handle, call_frame, std::move(done));
} }
Status FunctionLibraryRuntimeOverlay::CreateKernel(const NodeDef&, OpKernel**) { Status FunctionLibraryRuntimeOverlay::CreateKernel(
const std::shared_ptr<const NodeProperties>&, OpKernel**) {
// We don't have access to base_lib_def_ in base function library runtime (aka // We don't have access to base_lib_def_ in base function library runtime (aka
// FunctionLibraryRuntimeImpl), so to make sure we do not create a kernel with // FunctionLibraryRuntimeImpl), so to make sure we do not create a kernel with
// the wrong lib_def we just disable creation of new kernels through overlays. // the wrong lib_def we just disable creation of new kernels through overlays.
@ -344,7 +346,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
Status GetRetTypes(Handle handle, DataTypeVector* ret_types) override; Status GetRetTypes(Handle handle, DataTypeVector* ret_types) override;
Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override; Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) override;
void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args, void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets, DoneCallback done) override; std::vector<Tensor>* rets, DoneCallback done) override;
@ -393,7 +396,9 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
const string device_name_; const string device_name_;
std::function<Status(const string&, const OpDef**)> get_func_sig_; std::function<Status(const string&, const OpDef**)> get_func_sig_;
std::function<Status(const NodeDef&, OpKernel**)> create_kernel_; std::function<Status(const std::shared_ptr<const NodeProperties>&,
OpKernel**)>
create_kernel_;
mutable mutex mu_; mutable mutex mu_;
@ -426,8 +431,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
// to use for kernel creation and execution. In particular, this method can // to use for kernel creation and execution. In particular, this method can
// accept a FunctionLibraryRuntimeOverlay that overlays a different // accept a FunctionLibraryRuntimeOverlay that overlays a different
// FunctionLibraryDefinition. // FunctionLibraryDefinition.
Status CreateKernel(const NodeDef& ndef, FunctionLibraryRuntime* flr, Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel); FunctionLibraryRuntime* flr, OpKernel** kernel);
Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs, Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
const FunctionLibraryDefinition* lib_def, const FunctionLibraryDefinition* lib_def,
std::unique_ptr<FunctionBody>* fbody); std::unique_ptr<FunctionBody>* fbody);
@ -476,8 +481,9 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
get_func_sig_ = [this](const string& op, const OpDef** sig) { get_func_sig_ = [this](const string& op, const OpDef** sig) {
return base_lib_def_->LookUpOpDef(op, sig); return base_lib_def_->LookUpOpDef(op, sig);
}; };
create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) { create_kernel_ = [this](const std::shared_ptr<const NodeProperties>& props,
return CreateKernel(ndef, kernel); OpKernel** kernel) {
return CreateKernel(props, kernel);
}; };
thread::ThreadPool* pool = nullptr; thread::ThreadPool* pool = nullptr;
if (device_ != nullptr) { if (device_ != nullptr) {
@ -589,20 +595,20 @@ Status FunctionLibraryRuntimeImpl::GetRetTypes(Handle h,
return Status::OK(); return Status::OK();
} }
Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, Status FunctionLibraryRuntimeImpl::CreateKernel(
OpKernel** kernel) { const std::shared_ptr<const NodeProperties>& props, OpKernel** kernel) {
return CreateKernel(ndef, this, kernel); return CreateKernel(props, this, kernel);
} }
Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, Status FunctionLibraryRuntimeImpl::CreateKernel(
FunctionLibraryRuntime* flr, const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) { FunctionLibraryRuntime* flr, OpKernel** kernel) {
// If a custom kernel creator is given, try that. // If a custom kernel creator is given, try that.
Status s; Status s;
if (custom_kernel_creator_ != nullptr && if (custom_kernel_creator_ != nullptr &&
custom_kernel_creator_->CanCreateKernel(*this, ndef)) { custom_kernel_creator_->CanCreateKernel(*this, props)) {
std::unique_ptr<OpKernel> ret; std::unique_ptr<OpKernel> ret;
s = custom_kernel_creator_->CreateKernel(this, ndef, &ret); s = custom_kernel_creator_->CreateKernel(this, props, &ret);
if (s.ok()) { if (s.ok()) {
*kernel = ret.release(); *kernel = ret.release();
} else { } else {
@ -613,9 +619,9 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
const FunctionLibraryDefinition* lib_def = const FunctionLibraryDefinition* lib_def =
flr->GetFunctionLibraryDefinition(); flr->GetFunctionLibraryDefinition();
if (lib_def->Find(ndef.op()) == nullptr) { if (lib_def->Find(props->node_def.op()) == nullptr) {
// A primitive operation. Creates the registered kernel. // A primitive operation. Creates the registered kernel.
return CreateNonCachedKernel(device_, flr, ndef, graph_def_version_, return CreateNonCachedKernel(device_, flr, props, graph_def_version_,
kernel); kernel);
} }
@ -626,8 +632,9 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
options.lib_def = lib_def; options.lib_def = lib_def;
} }
Handle handle; Handle handle;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(Instantiate(props->node_def.op(),
Instantiate(ndef.op(), AttrSlice(&ndef.attr()), options, &handle)); AttrSlice(&props->node_def.attr()), options,
&handle));
const FunctionBody* fbody = GetFunctionBody(handle); const FunctionBody* fbody = GetFunctionBody(handle);
CHECK_NOTNULL(fbody); CHECK_NOTNULL(fbody);
@ -647,10 +654,12 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
// Constructs a CallOp kernel for running the instantiated function. // Constructs a CallOp kernel for running the instantiated function.
auto device_type = DeviceType(device_->attributes().device_type()); auto device_type = DeviceType(device_->attributes().device_type());
auto new_props = std::make_shared<NodeProperties>(
&fbody->fdef.signature(), props->node_def, fbody->arg_types,
fbody->ret_types);
OpKernelConstruction construction( OpKernelConstruction construction(
device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef, device_type, device_, device_->GetAllocator(AllocatorAttributes()), flr,
&fbody->fdef.signature(), flr, device_->resource_manager(), device_->resource_manager(), props, input_memory_types,
fbody->arg_types, input_memory_types, fbody->ret_types,
output_memory_types, graph_def_version_, &s); output_memory_types, graph_def_version_, &s);
if (s.ok()) { if (s.ok()) {
*kernel = new CallOp(handle, &construction); *kernel = new CallOp(handle, &construction);
@ -953,8 +962,10 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) {
if (flr == this) { if (flr == this) {
params.create_kernel = create_kernel_; params.create_kernel = create_kernel_;
} else { } else {
params.create_kernel = [this, flr](const NodeDef& ndef, OpKernel** kernel) { params.create_kernel =
return CreateKernel(ndef, flr, kernel); [this, flr](const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) {
return CreateKernel(props, flr, kernel);
}; };
} }
params.delete_kernel = [](OpKernel* kernel) { params.delete_kernel = [](OpKernel* kernel) {

View File

@ -90,9 +90,10 @@ class FunctionTest : public ::testing::Test {
const int version = g->versions().producer(); const int version = g->versions().producer();
LocalExecutorParams params; LocalExecutorParams params;
params.device = device_.get(); params.device = device_.get();
params.create_kernel = [this, version](const NodeDef& ndef, params.create_kernel =
[this, version](const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) { OpKernel** kernel) {
return CreateNonCachedKernel(device_.get(), nullptr, ndef, version, return CreateNonCachedKernel(device_.get(), nullptr, props, version,
kernel); kernel);
}; };
params.delete_kernel = [](OpKernel* kernel) { params.delete_kernel = [](OpKernel* kernel) {

View File

@ -157,9 +157,10 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
params.device = device_; params.device = device_;
params.function_library = function_library; params.function_library = function_library;
const int producer = graph_to_run->versions().producer(); const int producer = graph_to_run->versions().producer();
params.create_kernel = [this, function_library, producer](const NodeDef& ndef, params.create_kernel = [this, function_library, producer](
const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) { OpKernel** kernel) {
return CreateNonCachedKernel(device_, function_library, ndef, producer, return CreateNonCachedKernel(device_, function_library, props, producer,
kernel); kernel);
}; };
params.delete_kernel = [](OpKernel* kernel) { delete kernel; }; params.delete_kernel = [](OpKernel* kernel) { delete kernel; };

View File

@ -84,9 +84,10 @@ Benchmark::Benchmark(const string& device, Graph* g,
LocalExecutorParams params; LocalExecutorParams params;
params.device = device_.get(); params.device = device_.get();
params.function_library = nullptr; params.function_library = nullptr;
params.create_kernel = [this, graph_def_version](const NodeDef& ndef, params.create_kernel = [this, graph_def_version](
const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) { OpKernel** kernel) {
return CreateNonCachedKernel(device_.get(), nullptr, ndef, return CreateNonCachedKernel(device_.get(), nullptr, props,
graph_def_version, kernel); graph_def_version, kernel);
}; };
params.delete_kernel = [](OpKernel* kernel) { params.delete_kernel = [](OpKernel* kernel) {

View File

@ -233,22 +233,24 @@ Status GraphMgr::InitItem(
// Construct the root executor for the subgraph. // Construct the root executor for the subgraph.
params.device = unit->device; params.device = unit->device;
params.function_library = lib; params.function_library = lib;
params.create_kernel = [handle, lib, opseg](const NodeDef& ndef, params.create_kernel =
[handle, lib, opseg](const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) { OpKernel** kernel) {
// NOTE(mrry): We must not share function kernels (implemented // NOTE(mrry): We must not share function kernels (implemented
// using `CallOp`) between subgraphs, because `CallOp::handle_` // using `CallOp`) between subgraphs, because `CallOp::handle_`
// is tied to a particular subgraph. Even if the function itself // is tied to a particular subgraph. Even if the function itself
// is stateful, the `CallOp` that invokes it is not. // is stateful, the `CallOp` that invokes it is not.
if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) { if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) {
return lib->CreateKernel(ndef, kernel); return lib->CreateKernel(props, kernel);
} }
auto create_fn = [lib, &ndef](OpKernel** kernel) { auto create_fn = [lib, &props](OpKernel** kernel) {
return lib->CreateKernel(ndef, kernel); return lib->CreateKernel(props, kernel);
}; };
// Kernels created for subgraph nodes need to be cached. On // Kernels created for subgraph nodes need to be cached. On
// cache miss, create_fn() is invoked to create a kernel based // cache miss, create_fn() is invoked to create a kernel based
// on the function library here + global op registry. // on the function library here + global op registry.
return opseg->FindOrCreate(handle, ndef.name(), kernel, create_fn); return opseg->FindOrCreate(handle, props->node_def.name(), kernel,
create_fn);
}; };
params.delete_kernel = [lib](OpKernel* kernel) { params.delete_kernel = [lib](OpKernel* kernel) {
if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) { if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) {

View File

@ -129,6 +129,7 @@ exports_files(
"attr_value_util.h", "attr_value_util.h",
"common_shape_fns.h", "common_shape_fns.h",
"node_def_util.h", "node_def_util.h",
"node_properties.h",
"op.h", "op.h",
"op_def_builder.h", "op_def_builder.h",
"op_def_util.h", "op_def_util.h",
@ -172,6 +173,7 @@ filegroup(
"model.h", "model.h",
"node_def_builder.h", "node_def_builder.h",
"node_def_util.h", "node_def_util.h",
"node_properties.h",
"numeric_op.h", "numeric_op.h",
"numeric_types.h", "numeric_types.h",
"op.h", "op.h",
@ -338,6 +340,8 @@ filegroup(
"node_def_builder.h", "node_def_builder.h",
"node_def_util.cc", "node_def_util.cc",
"node_def_util.h", "node_def_util.h",
"node_properties.cc",
"node_properties.h",
"numeric_op.h", "numeric_op.h",
"op.cc", "op.cc",
"op.h", "op.h",
@ -862,6 +866,21 @@ cc_library(
], ],
) )
cc_library(
name = "node_properties",
srcs = ["node_properties.cc"],
hdrs = ["node_properties.h"],
deps = [
":node_def_proto_cc",
":node_def_util",
":op",
":op_def_proto_cc",
":tensor",
":types_proto_cc",
"//tensorflow/core/lib/core:status",
],
)
cc_library( cc_library(
name = "op_def_builder", name = "op_def_builder",
srcs = ["op_def_builder.cc"], srcs = ["op_def_builder.cc"],
@ -967,6 +986,7 @@ tf_cc_tests(
"model_test.cc", "model_test.cc",
"node_def_builder_test.cc", "node_def_builder_test.cc",
"node_def_util_test.cc", "node_def_util_test.cc",
"node_properties_test.cc",
"op_compatibility_test.cc", "op_compatibility_test.cc",
"op_def_builder_test.cc", "op_def_builder_test.cc",
"op_def_util_test.cc", "op_def_util_test.cc",

View File

@ -722,11 +722,13 @@ class FunctionLibraryRuntime {
virtual void Run(const Options& opts, Handle handle, virtual void Run(const Options& opts, Handle handle,
CallFrameInterface* call_frame, DoneCallback done) = 0; CallFrameInterface* call_frame, DoneCallback done) = 0;
// Creates a "kernel" for the given node def "ndef". // Creates a "kernel" for the given NodeProperties "props".
// //
// If succeeds, returns OK and the caller takes the ownership of the // If succeeds, returns OK and the caller takes the ownership of the
// returned "*kernel". Otherwise, returns an error. // returned "*kernel". Otherwise, returns an error.
virtual Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) = 0; virtual Status CreateKernel(
const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) = 0;
// Returns true iff the function named `function_name` is stateful. // Returns true iff the function named `function_name` is stateful.
// //
@ -818,11 +820,14 @@ class CustomKernelCreator {
// Given a NodeDef 'node_def' and the function library runtime 'flr', // Given a NodeDef 'node_def' and the function library runtime 'flr',
// validate if the class supports creating such a kernel. // validate if the class supports creating such a kernel.
virtual bool CanCreateKernel(const FunctionLibraryRuntime& flr, virtual bool CanCreateKernel(
const NodeDef& node_def) const = 0; const FunctionLibraryRuntime& flr,
const std::shared_ptr<const NodeProperties>& props) const = 0;
// Given a supported NodeDef, returns a kernel that computes the node. // Given a supported NodeDef, returns a kernel that computes the node.
virtual Status CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& ndef, virtual Status CreateKernel(
FunctionLibraryRuntime* flr,
const std::shared_ptr<const NodeProperties>& props,
std::unique_ptr<OpKernel>* kernel) const = 0; std::unique_ptr<OpKernel>* kernel) const = 0;
}; };

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/node_properties.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
// static
Status NodeProperties::CreateFromNodeDef(
NodeDef node_def, const OpRegistryInterface* op_registry,
std::shared_ptr<const NodeProperties>* props) {
const OpDef* op_def;
TF_RETURN_IF_ERROR(op_registry->LookUpOpDef(node_def.op(), &op_def));
DataTypeVector input_types;
DataTypeVector output_types;
TF_RETURN_IF_ERROR(
InOutTypesForNode(node_def, *op_def, &input_types, &output_types));
props->reset(new NodeProperties(op_def, std::move(node_def),
std::move(input_types),
std::move(output_types)));
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,63 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_PROPERTIES_H_
#define TENSORFLOW_CORE_FRAMEWORK_NODE_PROPERTIES_H_
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class OpRegistryInterface;
struct NodeProperties {
public:
NodeProperties(const OpDef* op_def, NodeDef node_def,
const DataTypeSlice inputs, const DataTypeSlice outputs)
: NodeProperties(op_def, std::move(node_def),
DataTypeVector(inputs.begin(), inputs.end()),
DataTypeVector(outputs.begin(), outputs.end())) {}
NodeProperties(const OpDef* _op_def, NodeDef&& _node_def,
DataTypeVector inputs, DataTypeVector outputs)
: op_def(_op_def),
node_def(std::move(_node_def)),
input_types(std::move(inputs)),
input_types_slice(input_types),
output_types(std::move(outputs)),
output_types_slice(output_types) {}
// Resets the 'props' shared pointer to point to a new NodeProperties created
// from the given NodeDef. 'op_registry' is used to look up the OpDef
// corresponding to node_def.op(). Returns an error if OpDef lookup or
// creation failed.
static Status CreateFromNodeDef(NodeDef node_def,
const OpRegistryInterface* op_registry,
std::shared_ptr<const NodeProperties>* props);
const OpDef* op_def; // not owned.
NodeDef node_def;
DataTypeVector input_types;
DataTypeSlice input_types_slice;
DataTypeVector output_types;
DataTypeSlice output_types_slice;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_NODE_PROPERTIES_H_

View File

@ -0,0 +1,128 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/node_properties.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
OpDef ToOpDef(const OpDefBuilder& builder) {
OpRegistrationData op_reg_data;
EXPECT_TRUE(builder.Finalize(&op_reg_data).ok());
return op_reg_data.op_def;
}
class MockOpRegistry : public OpRegistryInterface {
public:
MockOpRegistry()
: op_reg_(ToOpDef(OpDefBuilder("Foo")
.Input("f: float")
.Input("i: int32")
.Output("of: double"))) {}
~MockOpRegistry() override {}
// Returns an error status and sets *op_reg_data to nullptr if no OpDef is
// registered under that name, otherwise returns the registered OpDef.
// Caller must not delete the returned pointer.
Status LookUp(const string& op_type_name,
const OpRegistrationData** op_reg_data) const override {
if (op_type_name == "Foo") {
*op_reg_data = &op_reg_;
return Status::OK();
} else {
*op_reg_data = nullptr;
return errors::InvalidArgument("Op type named ", op_type_name,
" not found");
}
}
const OpDef* get_op_def_addr() { return &op_reg_.op_def; }
private:
const OpRegistrationData op_reg_;
};
void ValidateNodeProperties(const NodeProperties& props, const OpDef* op_def,
const NodeDef& node_def,
const DataTypeVector& input_types,
const DataTypeVector& output_types) {
EXPECT_EQ(props.op_def, op_def);
EXPECT_EQ(props.node_def.name(), node_def.name());
ASSERT_EQ(props.input_types.size(), input_types.size());
for (int i = 0; i < input_types.size(); ++i) {
EXPECT_EQ(props.input_types[i], input_types[i]);
EXPECT_EQ(props.input_types_slice[i], input_types[i]);
}
ASSERT_EQ(props.output_types.size(), output_types.size());
for (int i = 0; i < output_types.size(); ++i) {
EXPECT_EQ(props.output_types[i], output_types[i]);
EXPECT_EQ(props.output_types_slice[i], output_types[i]);
}
}
} // namespace
TEST(NodeProperties, Contructors) {
OpDef op_def;
NodeDef node_def;
node_def.set_name("foo");
DataTypeVector input_types{DT_FLOAT, DT_INT32};
DataTypeVector output_types{DT_DOUBLE};
DataTypeSlice input_types_slice(input_types);
DataTypeSlice output_types_slice(output_types);
// Construct from slices.
NodeProperties props_from_slices(&op_def, node_def, input_types_slice,
output_types_slice);
ValidateNodeProperties(props_from_slices, &op_def, node_def, input_types,
output_types);
// Construct from vectors.
NodeProperties props_from_vectors(&op_def, node_def, input_types,
output_types);
ValidateNodeProperties(props_from_vectors, &op_def, node_def, input_types,
output_types);
}
TEST(NodeProperties, CreateFromNodeDef) {
MockOpRegistry op_registry;
NodeDef node_def;
node_def.set_name("bar");
node_def.set_op("Foo");
node_def.add_input("f_in");
node_def.add_input("i_in");
std::shared_ptr<const NodeProperties> props;
EXPECT_TRUE(
NodeProperties::CreateFromNodeDef(node_def, &op_registry, &props).ok());
DataTypeVector input_types{DT_FLOAT, DT_INT32};
DataTypeVector output_types{DT_DOUBLE};
ValidateNodeProperties(*props, op_registry.get_op_def_addr(), node_def,
input_types, output_types);
// The OpDef lookup should fail for this one:
node_def.set_op("Baz");
std::shared_ptr<const NodeProperties> props_bad;
EXPECT_FALSE(
NodeProperties::CreateFromNodeDef(node_def, &op_registry, &props_bad)
.ok());
EXPECT_EQ(props_bad, nullptr);
}
} // namespace tensorflow

View File

@ -35,9 +35,9 @@ limitations under the License.
#include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/node_properties.h"
#include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/stringpiece.h"
@ -91,35 +91,53 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
// OpKernel ------------------------------------------------------------------ // OpKernel ------------------------------------------------------------------
OpKernel::OpKernel(OpKernelConstruction* context) OpKernel::OpKernel(OpKernelConstruction* context) : OpKernel(context, false) {}
: OpKernel(context, MakeUnique<const NodeDef>(context->def())) {}
OpKernel::OpKernel(OpKernelConstruction* context, bool is_deferred) OpKernel::OpKernel(OpKernelConstruction* context, bool is_deferred)
: OpKernel(context, MakeUnique<const NodeDef>(context->def()), : props_(context->props_),
is_deferred) {}
OpKernel::OpKernel(OpKernelConstruction* context,
std::unique_ptr<const NodeDef> node_def, bool is_deferred)
: def_(std::move(node_def)),
input_types_(context->input_types().begin(),
context->input_types().end()),
input_memory_types_(context->input_memory_types().begin(), input_memory_types_(context->input_memory_types().begin(),
context->input_memory_types().end()), context->input_memory_types().end()),
output_types_(context->output_types().begin(),
context->output_types().end()),
output_memory_types_(context->output_memory_types().begin(), output_memory_types_(context->output_memory_types().begin(),
context->output_memory_types().end()), context->output_memory_types().end()),
input_name_map_(context->num_inputs()), input_name_map_(context->num_inputs()),
output_name_map_(context->num_outputs()), output_name_map_(context->num_outputs()),
name_view_(def_->name()), name_view_(props_->node_def.name()),
type_string_view_(def_->op()), type_string_view_(props_->node_def.op()),
graph_def_version_(context->graph_def_version()), graph_def_version_(context->graph_def_version()),
is_deferred_(is_deferred), is_deferred_(is_deferred),
cost_estimate_(OpKernel::kInitialCostEstimateCycles) { cost_estimate_(OpKernel::kInitialCostEstimateCycles) {
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
NameRangesForNode(*def_, *context->op_def_, &input_name_map_, NameRangesForNode(props_->node_def, *props_->op_def,
&output_name_map_)); &input_name_map_, &output_name_map_));
OP_REQUIRES_OK(context, CheckOpDeprecation(*context->op_def_, OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def,
context->graph_def_version()));
// Kernels executing on GPU/SYCL tie very few resources on the CPU where the
// scheduler runs: we consider them as inexpensive.
expensive_ = context->device_type() != DeviceType(DEVICE_GPU) &&
context->device_type() != DeviceType(DEVICE_SYCL);
}
OpKernel::OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
bool is_deferred)
: props_(std::make_shared<const NodeProperties>(
context->props_->op_def, std::move(custom_def),
context->props_->input_types, context->props_->output_types)),
input_memory_types_(context->input_memory_types().begin(),
context->input_memory_types().end()),
output_memory_types_(context->output_memory_types().begin(),
context->output_memory_types().end()),
input_name_map_(context->num_inputs()),
output_name_map_(context->num_outputs()),
name_view_(props_->node_def.name()),
type_string_view_(props_->node_def.op()),
graph_def_version_(context->graph_def_version()),
is_deferred_(is_deferred),
cost_estimate_(OpKernel::kInitialCostEstimateCycles) {
OP_REQUIRES_OK(context,
NameRangesForNode(props_->node_def, *props_->op_def,
&input_name_map_, &output_name_map_));
OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def,
context->graph_def_version())); context->graph_def_version()));
// Kernels executing on GPU/SYCL tie very few resources on the CPU where the // Kernels executing on GPU/SYCL tie very few resources on the CPU where the
@ -134,10 +152,6 @@ const uint64 OpKernel::kInitialCostEstimateCycles;
const uint64 OpKernel::kOpIsExpensiveThresholdCycles; const uint64 OpKernel::kOpIsExpensiveThresholdCycles;
const uint64 OpKernel::kCostDecay; const uint64 OpKernel::kCostDecay;
const string& OpKernel::name() const { return def_->name(); }
const string& OpKernel::type_string() const { return def_->op(); }
const string& OpKernel::requested_device() const { return def_->device(); }
const string& OpKernel::requested_input(int i) const { return def_->input(i); }
Status OpKernel::InputRange(StringPiece input_name, int* start, Status OpKernel::InputRange(StringPiece input_name, int* start,
int* stop) const { int* stop) const {
@ -216,22 +230,18 @@ Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) {
OpKernelConstruction::OpKernelConstruction( OpKernelConstruction::OpKernelConstruction(
DeviceType device_type, DeviceBase* device, Allocator* allocator, DeviceType device_type, DeviceBase* device, Allocator* allocator,
const NodeDef* node_def, const OpDef* op_def, FunctionLibraryRuntime* flib, FunctionLibraryRuntime* flib, ResourceMgr* resource_mgr,
ResourceMgr* resource_mgr, const DataTypeSlice& input_types, const std::shared_ptr<const NodeProperties>& props,
const MemoryTypeSlice& input_memory_types, const MemoryTypeSlice& input_memory_types,
const DataTypeSlice& output_types,
const MemoryTypeSlice& output_memory_types, int graph_def_version, const MemoryTypeSlice& output_memory_types, int graph_def_version,
Status* status) Status* status)
: device_type_(std::move(device_type)), : device_type_(std::move(device_type)),
device_(device), device_(device),
allocator_(allocator), allocator_(allocator),
def_(node_def),
op_def_(op_def),
flib_(flib), flib_(flib),
resource_mgr_(resource_mgr), resource_mgr_(resource_mgr),
input_types_(input_types), props_(props),
input_memory_types_(input_memory_types), input_memory_types_(input_memory_types),
output_types_(output_types),
output_memory_types_(output_memory_types), output_memory_types_(output_memory_types),
graph_def_version_(graph_def_version), graph_def_version_(graph_def_version),
status_(status) {} status_(status) {}
@ -246,8 +256,8 @@ void OpKernelConstruction::SetStatus(const Status& status) {
Status OpKernelConstruction::MatchSignature( Status OpKernelConstruction::MatchSignature(
const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) { const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) {
return MatchSignatureHelper(expected_inputs, expected_outputs, input_types_, return MatchSignatureHelper(expected_inputs, expected_outputs,
output_types_); props_->input_types, props_->output_types);
} }
Status OpKernelConstruction::allocate_temp(DataType type, Status OpKernelConstruction::allocate_temp(DataType type,
@ -263,7 +273,7 @@ Status OpKernelConstruction::allocate_temp(DataType type,
} }
if (LogMemory::IsEnabled()) { if (LogMemory::IsEnabled()) {
LogMemory::RecordTensorAllocation( LogMemory::RecordTensorAllocation(
def_->name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp); def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
} }
*out_temp = new_temp; *out_temp = new_temp;
return Status::OK(); return Status::OK();
@ -288,7 +298,7 @@ Status OpKernelConstruction::allocate_temp(DataType type,
} }
if (LogMemory::IsEnabled()) { if (LogMemory::IsEnabled()) {
LogMemory::RecordTensorAllocation( LogMemory::RecordTensorAllocation(
def_->name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp); def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
} }
*out_temp = new_temp; *out_temp = new_temp;
return Status::OK(); return Status::OK();
@ -1544,46 +1554,66 @@ string KernelsRegisteredForOp(StringPiece op_name) {
return ret; return ret;
} }
/* TODO(rmlarsen): This API is deprecated. Remove it if possible to avoid
* copying the NodeDef. */
std::unique_ptr<OpKernel> CreateOpKernel( std::unique_ptr<OpKernel> CreateOpKernel(
DeviceType device_type, DeviceBase* device, Allocator* allocator, DeviceType device_type, DeviceBase* device, Allocator* allocator,
const NodeDef& node_def, int graph_def_version, Status* status) { const NodeDef& node_def, int graph_def_version, Status* status) {
// Look up the Op registered for this op name.
std::shared_ptr<const NodeProperties> props;
status->Update(NodeProperties::CreateFromNodeDef(
node_def, OpRegistry::Global(), &props));
if (!status->ok()) {
errors::AppendToMessage(status,
" for node: ", FormatNodeDefForError(node_def));
return nullptr;
}
return CreateOpKernel(device_type, device, allocator, props,
graph_def_version, status);
}
std::unique_ptr<OpKernel> CreateOpKernel(
DeviceType device_type, DeviceBase* device, Allocator* allocator,
const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
Status* status) {
OpKernel* kernel = nullptr; OpKernel* kernel = nullptr;
*status = CreateOpKernel(std::move(device_type), device, allocator, nullptr, *status = CreateOpKernel(std::move(device_type), device, allocator,
node_def, graph_def_version, &kernel); /*flib=*/nullptr, props, graph_def_version, &kernel);
return std::unique_ptr<OpKernel>(kernel); return std::unique_ptr<OpKernel>(kernel);
} }
Status CreateOpKernel(DeviceType device_type, DeviceBase* device, Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
Allocator* allocator, FunctionLibraryRuntime* flib, Allocator* allocator, FunctionLibraryRuntime* flib,
const NodeDef& node_def, int graph_def_version, const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) { int graph_def_version, OpKernel** kernel) {
return CreateOpKernel(std::move(device_type), device, allocator, flib, return CreateOpKernel(std::move(device_type), device, allocator, flib,
/* resource_mgr= */ nullptr, node_def, /* resource_mgr= */ nullptr, props, graph_def_version,
graph_def_version, kernel); kernel);
} }
Status CreateOpKernel(DeviceType device_type, DeviceBase* device, Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
Allocator* allocator, FunctionLibraryRuntime* flib, Allocator* allocator, FunctionLibraryRuntime* flib,
ResourceMgr* resource_mgr, const NodeDef& node_def, ResourceMgr* resource_mgr,
const std::shared_ptr<const NodeProperties>& props,
int graph_def_version, OpKernel** kernel) { int graph_def_version, OpKernel** kernel) {
const NodeDef& node_def = props->node_def;
bool was_attr_mismatch;
const KernelRegistration* registration = nullptr;
Status s;
if (props != nullptr) {
VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def); VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def);
// Look up the Op registered for this op name.
const OpDef* op_def = nullptr;
TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def));
// Validate node_def against OpDef. // Validate node_def against OpDef.
TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *op_def)); TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *props->op_def));
// Look up kernel registration. // Look up kernel registration.
const KernelRegistration* registration; s = FindKernelRegistration(device_type, node_def, &registration,
bool was_attr_mismatch;
Status s = FindKernelRegistration(device_type, node_def, &registration,
&was_attr_mismatch); &was_attr_mismatch);
if (!s.ok()) { if (!s.ok()) {
errors::AppendToMessage(&s, " when instantiating ", node_def.op()); errors::AppendToMessage(&s, " when instantiating ", node_def.op());
return s; return s;
} }
}
if (registration == nullptr) { if (registration == nullptr) {
s.Update(errors::NotFound("No registered '", node_def.op(), s.Update(errors::NotFound("No registered '", node_def.op(),
"' OpKernel for '", DeviceTypeString(device_type), "' OpKernel for '", DeviceTypeString(device_type),
@ -1599,15 +1629,6 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
return s; return s;
} }
// Get signature from the OpDef & NodeDef
DataTypeVector inputs;
DataTypeVector outputs;
s.Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs));
if (!s.ok()) {
errors::AppendToMessage(&s, " for node: ", FormatNodeDefForError(node_def));
return s;
}
// We are creating a kernel for an op registered in // We are creating a kernel for an op registered in
// OpRegistry::Global(), we consult the kernel registry to decide // OpRegistry::Global(), we consult the kernel registry to decide
// the kernel's input and output memory types. // the kernel's input and output memory types.
@ -1618,10 +1639,9 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
&output_memory_types)); &output_memory_types));
// Everything needed for OpKernel construction. // Everything needed for OpKernel construction.
OpKernelConstruction context(std::move(device_type), device, allocator, OpKernelConstruction context(std::move(device_type), device, allocator, flib,
&node_def, op_def, flib, resource_mgr, inputs, resource_mgr, props, input_memory_types,
input_memory_types, outputs, output_memory_types, output_memory_types, graph_def_version, &s);
graph_def_version, &s);
*kernel = registration->factory->Create(&context); *kernel = registration->factory->Create(&context);
if (!s.ok()) { if (!s.ok()) {
delete *kernel; delete *kernel;

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/node_properties.h"
#include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/selective_registration.h" #include "tensorflow/core/framework/selective_registration.h"
@ -85,19 +86,18 @@ class OpKernel {
// expensive initialization in the descendant's constructor. // expensive initialization in the descendant's constructor.
explicit OpKernel(OpKernelConstruction* context); explicit OpKernel(OpKernelConstruction* context);
// Specialized constructor that enables the descendant to provide a different
// `NodeDef` value. For example, this constructor can be used to provide a
// stripped-down `NodeDef` that does not contain the full set of attrs (such
// as tensor values) if the descendant stores them in a different form.
explicit OpKernel(OpKernelConstruction* context,
std::unique_ptr<const NodeDef> node_def,
bool is_deferred = false);
// Specialized constructor that allows a kernel implementation to mark itself // Specialized constructor that allows a kernel implementation to mark itself
// as a "deferred" op. If true, the executor will provide access to the // as a "deferred" op. If true, the executor will provide access to the
// `OpKernelContext::inc_num_deferred_ops_function()` and // `OpKernelContext::inc_num_deferred_ops_function()` and
// `OpKernelContext::dec_num_deferred_ops_function()` methods at run-time. // `OpKernelContext::dec_num_deferred_ops_function()` methods at run-time.
explicit OpKernel(OpKernelConstruction* context, bool is_deferred); OpKernel(OpKernelConstruction* context, bool is_deferred);
// Specialized constructor that enables the descendant to provide a custom
// `NodeDef` value. For example, this constructor can be used to provide a
// stripped-down `NodeDef` that does not contain the full set of attrs (such
// as tensor values) if the descendant stores them in a different form.
OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
bool is_deferred);
virtual ~OpKernel(); virtual ~OpKernel();
@ -170,24 +170,26 @@ class OpKernel {
} }
// Accessors. // Accessors.
const NodeDef& def() const { return *def_; } const NodeDef& def() const { return props_->node_def; }
const string& name() const; // Same as def().name() const string& name() const { return props_->node_def.name(); }
absl::string_view name_view() const { return name_view_; } absl::string_view name_view() const { return name_view_; }
const string& type_string() const; // Same as def().op() const string& type_string() const { return props_->node_def.op(); }
absl::string_view type_string_view() const { return type_string_view_; } absl::string_view type_string_view() const { return type_string_view_; }
const string& requested_device() const; // Same as def().device() const string& requested_input(int i) const {
return props_->node_def.input(i);
}
const string& requested_device() const { return props_->node_def.device(); }
int num_inputs() const { return input_types_.size(); } int num_inputs() const { return props_->input_types.size(); }
DataType input_type(int i) const { return input_types_[i]; } DataType input_type(int i) const { return props_->input_types[i]; }
const DataTypeVector& input_types() const { return input_types_; } const DataTypeVector& input_types() const { return props_->input_types; }
const MemoryTypeVector& input_memory_types() const { const MemoryTypeVector& input_memory_types() const {
return input_memory_types_; return input_memory_types_;
} }
const string& requested_input(int i) const; // Same as def().input(i)
int num_outputs() const { return output_types_.size(); } int num_outputs() const { return props_->output_types.size(); }
DataType output_type(int o) const { return output_types_[o]; } DataType output_type(int o) const { return props_->output_types[o]; }
const DataTypeVector& output_types() const { return output_types_; } const DataTypeVector& output_types() const { return props_->output_types; }
const MemoryTypeVector& output_memory_types() const { const MemoryTypeVector& output_memory_types() const {
return output_memory_types_; return output_memory_types_;
} }
@ -209,10 +211,8 @@ class OpKernel {
string GetTraceArgument(OpKernelContext* ctx); string GetTraceArgument(OpKernelContext* ctx);
private: private:
const std::unique_ptr<const NodeDef> def_; const std::shared_ptr<const NodeProperties> props_;
const DataTypeVector input_types_;
const MemoryTypeVector input_memory_types_; const MemoryTypeVector input_memory_types_;
const DataTypeVector output_types_;
const MemoryTypeVector output_memory_types_; const MemoryTypeVector output_memory_types_;
NameRangeMap input_name_map_; NameRangeMap input_name_map_;
NameRangeMap output_name_map_; NameRangeMap output_name_map_;
@ -284,12 +284,10 @@ class PersistentTensor {
class OpKernelConstruction { class OpKernelConstruction {
public: public:
OpKernelConstruction(DeviceType device_type, DeviceBase* device, OpKernelConstruction(DeviceType device_type, DeviceBase* device,
Allocator* allocator, const NodeDef* node_def, Allocator* allocator, FunctionLibraryRuntime* flib,
const OpDef* op_def, FunctionLibraryRuntime* flib,
ResourceMgr* resource_mgr, ResourceMgr* resource_mgr,
const DataTypeSlice& input_types, const std::shared_ptr<const NodeProperties>& props,
const MemoryTypeSlice& input_memory_types, const MemoryTypeSlice& input_memory_types,
const DataTypeSlice& output_types,
const MemoryTypeSlice& output_memory_types, const MemoryTypeSlice& output_memory_types,
int graph_def_version, Status* status); int graph_def_version, Status* status);
@ -330,20 +328,22 @@ class OpKernelConstruction {
Tensor** out_tensor); Tensor** out_tensor);
// User-supplied configuration of this operation. // User-supplied configuration of this operation.
const NodeDef& def() const { return *def_; } const NodeDef& def() const { return props_->node_def; }
// For inspecting the inputs to this operation. // For inspecting the inputs to this operation.
int num_inputs() const { return input_types_.size(); } int num_inputs() const { return props_->input_types.size(); }
DataType input_type(int i) const { return input_types_[i]; } DataType input_type(int i) const { return props_->input_types[i]; }
const DataTypeSlice& input_types() const { return input_types_; } const DataTypeSlice& input_types() const { return props_->input_types_slice; }
const MemoryTypeSlice& input_memory_types() const { const MemoryTypeSlice& input_memory_types() const {
return input_memory_types_; return input_memory_types_;
} }
// For inspecting the outputs expected from this operation. // For inspecting the outputs expected from this operation.
int num_outputs() const { return output_types_.size(); } int num_outputs() const { return props_->output_types.size(); }
DataType output_type(int i) const { return output_types_[i]; } DataType output_type(int i) const { return props_->output_types[i]; }
const DataTypeSlice& output_types() const { return output_types_; } const DataTypeSlice& output_types() const {
return props_->output_types_slice;
}
const MemoryTypeSlice& output_memory_types() const { const MemoryTypeSlice& output_memory_types() const {
return output_memory_types_; return output_memory_types_;
} }
@ -403,19 +403,15 @@ class OpKernelConstruction {
const DeviceType device_type_; const DeviceType device_type_;
DeviceBase* const device_; DeviceBase* const device_;
Allocator* allocator_; Allocator* allocator_;
const NodeDef* def_;
const OpDef* op_def_;
FunctionLibraryRuntime* flib_; FunctionLibraryRuntime* flib_;
ResourceMgr* const resource_mgr_; ResourceMgr* const resource_mgr_;
DataTypeSlice input_types_; std::shared_ptr<const NodeProperties> props_;
MemoryTypeSlice input_memory_types_; MemoryTypeSlice input_memory_types_;
DataTypeSlice output_types_;
MemoryTypeSlice output_memory_types_; MemoryTypeSlice output_memory_types_;
const int graph_def_version_; const int graph_def_version_;
Status* status_; Status* status_;
// Allow op_def_ across from OpKernel, but not from subclasses. // Allow access from OpKernel ctor.
// TODO(irving): Remove protos from this header entirely.
friend class OpKernel; friend class OpKernel;
TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction); TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
@ -1404,15 +1400,23 @@ const Eigen::SyclDevice& OpKernelContext::eigen_device() const;
std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type, std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
DeviceBase* device, DeviceBase* device,
Allocator* allocator, Allocator* allocator,
const NodeDef& def, const NodeDef& node_def,
int graph_def_version, Status* status); int graph_def_version, Status* status);
std::unique_ptr<OpKernel> CreateOpKernel(
DeviceType device_type, DeviceBase* device, Allocator* allocator,
const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
Status* status);
Status CreateOpKernel(DeviceType device_type, DeviceBase* device, Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
Allocator* allocator, FunctionLibraryRuntime* flib, Allocator* allocator, FunctionLibraryRuntime* flib,
const NodeDef& def, int graph_def_version, const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel); int graph_def_version, OpKernel** kernel);
Status CreateOpKernel(DeviceType device_type, DeviceBase* device, Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
Allocator* allocator, FunctionLibraryRuntime* flib, Allocator* allocator, FunctionLibraryRuntime* flib,
ResourceMgr* resource_mgr, const NodeDef& def, ResourceMgr* resource_mgr,
const std::shared_ptr<const NodeProperties>& props,
int graph_def_version, OpKernel** kernel); int graph_def_version, OpKernel** kernel);
// Returns into 'device_types' the subset of prioritized_types that this // Returns into 'device_types' the subset of prioritized_types that this

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"

View File

@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/node_properties.h"
#include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/framework/versions.pb.h"
@ -37,23 +37,7 @@ namespace tensorflow {
const int Graph::kControlSlot = -1; const int Graph::kControlSlot = -1;
struct NodeProperties {
public:
NodeProperties(const OpDef* op_def, NodeDef node_def,
const DataTypeSlice inputs, const DataTypeSlice outputs)
: op_def(op_def),
node_def(std::move(node_def)),
input_types(inputs.begin(), inputs.end()),
output_types(outputs.begin(), outputs.end()) {}
const OpDef* op_def; // not owned
NodeDef node_def;
const DataTypeVector input_types;
const DataTypeVector output_types;
};
// Node // Node
#define REF_CLASS(key, value) \ #define REF_CLASS(key, value) \
{key, value}, { "Ref" key, value } {key, value}, { "Ref" key, value }
@ -97,7 +81,8 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
{"StatelessIf", NC_IF}, {"StatelessIf", NC_IF},
{"While", NC_WHILE}, {"While", NC_WHILE},
{"StatelessWhile", NC_WHILE}, {"StatelessWhile", NC_WHILE},
// Not using the constants defined in FunctionLibraryDefinition for the // Not using the constants defined in FunctionLibraryDefinition
// for the
// 4 ops below because android inference library does not link // 4 ops below because android inference library does not link
// tf.function related files. // tf.function related files.
{"_Arg", NC_ARG}, {"_Arg", NC_ARG},

View File

@ -43,6 +43,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/edgeset.h" #include "tensorflow/core/graph/edgeset.h"
@ -67,7 +68,6 @@ class WhileContext;
class NeighborIter; // Declared below class NeighborIter; // Declared below
class NodeIter; // Declared below class NodeIter; // Declared below
struct NodeProperties; // Defined in .cc
class Node { class Node {
public: public:
@ -229,11 +229,12 @@ class Node {
while_ctx_ = while_ctx; while_ctx_ = while_ctx;
} }
std::shared_ptr<NodeProperties> properties() const { return props_; }
private: private:
friend class Graph; friend class Graph;
Node(); Node();
NodeProperties* properties() const { return props_.get(); }
void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props, void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props,
bool is_function_op); bool is_function_op);

View File

@ -47,31 +47,30 @@ namespace tensorflow {
namespace { namespace {
std::unique_ptr<const NodeDef> StripTensorDataFromNodeDef( NodeDef StripTensorDataFromNodeDef(OpKernelConstruction* ctx) {
OpKernelConstruction* ctx) {
#ifndef __ANDROID__ #ifndef __ANDROID__
DCHECK_EQ(NodeDef::descriptor()->field_count(), 6) DCHECK_EQ(NodeDef::descriptor()->field_count(), 6)
<< "The NodeDef format has changed, and the attr-stripping code may need " << "The NodeDef format has changed, and the attr-stripping code may need "
<< "to be updated."; << "to be updated.";
#endif #endif
const NodeDef& original = ctx->def(); const NodeDef& original = ctx->def();
NodeDef* ret = new NodeDef; NodeDef ret;
ret->set_name(original.name()); ret.set_name(original.name());
ret->set_op(original.op()); ret.set_op(original.op());
ret->set_device(original.device()); ret.set_device(original.device());
// Strip the "value" attr from the returned NodeDef. // Strip the "value" attr from the returned NodeDef.
// NOTE(mrry): The present implementation of `OpKernel::OpKernel()` only uses // NOTE(mrry): The present implementation of `OpKernel::OpKernel()` only uses
// attrs that affect the cardinality of list-typed inputs and outputs, so it // attrs that affect the cardinality of list-typed inputs and outputs, so it
// is safe to drop other attrs from the NodeDef. // is safe to drop other attrs from the NodeDef.
AddNodeAttr("dtype", ctx->output_type(0), ret); AddNodeAttr("dtype", ctx->output_type(0), &ret);
MergeDebugInfo(original, ret); MergeDebugInfo(original, &ret);
return std::unique_ptr<const NodeDef>(ret); return ret;
} }
} // namespace } // namespace
ConstantOp::ConstantOp(OpKernelConstruction* ctx) ConstantOp::ConstantOp(OpKernelConstruction* ctx)
: OpKernel(ctx, StripTensorDataFromNodeDef(ctx)), : OpKernel(ctx, StripTensorDataFromNodeDef(ctx), false),
tensor_(ctx->output_type(0)) { tensor_(ctx->output_type(0)) {
const TensorProto* proto = nullptr; const TensorProto* proto = nullptr;
MEMDEBUG_CACHE_OP(ctx->def().name().c_str()); MEMDEBUG_CACHE_OP(ctx->def().name().c_str());

View File

@ -304,9 +304,14 @@ Status DatasetOpsTestBase::ExpectEqual(std::vector<Tensor> produced_tensors,
Status DatasetOpsTestBase::CreateOpKernel( Status DatasetOpsTestBase::CreateOpKernel(
const NodeDef& node_def, std::unique_ptr<OpKernel>* op_kernel) { const NodeDef& node_def, std::unique_ptr<OpKernel>* op_kernel) {
OpKernel* kernel; OpKernel* kernel;
Status s;
std::shared_ptr<const NodeProperties> props;
TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef(
node_def, flr_->GetFunctionLibraryDefinition(), &props));
TF_RETURN_IF_ERROR(tensorflow::CreateOpKernel( TF_RETURN_IF_ERROR(tensorflow::CreateOpKernel(
device_type_, device_.get(), allocator_, flr_, device_type_, device_.get(), allocator_, flr_,
device_->resource_manager(), node_def, TF_GRAPH_DEF_VERSION, &kernel)); device_->resource_manager(), props, TF_GRAPH_DEF_VERSION, &kernel));
op_kernel->reset(kernel); op_kernel->reset(kernel);
return Status::OK(); return Status::OK();
} }
@ -435,9 +440,10 @@ Status DatasetOpsTestBase::RunFunction(
LocalExecutorParams params; LocalExecutorParams params;
params.function_library = flr_; params.function_library = flr_;
params.device = device_.get(); params.device = device_.get();
params.create_kernel = [this, version](const NodeDef& ndef, params.create_kernel = [this, version](
const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) { OpKernel** kernel) {
return CreateNonCachedKernel(device_.get(), this->flr_, ndef, version, return CreateNonCachedKernel(device_.get(), this->flr_, props, version,
kernel); kernel);
}; };
params.delete_kernel = [](OpKernel* kernel) { params.delete_kernel = [](OpKernel* kernel) {

View File

@ -108,7 +108,8 @@ class SingleThreadedExecutorImpl : public Executor {
KernelState& kernel_state = kernels_[kernel_index]; KernelState& kernel_state = kernels_[kernel_index];
node_to_index_map[n] = kernel_index; node_to_index_map[n] = kernel_index;
TF_RETURN_IF_ERROR(params_.create_kernel(n->def(), &kernel_state.kernel)); TF_RETURN_IF_ERROR(
params_.create_kernel(n->properties(), &kernel_state.kernel));
kernel_state.num_inputs = n->num_inputs(); kernel_state.num_inputs = n->num_inputs();
kernel_state.num_outputs = n->num_outputs(); kernel_state.num_outputs = n->num_outputs();

View File

@ -58,9 +58,10 @@ class ExecutorTest : public ::testing::Test {
const int version = graph->versions().producer(); const int version = graph->versions().producer();
LocalExecutorParams params; LocalExecutorParams params;
params.device = device_.get(); params.device = device_.get();
params.create_kernel = [this, version](const NodeDef& ndef, params.create_kernel =
[this, version](const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) { OpKernel** kernel) {
return CreateNonCachedKernel(device_.get(), nullptr, ndef, version, return CreateNonCachedKernel(device_.get(), nullptr, props, version,
kernel); kernel);
}; };
params.delete_kernel = [](OpKernel* kernel) { params.delete_kernel = [](OpKernel* kernel) {

View File

@ -237,8 +237,7 @@ class Tests(test.TestCase):
@test_util.assert_no_garbage_created @test_util.assert_no_garbage_created
def testInvalidNumOutputs(self): def testInvalidNumOutputs(self):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
Exception, Exception, r"Value for number_attr\(\) -1 < 0 \[Op:Split\]"):
"Value for attr 'num_split' of -1 must be at least minimum 1"):
array_ops.split(value=[1, 2, 3], num_or_size_splits=-1) array_ops.split(value=[1, 2, 3], num_or_size_splits=-1)
with self.assertRaisesRegexp( with self.assertRaisesRegexp(