Replace NodeDef with std::shared_ptr<NodeProperties> in the kernel creation code paths and try to avoid as many copies of NodeDefs as possible. This will in most cases allow sharing the NodeDef between the OpKernel and the graph Node from which it is created.
This reduces the number of allocations in the executor benchmark by about 8%: name old time/op new time/op delta BM_executor/16/1k [Nodes = 9824 ] 911µs ± 3% 911µs ± 1% ~ (p=0.548 n=5+5) BM_executor/32/8k [Nodes = 141991] 17.1ms ± 2% 16.8ms ± 1% -2.17% (p=0.016 n=5+5) BM_executor/1k/16 [Nodes = 6781 ] 1.21ms ± 1% 1.25ms ± 7% ~ (p=0.095 n=5+5) BM_executor/8k/32 [Nodes = 130875] 4.35s ± 0% 4.34s ± 0% ~ (p=0.841 n=5+5) BM_executor/1k/1k [Nodes = 526256] 3.33s ± 1% 3.31s ± 1% ~ (p=0.095 n=5+5) BM_FeedInputFetchOutput 54.0µs ± 7% 56.9µs ±13% ~ (p=0.222 n=5+5) name old allocs/op new allocs/op delta BM_executor/16/1k [Nodes = 9824 ] 15.4k ± 0% 14.1k ± 0% -7.95% (p=0.008 n=5+5) BM_executor/32/8k [Nodes = 141991] 226k ± 0% 208k ± 0% -7.86% (p=0.008 n=5+5) BM_executor/1k/16 [Nodes = 6781 ] 10.2k ± 0% 9.3k ± 0% -8.36% (p=0.008 n=5+5) BM_executor/8k/32 [Nodes = 130875] 197k ± 0% 180k ± 0% -8.31% (p=0.016 n=4+5) BM_executor/1k/1k [Nodes = 526256] 771k ± 0% 706k ± 0% -8.53% (p=0.008 n=5+5) BM_FeedInputFetchOutput 58.0 ± 0% 57.0 ± 0% -1.72% (p=0.008 n=5+5) PiperOrigin-RevId: 295803318 Change-Id: I0d262c6082822023f449f9817dc943d20bd302d5
This commit is contained in:
parent
36fe0e7aad
commit
5f3a3019ba
|
@ -20,15 +20,17 @@ limitations under the License.
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def) const {
|
||||
return CanCreateXlaKernel(node_def);
|
||||
bool XlaKernelCreator::CanCreateKernel(
|
||||
const FunctionLibraryRuntime& flr,
|
||||
const std::shared_ptr<const NodeProperties>& props) const {
|
||||
return CanCreateXlaKernel(props->node_def);
|
||||
}
|
||||
|
||||
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr,
|
||||
const NodeDef& node_def,
|
||||
std::unique_ptr<OpKernel>* kernel) const {
|
||||
return CreateXlaKernel(flr, node_def, kernel);
|
||||
Status XlaKernelCreator::CreateKernel(
|
||||
FunctionLibraryRuntime* flr,
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
std::unique_ptr<OpKernel>* kernel) const {
|
||||
return CreateXlaKernel(flr, props->node_def, kernel);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
|
@ -29,11 +29,13 @@ class XlaKernelCreator : public CustomKernelCreator {
|
|||
// Given a NodeDef 'node_def' and the function library runtime 'flr', returns
|
||||
// true if 'node_def' is a call to a compilable function defined in 'flr',
|
||||
// with the kXlaCompileAttr set.
|
||||
bool CanCreateKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def) const override;
|
||||
bool CanCreateKernel(
|
||||
const FunctionLibraryRuntime& flr,
|
||||
const std::shared_ptr<const NodeProperties>& props) const override;
|
||||
|
||||
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
|
||||
Status CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
Status CreateKernel(FunctionLibraryRuntime* flr,
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
std::unique_ptr<OpKernel>* kernel) const override;
|
||||
};
|
||||
|
||||
|
|
|
@ -30,10 +30,12 @@ limitations under the License.
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
NodeDef ToNodeDef(const string& text) {
|
||||
std::shared_ptr<NodeProperties> ToNodeProperties(const string& text) {
|
||||
NodeDef node_def;
|
||||
DataTypeVector dummy;
|
||||
EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
|
||||
return node_def;
|
||||
return std::make_shared<NodeProperties>(nullptr, std::move(node_def), dummy,
|
||||
dummy);
|
||||
}
|
||||
|
||||
// Create a FunctionDef that takes one resource and one regular param
|
||||
|
@ -98,11 +100,11 @@ TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
|
|||
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
|
||||
Init({fdef});
|
||||
XlaKernelCreator xla_kernel_creator;
|
||||
NodeDef callsite =
|
||||
ToNodeDef(R"pb(
|
||||
auto callsite =
|
||||
ToNodeProperties(R"pb(
|
||||
name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
|
||||
)pb");
|
||||
(*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
|
||||
(*(callsite->node_def.mutable_attr()))["_XlaMustCompile"] = BoolAttr(true);
|
||||
|
||||
// Note: need to set attribute on the created node.
|
||||
Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_);
|
||||
|
@ -127,13 +129,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) {
|
|||
Init({fdef});
|
||||
XlaKernelCreator xla_kernel_creator;
|
||||
|
||||
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
|
||||
name: 'XTimesY'
|
||||
op: 'XTimesY'
|
||||
input: 'a'
|
||||
input: 'b'
|
||||
)proto"),
|
||||
&kernel_);
|
||||
Status status =
|
||||
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
|
||||
name: 'XTimesY'
|
||||
op: 'XTimesY'
|
||||
input: 'a'
|
||||
input: 'b'
|
||||
)proto"),
|
||||
&kernel_);
|
||||
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
|
||||
}
|
||||
|
||||
|
@ -143,13 +146,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
|
|||
Init({fdef});
|
||||
XlaKernelCreator xla_kernel_creator;
|
||||
|
||||
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
|
||||
name: 'XTimesY'
|
||||
op: 'XTimesY'
|
||||
input: 'a'
|
||||
input: 'b'
|
||||
)proto"),
|
||||
&kernel_);
|
||||
Status status =
|
||||
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
|
||||
name: 'XTimesY'
|
||||
op: 'XTimesY'
|
||||
input: 'a'
|
||||
input: 'b'
|
||||
)proto"),
|
||||
&kernel_);
|
||||
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
|
||||
}
|
||||
|
||||
|
|
|
@ -218,12 +218,13 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
|||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||
Device* dev = flr->device();
|
||||
Status s;
|
||||
OpKernelConstruction construction(
|
||||
DeviceType(dev->device_type()), dev,
|
||||
dev->GetAllocator(AllocatorAttributes()), &node_def,
|
||||
&fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types,
|
||||
input_memory_types, fbody->ret_types, output_memory_types,
|
||||
flr->graph_def_version(), &s);
|
||||
auto props = std::make_shared<NodeProperties>(
|
||||
&fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types);
|
||||
OpKernelConstruction construction(DeviceType(dev->device_type()), dev,
|
||||
dev->GetAllocator(AllocatorAttributes()),
|
||||
flr, dev->resource_manager(), props,
|
||||
input_memory_types, output_memory_types,
|
||||
flr->graph_def_version(), &s);
|
||||
|
||||
*kernel = absl::make_unique<XlaLocalLaunchBase>(
|
||||
&construction, constant_arg_indices, resource_arg_indices, function,
|
||||
|
|
|
@ -127,9 +127,14 @@ class TRTEngineOpTestBase : public OpsTestBase {
|
|||
private:
|
||||
Status InitOpWithFunctionLibrary() {
|
||||
OpKernel* kernel = nullptr;
|
||||
Status status = CreateOpKernel(device_type_, device_, allocator(),
|
||||
pflr_->GetFLR(device_->name()), node_def_,
|
||||
TF_GRAPH_DEF_VERSION, &kernel);
|
||||
auto flr = pflr_->GetFLR(device_->name());
|
||||
std::shared_ptr<const NodeProperties> props;
|
||||
Status status = NodeProperties::CreateFromNodeDef(
|
||||
node_def_, flr->GetFunctionLibraryDefinition(), &props);
|
||||
if (status.ok()) {
|
||||
status.Update(CreateOpKernel(device_type_, device_, allocator(), flr,
|
||||
props, TF_GRAPH_DEF_VERSION, &kernel));
|
||||
}
|
||||
kernel_ = std::unique_ptr<OpKernel>(kernel);
|
||||
if (kernel_ != nullptr) input_types_ = kernel_->input_types();
|
||||
return status;
|
||||
|
|
|
@ -133,7 +133,7 @@ Status GraphCompiler::Compile() {
|
|||
OpKernel* op_kernel_raw = nullptr;
|
||||
// The kernel is not actually run for functional ops, we just need it
|
||||
// for metadata.
|
||||
Status s = flib_->CreateKernel(n->def(), &op_kernel_raw);
|
||||
Status s = flib_->CreateKernel(n->properties(), &op_kernel_raw);
|
||||
// Transfer ownership of the kernel to a local smart pointer.
|
||||
std::unique_ptr<OpKernel> op_kernel(op_kernel_raw);
|
||||
|
||||
|
|
|
@ -472,6 +472,7 @@ tf_cuda_library(
|
|||
"//tensorflow/core/framework:memory_types.h",
|
||||
"//tensorflow/core/framework:node_def_builder.h",
|
||||
"//tensorflow/core/framework:node_def_util.h",
|
||||
"//tensorflow/core/framework:node_properties.h",
|
||||
"//tensorflow/core/framework:numeric_op.h",
|
||||
"//tensorflow/core/framework:numeric_types.h",
|
||||
"//tensorflow/core/framework:op.h",
|
||||
|
@ -2323,6 +2324,7 @@ tf_cuda_library(
|
|||
"//tensorflow/core/framework:bfloat16",
|
||||
"//tensorflow/core/framework:common_shape_fns",
|
||||
"//tensorflow/core/framework:node_def_util",
|
||||
"//tensorflow/core/framework:node_properties",
|
||||
"//tensorflow/core/framework:numeric_types",
|
||||
"//tensorflow/core/framework:op",
|
||||
"//tensorflow/core/framework:op_def_builder",
|
||||
|
|
|
@ -1356,24 +1356,25 @@ Status DirectSession::CreateExecutors(
|
|||
params.session_metadata = session_metadata;
|
||||
params.function_library = lib;
|
||||
auto opseg = device->op_segment();
|
||||
params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
|
||||
OpKernel** kernel) {
|
||||
// NOTE(mrry): We must not share function kernels (implemented
|
||||
// using `CallOp`) between subgraphs, because `CallOp::handle_`
|
||||
// is tied to a particular subgraph. Even if the function itself
|
||||
// is stateful, the `CallOp` that invokes it is not.
|
||||
if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) {
|
||||
return lib->CreateKernel(ndef, kernel);
|
||||
}
|
||||
auto create_fn = [lib, &ndef](OpKernel** kernel) {
|
||||
return lib->CreateKernel(ndef, kernel);
|
||||
};
|
||||
// Kernels created for subgraph nodes need to be cached. On
|
||||
// cache miss, create_fn() is invoked to create a kernel based
|
||||
// on the function library here + global op registry.
|
||||
return opseg->FindOrCreate(session_handle_, ndef.name(), kernel,
|
||||
create_fn);
|
||||
};
|
||||
params.create_kernel =
|
||||
[this, lib, opseg](const std::shared_ptr<const NodeProperties>& props,
|
||||
OpKernel** kernel) {
|
||||
// NOTE(mrry): We must not share function kernels (implemented
|
||||
// using `CallOp`) between subgraphs, because `CallOp::handle_`
|
||||
// is tied to a particular subgraph. Even if the function itself
|
||||
// is stateful, the `CallOp` that invokes it is not.
|
||||
if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) {
|
||||
return lib->CreateKernel(props, kernel);
|
||||
}
|
||||
auto create_fn = [lib, &props](OpKernel** kernel) {
|
||||
return lib->CreateKernel(props, kernel);
|
||||
};
|
||||
// Kernels created for subgraph nodes need to be cached. On
|
||||
// cache miss, create_fn() is invoked to create a kernel based
|
||||
// on the function library here + global op registry.
|
||||
return opseg->FindOrCreate(session_handle_, props->node_def.name(),
|
||||
kernel, create_fn);
|
||||
};
|
||||
params.delete_kernel = [lib](OpKernel* kernel) {
|
||||
if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string()))
|
||||
delete kernel;
|
||||
|
|
|
@ -98,7 +98,10 @@ Status KernelAndDeviceOp::Init(const NodeDef& ndef,
|
|||
"A valid FunctionLibraryRuntime must be provided when running ops "
|
||||
"based on OpKernel.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(flr_->CreateKernel(ndef, &k));
|
||||
std::shared_ptr<const NodeProperties> props;
|
||||
TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef(
|
||||
ndef, flr_->GetFunctionLibraryDefinition(), &props));
|
||||
TF_RETURN_IF_ERROR(flr_->CreateKernel(props, &k));
|
||||
kernel_.reset(k);
|
||||
|
||||
input_alloc_attrs_.resize(kernel_->num_inputs());
|
||||
|
|
|
@ -654,7 +654,7 @@ Status ExecutorImpl::Initialize(const Graph& graph) {
|
|||
item->input_start = frame_info->total_inputs;
|
||||
frame_info->total_inputs += n->num_inputs();
|
||||
|
||||
Status s = params_.create_kernel(n->def(), &item->kernel);
|
||||
Status s = params_.create_kernel(n->properties(), &item->kernel);
|
||||
if (!s.ok()) {
|
||||
item->kernel = nullptr;
|
||||
s = AttachDef(s, *n);
|
||||
|
@ -2974,12 +2974,12 @@ Status NewLocalExecutor(const LocalExecutorParams& params, const Graph& graph,
|
|||
}
|
||||
|
||||
Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
|
||||
const NodeDef& ndef, int graph_def_version,
|
||||
OpKernel** kernel) {
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
int graph_def_version, OpKernel** kernel) {
|
||||
const auto device_type = DeviceType(device->attributes().device_type());
|
||||
auto allocator = device->GetAllocator(AllocatorAttributes());
|
||||
return CreateOpKernel(device_type, device, allocator, flib,
|
||||
device->resource_manager(), ndef, graph_def_version,
|
||||
device->resource_manager(), props, graph_def_version,
|
||||
kernel);
|
||||
}
|
||||
|
||||
|
|
|
@ -145,7 +145,9 @@ struct LocalExecutorParams {
|
|||
// create_kernel returns an instance of op kernel based on NodeDef.
|
||||
// delete_kernel is called for every kernel used by the executor
|
||||
// when the executor is deleted.
|
||||
std::function<Status(const NodeDef&, OpKernel**)> create_kernel;
|
||||
std::function<Status(const std::shared_ptr<const NodeProperties>&,
|
||||
OpKernel**)>
|
||||
create_kernel;
|
||||
std::function<void(OpKernel*)> delete_kernel;
|
||||
|
||||
Executor::RendezvousFactory rendezvous_factory;
|
||||
|
@ -240,12 +242,12 @@ class ExecutorBarrier {
|
|||
|
||||
// A few helpers to facilitate create/delete kernels.
|
||||
|
||||
// Creates a kernel based on "ndef" on device "device". The kernel can
|
||||
// Creates a kernel based on "props" on device "device". The kernel can
|
||||
// access the functions in the "flib". The caller takes ownership of
|
||||
// returned "*kernel".
|
||||
Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
|
||||
const NodeDef& ndef, int graph_def_version,
|
||||
OpKernel** kernel);
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
int graph_def_version, OpKernel** kernel);
|
||||
|
||||
// Deletes "kernel" returned by CreateKernel.
|
||||
void DeleteNonCachedKernel(OpKernel* kernel);
|
||||
|
|
|
@ -61,11 +61,12 @@ class ExecutorTest : public ::testing::Test {
|
|||
const int version = graph->versions().producer();
|
||||
LocalExecutorParams params;
|
||||
params.device = device_.get();
|
||||
params.create_kernel = [this, version](const NodeDef& ndef,
|
||||
OpKernel** kernel) {
|
||||
return CreateNonCachedKernel(device_.get(), nullptr, ndef, version,
|
||||
kernel);
|
||||
};
|
||||
params.create_kernel =
|
||||
[this, version](const std::shared_ptr<const NodeProperties>& props,
|
||||
OpKernel** kernel) {
|
||||
return CreateNonCachedKernel(device_.get(), nullptr, props, version,
|
||||
kernel);
|
||||
};
|
||||
params.delete_kernel = [](OpKernel* kernel) {
|
||||
DeleteNonCachedKernel(kernel);
|
||||
};
|
||||
|
|
|
@ -187,7 +187,8 @@ class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime {
|
|||
void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame,
|
||||
DoneCallback done) override;
|
||||
|
||||
Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override;
|
||||
Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
|
||||
OpKernel** kernel) override;
|
||||
|
||||
bool IsStateful(const string& function_name) const override;
|
||||
|
||||
|
@ -256,7 +257,8 @@ void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle,
|
|||
base_flr_->Run(opts, handle, call_frame, std::move(done));
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeOverlay::CreateKernel(const NodeDef&, OpKernel**) {
|
||||
Status FunctionLibraryRuntimeOverlay::CreateKernel(
|
||||
const std::shared_ptr<const NodeProperties>&, OpKernel**) {
|
||||
// We don't have access to base_lib_def_ in base function library runtime (aka
|
||||
// FunctionLibraryRuntimeImpl), so to make sure we do not create a kernel with
|
||||
// the wrong lib_def we just disable creation of new kernels through overlays.
|
||||
|
@ -344,7 +346,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
|||
|
||||
Status GetRetTypes(Handle handle, DataTypeVector* ret_types) override;
|
||||
|
||||
Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override;
|
||||
Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
|
||||
OpKernel** kernel) override;
|
||||
|
||||
void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets, DoneCallback done) override;
|
||||
|
@ -393,7 +396,9 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
|||
const string device_name_;
|
||||
|
||||
std::function<Status(const string&, const OpDef**)> get_func_sig_;
|
||||
std::function<Status(const NodeDef&, OpKernel**)> create_kernel_;
|
||||
std::function<Status(const std::shared_ptr<const NodeProperties>&,
|
||||
OpKernel**)>
|
||||
create_kernel_;
|
||||
|
||||
mutable mutex mu_;
|
||||
|
||||
|
@ -426,8 +431,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
|||
// to use for kernel creation and execution. In particular, this method can
|
||||
// accept a FunctionLibraryRuntimeOverlay that overlays a different
|
||||
// FunctionLibraryDefinition.
|
||||
Status CreateKernel(const NodeDef& ndef, FunctionLibraryRuntime* flr,
|
||||
OpKernel** kernel);
|
||||
Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
|
||||
FunctionLibraryRuntime* flr, OpKernel** kernel);
|
||||
Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
std::unique_ptr<FunctionBody>* fbody);
|
||||
|
@ -476,8 +481,9 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
|
|||
get_func_sig_ = [this](const string& op, const OpDef** sig) {
|
||||
return base_lib_def_->LookUpOpDef(op, sig);
|
||||
};
|
||||
create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) {
|
||||
return CreateKernel(ndef, kernel);
|
||||
create_kernel_ = [this](const std::shared_ptr<const NodeProperties>& props,
|
||||
OpKernel** kernel) {
|
||||
return CreateKernel(props, kernel);
|
||||
};
|
||||
thread::ThreadPool* pool = nullptr;
|
||||
if (device_ != nullptr) {
|
||||
|
@ -589,20 +595,20 @@ Status FunctionLibraryRuntimeImpl::GetRetTypes(Handle h,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
|
||||
OpKernel** kernel) {
|
||||
return CreateKernel(ndef, this, kernel);
|
||||
Status FunctionLibraryRuntimeImpl::CreateKernel(
|
||||
const std::shared_ptr<const NodeProperties>& props, OpKernel** kernel) {
|
||||
return CreateKernel(props, this, kernel);
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
|
||||
FunctionLibraryRuntime* flr,
|
||||
OpKernel** kernel) {
|
||||
Status FunctionLibraryRuntimeImpl::CreateKernel(
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
FunctionLibraryRuntime* flr, OpKernel** kernel) {
|
||||
// If a custom kernel creator is given, try that.
|
||||
Status s;
|
||||
if (custom_kernel_creator_ != nullptr &&
|
||||
custom_kernel_creator_->CanCreateKernel(*this, ndef)) {
|
||||
custom_kernel_creator_->CanCreateKernel(*this, props)) {
|
||||
std::unique_ptr<OpKernel> ret;
|
||||
s = custom_kernel_creator_->CreateKernel(this, ndef, &ret);
|
||||
s = custom_kernel_creator_->CreateKernel(this, props, &ret);
|
||||
if (s.ok()) {
|
||||
*kernel = ret.release();
|
||||
} else {
|
||||
|
@ -613,9 +619,9 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
|
|||
|
||||
const FunctionLibraryDefinition* lib_def =
|
||||
flr->GetFunctionLibraryDefinition();
|
||||
if (lib_def->Find(ndef.op()) == nullptr) {
|
||||
if (lib_def->Find(props->node_def.op()) == nullptr) {
|
||||
// A primitive operation. Creates the registered kernel.
|
||||
return CreateNonCachedKernel(device_, flr, ndef, graph_def_version_,
|
||||
return CreateNonCachedKernel(device_, flr, props, graph_def_version_,
|
||||
kernel);
|
||||
}
|
||||
|
||||
|
@ -626,8 +632,9 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
|
|||
options.lib_def = lib_def;
|
||||
}
|
||||
Handle handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Instantiate(ndef.op(), AttrSlice(&ndef.attr()), options, &handle));
|
||||
TF_RETURN_IF_ERROR(Instantiate(props->node_def.op(),
|
||||
AttrSlice(&props->node_def.attr()), options,
|
||||
&handle));
|
||||
|
||||
const FunctionBody* fbody = GetFunctionBody(handle);
|
||||
CHECK_NOTNULL(fbody);
|
||||
|
@ -647,10 +654,12 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
|
|||
|
||||
// Constructs a CallOp kernel for running the instantiated function.
|
||||
auto device_type = DeviceType(device_->attributes().device_type());
|
||||
auto new_props = std::make_shared<NodeProperties>(
|
||||
&fbody->fdef.signature(), props->node_def, fbody->arg_types,
|
||||
fbody->ret_types);
|
||||
OpKernelConstruction construction(
|
||||
device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef,
|
||||
&fbody->fdef.signature(), flr, device_->resource_manager(),
|
||||
fbody->arg_types, input_memory_types, fbody->ret_types,
|
||||
device_type, device_, device_->GetAllocator(AllocatorAttributes()), flr,
|
||||
device_->resource_manager(), props, input_memory_types,
|
||||
output_memory_types, graph_def_version_, &s);
|
||||
if (s.ok()) {
|
||||
*kernel = new CallOp(handle, &construction);
|
||||
|
@ -953,9 +962,11 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) {
|
|||
if (flr == this) {
|
||||
params.create_kernel = create_kernel_;
|
||||
} else {
|
||||
params.create_kernel = [this, flr](const NodeDef& ndef, OpKernel** kernel) {
|
||||
return CreateKernel(ndef, flr, kernel);
|
||||
};
|
||||
params.create_kernel =
|
||||
[this, flr](const std::shared_ptr<const NodeProperties>& props,
|
||||
OpKernel** kernel) {
|
||||
return CreateKernel(props, flr, kernel);
|
||||
};
|
||||
}
|
||||
params.delete_kernel = [](OpKernel* kernel) {
|
||||
DeleteNonCachedKernel(kernel);
|
||||
|
|
|
@ -90,11 +90,12 @@ class FunctionTest : public ::testing::Test {
|
|||
const int version = g->versions().producer();
|
||||
LocalExecutorParams params;
|
||||
params.device = device_.get();
|
||||
params.create_kernel = [this, version](const NodeDef& ndef,
|
||||
OpKernel** kernel) {
|
||||
return CreateNonCachedKernel(device_.get(), nullptr, ndef, version,
|
||||
kernel);
|
||||
};
|
||||
params.create_kernel =
|
||||
[this, version](const std::shared_ptr<const NodeProperties>& props,
|
||||
OpKernel** kernel) {
|
||||
return CreateNonCachedKernel(device_.get(), nullptr, props, version,
|
||||
kernel);
|
||||
};
|
||||
params.delete_kernel = [](OpKernel* kernel) {
|
||||
DeleteNonCachedKernel(kernel);
|
||||
};
|
||||
|
|
|
@ -157,9 +157,10 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
|||
params.device = device_;
|
||||
params.function_library = function_library;
|
||||
const int producer = graph_to_run->versions().producer();
|
||||
params.create_kernel = [this, function_library, producer](const NodeDef& ndef,
|
||||
OpKernel** kernel) {
|
||||
return CreateNonCachedKernel(device_, function_library, ndef, producer,
|
||||
params.create_kernel = [this, function_library, producer](
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
OpKernel** kernel) {
|
||||
return CreateNonCachedKernel(device_, function_library, props, producer,
|
||||
kernel);
|
||||
};
|
||||
params.delete_kernel = [](OpKernel* kernel) { delete kernel; };
|
||||
|
|
|
@ -84,9 +84,10 @@ Benchmark::Benchmark(const string& device, Graph* g,
|
|||
LocalExecutorParams params;
|
||||
params.device = device_.get();
|
||||
params.function_library = nullptr;
|
||||
params.create_kernel = [this, graph_def_version](const NodeDef& ndef,
|
||||
OpKernel** kernel) {
|
||||
return CreateNonCachedKernel(device_.get(), nullptr, ndef,
|
||||
params.create_kernel = [this, graph_def_version](
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
OpKernel** kernel) {
|
||||
return CreateNonCachedKernel(device_.get(), nullptr, props,
|
||||
graph_def_version, kernel);
|
||||
};
|
||||
params.delete_kernel = [](OpKernel* kernel) {
|
||||
|
|
|
@ -233,23 +233,25 @@ Status GraphMgr::InitItem(
|
|||
// Construct the root executor for the subgraph.
|
||||
params.device = unit->device;
|
||||
params.function_library = lib;
|
||||
params.create_kernel = [handle, lib, opseg](const NodeDef& ndef,
|
||||
OpKernel** kernel) {
|
||||
// NOTE(mrry): We must not share function kernels (implemented
|
||||
// using `CallOp`) between subgraphs, because `CallOp::handle_`
|
||||
// is tied to a particular subgraph. Even if the function itself
|
||||
// is stateful, the `CallOp` that invokes it is not.
|
||||
if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) {
|
||||
return lib->CreateKernel(ndef, kernel);
|
||||
}
|
||||
auto create_fn = [lib, &ndef](OpKernel** kernel) {
|
||||
return lib->CreateKernel(ndef, kernel);
|
||||
};
|
||||
// Kernels created for subgraph nodes need to be cached. On
|
||||
// cache miss, create_fn() is invoked to create a kernel based
|
||||
// on the function library here + global op registry.
|
||||
return opseg->FindOrCreate(handle, ndef.name(), kernel, create_fn);
|
||||
};
|
||||
params.create_kernel =
|
||||
[handle, lib, opseg](const std::shared_ptr<const NodeProperties>& props,
|
||||
OpKernel** kernel) {
|
||||
// NOTE(mrry): We must not share function kernels (implemented
|
||||
// using `CallOp`) between subgraphs, because `CallOp::handle_`
|
||||
// is tied to a particular subgraph. Even if the function itself
|
||||
// is stateful, the `CallOp` that invokes it is not.
|
||||
if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) {
|
||||
return lib->CreateKernel(props, kernel);
|
||||
}
|
||||
auto create_fn = [lib, &props](OpKernel** kernel) {
|
||||
return lib->CreateKernel(props, kernel);
|
||||
};
|
||||
// Kernels created for subgraph nodes need to be cached. On
|
||||
// cache miss, create_fn() is invoked to create a kernel based
|
||||
// on the function library here + global op registry.
|
||||
return opseg->FindOrCreate(handle, props->node_def.name(), kernel,
|
||||
create_fn);
|
||||
};
|
||||
params.delete_kernel = [lib](OpKernel* kernel) {
|
||||
if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) {
|
||||
delete kernel;
|
||||
|
|
|
@ -129,6 +129,7 @@ exports_files(
|
|||
"attr_value_util.h",
|
||||
"common_shape_fns.h",
|
||||
"node_def_util.h",
|
||||
"node_properties.h",
|
||||
"op.h",
|
||||
"op_def_builder.h",
|
||||
"op_def_util.h",
|
||||
|
@ -172,6 +173,7 @@ filegroup(
|
|||
"model.h",
|
||||
"node_def_builder.h",
|
||||
"node_def_util.h",
|
||||
"node_properties.h",
|
||||
"numeric_op.h",
|
||||
"numeric_types.h",
|
||||
"op.h",
|
||||
|
@ -338,6 +340,8 @@ filegroup(
|
|||
"node_def_builder.h",
|
||||
"node_def_util.cc",
|
||||
"node_def_util.h",
|
||||
"node_properties.cc",
|
||||
"node_properties.h",
|
||||
"numeric_op.h",
|
||||
"op.cc",
|
||||
"op.h",
|
||||
|
@ -862,6 +866,21 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "node_properties",
|
||||
srcs = ["node_properties.cc"],
|
||||
hdrs = ["node_properties.h"],
|
||||
deps = [
|
||||
":node_def_proto_cc",
|
||||
":node_def_util",
|
||||
":op",
|
||||
":op_def_proto_cc",
|
||||
":tensor",
|
||||
":types_proto_cc",
|
||||
"//tensorflow/core/lib/core:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "op_def_builder",
|
||||
srcs = ["op_def_builder.cc"],
|
||||
|
@ -967,6 +986,7 @@ tf_cc_tests(
|
|||
"model_test.cc",
|
||||
"node_def_builder_test.cc",
|
||||
"node_def_util_test.cc",
|
||||
"node_properties_test.cc",
|
||||
"op_compatibility_test.cc",
|
||||
"op_def_builder_test.cc",
|
||||
"op_def_util_test.cc",
|
||||
|
|
|
@ -722,11 +722,13 @@ class FunctionLibraryRuntime {
|
|||
virtual void Run(const Options& opts, Handle handle,
|
||||
CallFrameInterface* call_frame, DoneCallback done) = 0;
|
||||
|
||||
// Creates a "kernel" for the given node def "ndef".
|
||||
// Creates a "kernel" for the given NodeProperties "props".
|
||||
//
|
||||
// If succeeds, returns OK and the caller takes the ownership of the
|
||||
// returned "*kernel". Otherwise, returns an error.
|
||||
virtual Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) = 0;
|
||||
virtual Status CreateKernel(
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
OpKernel** kernel) = 0;
|
||||
|
||||
// Returns true iff the function named `function_name` is stateful.
|
||||
//
|
||||
|
@ -818,12 +820,15 @@ class CustomKernelCreator {
|
|||
|
||||
// Given a NodeDef 'node_def' and the function library runtime 'flr',
|
||||
// validate if the class supports creating such a kernel.
|
||||
virtual bool CanCreateKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def) const = 0;
|
||||
virtual bool CanCreateKernel(
|
||||
const FunctionLibraryRuntime& flr,
|
||||
const std::shared_ptr<const NodeProperties>& props) const = 0;
|
||||
|
||||
// Given a supported NodeDef, returns a kernel that computes the node.
|
||||
virtual Status CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
std::unique_ptr<OpKernel>* kernel) const = 0;
|
||||
virtual Status CreateKernel(
|
||||
FunctionLibraryRuntime* flr,
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
std::unique_ptr<OpKernel>* kernel) const = 0;
|
||||
};
|
||||
|
||||
// Used to instantiate and run functions in a distributed system.
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/node_properties.h"
|
||||
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// static
|
||||
Status NodeProperties::CreateFromNodeDef(
|
||||
NodeDef node_def, const OpRegistryInterface* op_registry,
|
||||
std::shared_ptr<const NodeProperties>* props) {
|
||||
const OpDef* op_def;
|
||||
TF_RETURN_IF_ERROR(op_registry->LookUpOpDef(node_def.op(), &op_def));
|
||||
DataTypeVector input_types;
|
||||
DataTypeVector output_types;
|
||||
TF_RETURN_IF_ERROR(
|
||||
InOutTypesForNode(node_def, *op_def, &input_types, &output_types));
|
||||
props->reset(new NodeProperties(op_def, std::move(node_def),
|
||||
std::move(input_types),
|
||||
std::move(output_types)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
|
@ -0,0 +1,63 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_PROPERTIES_H_
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_NODE_PROPERTIES_H_
|
||||
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class OpRegistryInterface;
|
||||
|
||||
struct NodeProperties {
|
||||
public:
|
||||
NodeProperties(const OpDef* op_def, NodeDef node_def,
|
||||
const DataTypeSlice inputs, const DataTypeSlice outputs)
|
||||
: NodeProperties(op_def, std::move(node_def),
|
||||
DataTypeVector(inputs.begin(), inputs.end()),
|
||||
DataTypeVector(outputs.begin(), outputs.end())) {}
|
||||
|
||||
NodeProperties(const OpDef* _op_def, NodeDef&& _node_def,
|
||||
DataTypeVector inputs, DataTypeVector outputs)
|
||||
: op_def(_op_def),
|
||||
node_def(std::move(_node_def)),
|
||||
input_types(std::move(inputs)),
|
||||
input_types_slice(input_types),
|
||||
output_types(std::move(outputs)),
|
||||
output_types_slice(output_types) {}
|
||||
|
||||
// Resets the 'props' shared pointer to point to a new NodeProperties created
|
||||
// from the given NodeDef. 'op_registry' is used to look up the OpDef
|
||||
// corresponding to node_def.op(). Returns an error if OpDef lookup or
|
||||
// creation failed.
|
||||
static Status CreateFromNodeDef(NodeDef node_def,
|
||||
const OpRegistryInterface* op_registry,
|
||||
std::shared_ptr<const NodeProperties>* props);
|
||||
|
||||
const OpDef* op_def; // not owned.
|
||||
NodeDef node_def;
|
||||
DataTypeVector input_types;
|
||||
DataTypeSlice input_types_slice;
|
||||
DataTypeVector output_types;
|
||||
DataTypeSlice output_types_slice;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_NODE_PROPERTIES_H_
|
|
@ -0,0 +1,128 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/node_properties.h"
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def_builder.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
OpDef ToOpDef(const OpDefBuilder& builder) {
|
||||
OpRegistrationData op_reg_data;
|
||||
EXPECT_TRUE(builder.Finalize(&op_reg_data).ok());
|
||||
return op_reg_data.op_def;
|
||||
}
|
||||
|
||||
class MockOpRegistry : public OpRegistryInterface {
|
||||
public:
|
||||
MockOpRegistry()
|
||||
: op_reg_(ToOpDef(OpDefBuilder("Foo")
|
||||
.Input("f: float")
|
||||
.Input("i: int32")
|
||||
.Output("of: double"))) {}
|
||||
~MockOpRegistry() override {}
|
||||
|
||||
// Returns an error status and sets *op_reg_data to nullptr if no OpDef is
|
||||
// registered under that name, otherwise returns the registered OpDef.
|
||||
// Caller must not delete the returned pointer.
|
||||
Status LookUp(const string& op_type_name,
|
||||
const OpRegistrationData** op_reg_data) const override {
|
||||
if (op_type_name == "Foo") {
|
||||
*op_reg_data = &op_reg_;
|
||||
return Status::OK();
|
||||
} else {
|
||||
*op_reg_data = nullptr;
|
||||
return errors::InvalidArgument("Op type named ", op_type_name,
|
||||
" not found");
|
||||
}
|
||||
}
|
||||
|
||||
const OpDef* get_op_def_addr() { return &op_reg_.op_def; }
|
||||
|
||||
private:
|
||||
const OpRegistrationData op_reg_;
|
||||
};
|
||||
|
||||
void ValidateNodeProperties(const NodeProperties& props, const OpDef* op_def,
|
||||
const NodeDef& node_def,
|
||||
const DataTypeVector& input_types,
|
||||
const DataTypeVector& output_types) {
|
||||
EXPECT_EQ(props.op_def, op_def);
|
||||
EXPECT_EQ(props.node_def.name(), node_def.name());
|
||||
ASSERT_EQ(props.input_types.size(), input_types.size());
|
||||
for (int i = 0; i < input_types.size(); ++i) {
|
||||
EXPECT_EQ(props.input_types[i], input_types[i]);
|
||||
EXPECT_EQ(props.input_types_slice[i], input_types[i]);
|
||||
}
|
||||
ASSERT_EQ(props.output_types.size(), output_types.size());
|
||||
for (int i = 0; i < output_types.size(); ++i) {
|
||||
EXPECT_EQ(props.output_types[i], output_types[i]);
|
||||
EXPECT_EQ(props.output_types_slice[i], output_types[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(NodeProperties, Contructors) {
|
||||
OpDef op_def;
|
||||
NodeDef node_def;
|
||||
node_def.set_name("foo");
|
||||
DataTypeVector input_types{DT_FLOAT, DT_INT32};
|
||||
DataTypeVector output_types{DT_DOUBLE};
|
||||
DataTypeSlice input_types_slice(input_types);
|
||||
DataTypeSlice output_types_slice(output_types);
|
||||
|
||||
// Construct from slices.
|
||||
NodeProperties props_from_slices(&op_def, node_def, input_types_slice,
|
||||
output_types_slice);
|
||||
ValidateNodeProperties(props_from_slices, &op_def, node_def, input_types,
|
||||
output_types);
|
||||
|
||||
// Construct from vectors.
|
||||
NodeProperties props_from_vectors(&op_def, node_def, input_types,
|
||||
output_types);
|
||||
ValidateNodeProperties(props_from_vectors, &op_def, node_def, input_types,
|
||||
output_types);
|
||||
}
|
||||
|
||||
TEST(NodeProperties, CreateFromNodeDef) {
|
||||
MockOpRegistry op_registry;
|
||||
NodeDef node_def;
|
||||
node_def.set_name("bar");
|
||||
node_def.set_op("Foo");
|
||||
node_def.add_input("f_in");
|
||||
node_def.add_input("i_in");
|
||||
|
||||
std::shared_ptr<const NodeProperties> props;
|
||||
EXPECT_TRUE(
|
||||
NodeProperties::CreateFromNodeDef(node_def, &op_registry, &props).ok());
|
||||
|
||||
DataTypeVector input_types{DT_FLOAT, DT_INT32};
|
||||
DataTypeVector output_types{DT_DOUBLE};
|
||||
ValidateNodeProperties(*props, op_registry.get_op_def_addr(), node_def,
|
||||
input_types, output_types);
|
||||
|
||||
// The OpDef lookup should fail for this one:
|
||||
node_def.set_op("Baz");
|
||||
std::shared_ptr<const NodeProperties> props_bad;
|
||||
EXPECT_FALSE(
|
||||
NodeProperties::CreateFromNodeDef(node_def, &op_registry, &props_bad)
|
||||
.ok());
|
||||
EXPECT_EQ(props_bad, nullptr);
|
||||
}
|
||||
} // namespace tensorflow
|
|
@ -35,9 +35,9 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/memory_types.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/node_properties.h"
|
||||
#include "tensorflow/core/framework/op_def_util.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
|
@ -91,35 +91,53 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
|
|||
|
||||
// OpKernel ------------------------------------------------------------------
|
||||
|
||||
OpKernel::OpKernel(OpKernelConstruction* context)
|
||||
: OpKernel(context, MakeUnique<const NodeDef>(context->def())) {}
|
||||
OpKernel::OpKernel(OpKernelConstruction* context) : OpKernel(context, false) {}
|
||||
|
||||
OpKernel::OpKernel(OpKernelConstruction* context, bool is_deferred)
|
||||
: OpKernel(context, MakeUnique<const NodeDef>(context->def()),
|
||||
is_deferred) {}
|
||||
|
||||
OpKernel::OpKernel(OpKernelConstruction* context,
|
||||
std::unique_ptr<const NodeDef> node_def, bool is_deferred)
|
||||
: def_(std::move(node_def)),
|
||||
input_types_(context->input_types().begin(),
|
||||
context->input_types().end()),
|
||||
: props_(context->props_),
|
||||
input_memory_types_(context->input_memory_types().begin(),
|
||||
context->input_memory_types().end()),
|
||||
output_types_(context->output_types().begin(),
|
||||
context->output_types().end()),
|
||||
output_memory_types_(context->output_memory_types().begin(),
|
||||
context->output_memory_types().end()),
|
||||
input_name_map_(context->num_inputs()),
|
||||
output_name_map_(context->num_outputs()),
|
||||
name_view_(def_->name()),
|
||||
type_string_view_(def_->op()),
|
||||
name_view_(props_->node_def.name()),
|
||||
type_string_view_(props_->node_def.op()),
|
||||
graph_def_version_(context->graph_def_version()),
|
||||
is_deferred_(is_deferred),
|
||||
cost_estimate_(OpKernel::kInitialCostEstimateCycles) {
|
||||
OP_REQUIRES_OK(context,
|
||||
NameRangesForNode(*def_, *context->op_def_, &input_name_map_,
|
||||
&output_name_map_));
|
||||
OP_REQUIRES_OK(context, CheckOpDeprecation(*context->op_def_,
|
||||
NameRangesForNode(props_->node_def, *props_->op_def,
|
||||
&input_name_map_, &output_name_map_));
|
||||
OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def,
|
||||
context->graph_def_version()));
|
||||
|
||||
// Kernels executing on GPU/SYCL tie very few resources on the CPU where the
|
||||
// scheduler runs: we consider them as inexpensive.
|
||||
expensive_ = context->device_type() != DeviceType(DEVICE_GPU) &&
|
||||
context->device_type() != DeviceType(DEVICE_SYCL);
|
||||
}
|
||||
|
||||
OpKernel::OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
|
||||
bool is_deferred)
|
||||
: props_(std::make_shared<const NodeProperties>(
|
||||
context->props_->op_def, std::move(custom_def),
|
||||
context->props_->input_types, context->props_->output_types)),
|
||||
input_memory_types_(context->input_memory_types().begin(),
|
||||
context->input_memory_types().end()),
|
||||
output_memory_types_(context->output_memory_types().begin(),
|
||||
context->output_memory_types().end()),
|
||||
input_name_map_(context->num_inputs()),
|
||||
output_name_map_(context->num_outputs()),
|
||||
name_view_(props_->node_def.name()),
|
||||
type_string_view_(props_->node_def.op()),
|
||||
graph_def_version_(context->graph_def_version()),
|
||||
is_deferred_(is_deferred),
|
||||
cost_estimate_(OpKernel::kInitialCostEstimateCycles) {
|
||||
OP_REQUIRES_OK(context,
|
||||
NameRangesForNode(props_->node_def, *props_->op_def,
|
||||
&input_name_map_, &output_name_map_));
|
||||
OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def,
|
||||
context->graph_def_version()));
|
||||
|
||||
// Kernels executing on GPU/SYCL tie very few resources on the CPU where the
|
||||
|
@ -134,10 +152,6 @@ const uint64 OpKernel::kInitialCostEstimateCycles;
|
|||
const uint64 OpKernel::kOpIsExpensiveThresholdCycles;
|
||||
const uint64 OpKernel::kCostDecay;
|
||||
|
||||
const string& OpKernel::name() const { return def_->name(); }
|
||||
const string& OpKernel::type_string() const { return def_->op(); }
|
||||
const string& OpKernel::requested_device() const { return def_->device(); }
|
||||
const string& OpKernel::requested_input(int i) const { return def_->input(i); }
|
||||
|
||||
Status OpKernel::InputRange(StringPiece input_name, int* start,
|
||||
int* stop) const {
|
||||
|
@ -216,22 +230,18 @@ Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) {
|
|||
|
||||
OpKernelConstruction::OpKernelConstruction(
|
||||
DeviceType device_type, DeviceBase* device, Allocator* allocator,
|
||||
const NodeDef* node_def, const OpDef* op_def, FunctionLibraryRuntime* flib,
|
||||
ResourceMgr* resource_mgr, const DataTypeSlice& input_types,
|
||||
FunctionLibraryRuntime* flib, ResourceMgr* resource_mgr,
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
const MemoryTypeSlice& input_memory_types,
|
||||
const DataTypeSlice& output_types,
|
||||
const MemoryTypeSlice& output_memory_types, int graph_def_version,
|
||||
Status* status)
|
||||
: device_type_(std::move(device_type)),
|
||||
device_(device),
|
||||
allocator_(allocator),
|
||||
def_(node_def),
|
||||
op_def_(op_def),
|
||||
flib_(flib),
|
||||
resource_mgr_(resource_mgr),
|
||||
input_types_(input_types),
|
||||
props_(props),
|
||||
input_memory_types_(input_memory_types),
|
||||
output_types_(output_types),
|
||||
output_memory_types_(output_memory_types),
|
||||
graph_def_version_(graph_def_version),
|
||||
status_(status) {}
|
||||
|
@ -246,8 +256,8 @@ void OpKernelConstruction::SetStatus(const Status& status) {
|
|||
|
||||
Status OpKernelConstruction::MatchSignature(
|
||||
const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) {
|
||||
return MatchSignatureHelper(expected_inputs, expected_outputs, input_types_,
|
||||
output_types_);
|
||||
return MatchSignatureHelper(expected_inputs, expected_outputs,
|
||||
props_->input_types, props_->output_types);
|
||||
}
|
||||
|
||||
Status OpKernelConstruction::allocate_temp(DataType type,
|
||||
|
@ -263,7 +273,7 @@ Status OpKernelConstruction::allocate_temp(DataType type,
|
|||
}
|
||||
if (LogMemory::IsEnabled()) {
|
||||
LogMemory::RecordTensorAllocation(
|
||||
def_->name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
|
||||
def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
|
||||
}
|
||||
*out_temp = new_temp;
|
||||
return Status::OK();
|
||||
|
@ -288,7 +298,7 @@ Status OpKernelConstruction::allocate_temp(DataType type,
|
|||
}
|
||||
if (LogMemory::IsEnabled()) {
|
||||
LogMemory::RecordTensorAllocation(
|
||||
def_->name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
|
||||
def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
|
||||
}
|
||||
*out_temp = new_temp;
|
||||
return Status::OK();
|
||||
|
@ -1544,45 +1554,65 @@ string KernelsRegisteredForOp(StringPiece op_name) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
/* TODO(rmlarsen): This API is deprecated. Remove it if possible to avoid
|
||||
* copying the NodeDef. */
|
||||
std::unique_ptr<OpKernel> CreateOpKernel(
|
||||
DeviceType device_type, DeviceBase* device, Allocator* allocator,
|
||||
const NodeDef& node_def, int graph_def_version, Status* status) {
|
||||
// Look up the Op registered for this op name.
|
||||
std::shared_ptr<const NodeProperties> props;
|
||||
status->Update(NodeProperties::CreateFromNodeDef(
|
||||
node_def, OpRegistry::Global(), &props));
|
||||
if (!status->ok()) {
|
||||
errors::AppendToMessage(status,
|
||||
" for node: ", FormatNodeDefForError(node_def));
|
||||
return nullptr;
|
||||
}
|
||||
return CreateOpKernel(device_type, device, allocator, props,
|
||||
graph_def_version, status);
|
||||
}
|
||||
|
||||
std::unique_ptr<OpKernel> CreateOpKernel(
|
||||
DeviceType device_type, DeviceBase* device, Allocator* allocator,
|
||||
const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
|
||||
Status* status) {
|
||||
OpKernel* kernel = nullptr;
|
||||
*status = CreateOpKernel(std::move(device_type), device, allocator, nullptr,
|
||||
node_def, graph_def_version, &kernel);
|
||||
*status = CreateOpKernel(std::move(device_type), device, allocator,
|
||||
/*flib=*/nullptr, props, graph_def_version, &kernel);
|
||||
return std::unique_ptr<OpKernel>(kernel);
|
||||
}
|
||||
|
||||
Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
|
||||
Allocator* allocator, FunctionLibraryRuntime* flib,
|
||||
const NodeDef& node_def, int graph_def_version,
|
||||
OpKernel** kernel) {
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
int graph_def_version, OpKernel** kernel) {
|
||||
return CreateOpKernel(std::move(device_type), device, allocator, flib,
|
||||
/* resource_mgr= */ nullptr, node_def,
|
||||
graph_def_version, kernel);
|
||||
/* resource_mgr= */ nullptr, props, graph_def_version,
|
||||
kernel);
|
||||
}
|
||||
|
||||
Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
|
||||
Allocator* allocator, FunctionLibraryRuntime* flib,
|
||||
ResourceMgr* resource_mgr, const NodeDef& node_def,
|
||||
ResourceMgr* resource_mgr,
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
int graph_def_version, OpKernel** kernel) {
|
||||
VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def);
|
||||
|
||||
// Look up the Op registered for this op name.
|
||||
const OpDef* op_def = nullptr;
|
||||
TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def));
|
||||
|
||||
// Validate node_def against OpDef.
|
||||
TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *op_def));
|
||||
|
||||
// Look up kernel registration.
|
||||
const KernelRegistration* registration;
|
||||
const NodeDef& node_def = props->node_def;
|
||||
bool was_attr_mismatch;
|
||||
Status s = FindKernelRegistration(device_type, node_def, ®istration,
|
||||
&was_attr_mismatch);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, " when instantiating ", node_def.op());
|
||||
return s;
|
||||
const KernelRegistration* registration = nullptr;
|
||||
Status s;
|
||||
if (props != nullptr) {
|
||||
VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def);
|
||||
|
||||
// Validate node_def against OpDef.
|
||||
TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *props->op_def));
|
||||
|
||||
// Look up kernel registration.
|
||||
s = FindKernelRegistration(device_type, node_def, ®istration,
|
||||
&was_attr_mismatch);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, " when instantiating ", node_def.op());
|
||||
return s;
|
||||
}
|
||||
}
|
||||
if (registration == nullptr) {
|
||||
s.Update(errors::NotFound("No registered '", node_def.op(),
|
||||
|
@ -1599,15 +1629,6 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
|
|||
return s;
|
||||
}
|
||||
|
||||
// Get signature from the OpDef & NodeDef
|
||||
DataTypeVector inputs;
|
||||
DataTypeVector outputs;
|
||||
s.Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs));
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, " for node: ", FormatNodeDefForError(node_def));
|
||||
return s;
|
||||
}
|
||||
|
||||
// We are creating a kernel for an op registered in
|
||||
// OpRegistry::Global(), we consult the kernel registry to decide
|
||||
// the kernel's input and output memory types.
|
||||
|
@ -1618,10 +1639,9 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
|
|||
&output_memory_types));
|
||||
|
||||
// Everything needed for OpKernel construction.
|
||||
OpKernelConstruction context(std::move(device_type), device, allocator,
|
||||
&node_def, op_def, flib, resource_mgr, inputs,
|
||||
input_memory_types, outputs, output_memory_types,
|
||||
graph_def_version, &s);
|
||||
OpKernelConstruction context(std::move(device_type), device, allocator, flib,
|
||||
resource_mgr, props, input_memory_types,
|
||||
output_memory_types, graph_def_version, &s);
|
||||
*kernel = registration->factory->Create(&context);
|
||||
if (!s.ok()) {
|
||||
delete *kernel;
|
||||
|
|
|
@ -31,6 +31,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/node_properties.h"
|
||||
#include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/selective_registration.h"
|
||||
|
@ -85,19 +86,18 @@ class OpKernel {
|
|||
// expensive initialization in the descendant's constructor.
|
||||
explicit OpKernel(OpKernelConstruction* context);
|
||||
|
||||
// Specialized constructor that enables the descendant to provide a different
|
||||
// `NodeDef` value. For example, this constructor can be used to provide a
|
||||
// stripped-down `NodeDef` that does not contain the full set of attrs (such
|
||||
// as tensor values) if the descendant stores them in a different form.
|
||||
explicit OpKernel(OpKernelConstruction* context,
|
||||
std::unique_ptr<const NodeDef> node_def,
|
||||
bool is_deferred = false);
|
||||
|
||||
// Specialized constructor that allows a kernel implementation to mark itself
|
||||
// as a "deferred" op. If true, the executor will provide access to the
|
||||
// `OpKernelContext::inc_num_deferred_ops_function()` and
|
||||
// `OpKernelContext::dec_num_deferred_ops_function()` methods at run-time.
|
||||
explicit OpKernel(OpKernelConstruction* context, bool is_deferred);
|
||||
OpKernel(OpKernelConstruction* context, bool is_deferred);
|
||||
|
||||
// Specialized constructor that enables the descendant to provide a custom
|
||||
// `NodeDef` value. For example, this constructor can be used to provide a
|
||||
// stripped-down `NodeDef` that does not contain the full set of attrs (such
|
||||
// as tensor values) if the descendant stores them in a different form.
|
||||
OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
|
||||
bool is_deferred);
|
||||
|
||||
virtual ~OpKernel();
|
||||
|
||||
|
@ -170,24 +170,26 @@ class OpKernel {
|
|||
}
|
||||
|
||||
// Accessors.
|
||||
const NodeDef& def() const { return *def_; }
|
||||
const string& name() const; // Same as def().name()
|
||||
const NodeDef& def() const { return props_->node_def; }
|
||||
const string& name() const { return props_->node_def.name(); }
|
||||
absl::string_view name_view() const { return name_view_; }
|
||||
const string& type_string() const; // Same as def().op()
|
||||
const string& type_string() const { return props_->node_def.op(); }
|
||||
absl::string_view type_string_view() const { return type_string_view_; }
|
||||
const string& requested_device() const; // Same as def().device()
|
||||
const string& requested_input(int i) const {
|
||||
return props_->node_def.input(i);
|
||||
}
|
||||
const string& requested_device() const { return props_->node_def.device(); }
|
||||
|
||||
int num_inputs() const { return input_types_.size(); }
|
||||
DataType input_type(int i) const { return input_types_[i]; }
|
||||
const DataTypeVector& input_types() const { return input_types_; }
|
||||
int num_inputs() const { return props_->input_types.size(); }
|
||||
DataType input_type(int i) const { return props_->input_types[i]; }
|
||||
const DataTypeVector& input_types() const { return props_->input_types; }
|
||||
const MemoryTypeVector& input_memory_types() const {
|
||||
return input_memory_types_;
|
||||
}
|
||||
const string& requested_input(int i) const; // Same as def().input(i)
|
||||
|
||||
int num_outputs() const { return output_types_.size(); }
|
||||
DataType output_type(int o) const { return output_types_[o]; }
|
||||
const DataTypeVector& output_types() const { return output_types_; }
|
||||
int num_outputs() const { return props_->output_types.size(); }
|
||||
DataType output_type(int o) const { return props_->output_types[o]; }
|
||||
const DataTypeVector& output_types() const { return props_->output_types; }
|
||||
const MemoryTypeVector& output_memory_types() const {
|
||||
return output_memory_types_;
|
||||
}
|
||||
|
@ -209,10 +211,8 @@ class OpKernel {
|
|||
string GetTraceArgument(OpKernelContext* ctx);
|
||||
|
||||
private:
|
||||
const std::unique_ptr<const NodeDef> def_;
|
||||
const DataTypeVector input_types_;
|
||||
const std::shared_ptr<const NodeProperties> props_;
|
||||
const MemoryTypeVector input_memory_types_;
|
||||
const DataTypeVector output_types_;
|
||||
const MemoryTypeVector output_memory_types_;
|
||||
NameRangeMap input_name_map_;
|
||||
NameRangeMap output_name_map_;
|
||||
|
@ -284,12 +284,10 @@ class PersistentTensor {
|
|||
class OpKernelConstruction {
|
||||
public:
|
||||
OpKernelConstruction(DeviceType device_type, DeviceBase* device,
|
||||
Allocator* allocator, const NodeDef* node_def,
|
||||
const OpDef* op_def, FunctionLibraryRuntime* flib,
|
||||
Allocator* allocator, FunctionLibraryRuntime* flib,
|
||||
ResourceMgr* resource_mgr,
|
||||
const DataTypeSlice& input_types,
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
const MemoryTypeSlice& input_memory_types,
|
||||
const DataTypeSlice& output_types,
|
||||
const MemoryTypeSlice& output_memory_types,
|
||||
int graph_def_version, Status* status);
|
||||
|
||||
|
@ -330,20 +328,22 @@ class OpKernelConstruction {
|
|||
Tensor** out_tensor);
|
||||
|
||||
// User-supplied configuration of this operation.
|
||||
const NodeDef& def() const { return *def_; }
|
||||
const NodeDef& def() const { return props_->node_def; }
|
||||
|
||||
// For inspecting the inputs to this operation.
|
||||
int num_inputs() const { return input_types_.size(); }
|
||||
DataType input_type(int i) const { return input_types_[i]; }
|
||||
const DataTypeSlice& input_types() const { return input_types_; }
|
||||
int num_inputs() const { return props_->input_types.size(); }
|
||||
DataType input_type(int i) const { return props_->input_types[i]; }
|
||||
const DataTypeSlice& input_types() const { return props_->input_types_slice; }
|
||||
const MemoryTypeSlice& input_memory_types() const {
|
||||
return input_memory_types_;
|
||||
}
|
||||
|
||||
// For inspecting the outputs expected from this operation.
|
||||
int num_outputs() const { return output_types_.size(); }
|
||||
DataType output_type(int i) const { return output_types_[i]; }
|
||||
const DataTypeSlice& output_types() const { return output_types_; }
|
||||
int num_outputs() const { return props_->output_types.size(); }
|
||||
DataType output_type(int i) const { return props_->output_types[i]; }
|
||||
const DataTypeSlice& output_types() const {
|
||||
return props_->output_types_slice;
|
||||
}
|
||||
const MemoryTypeSlice& output_memory_types() const {
|
||||
return output_memory_types_;
|
||||
}
|
||||
|
@ -403,19 +403,15 @@ class OpKernelConstruction {
|
|||
const DeviceType device_type_;
|
||||
DeviceBase* const device_;
|
||||
Allocator* allocator_;
|
||||
const NodeDef* def_;
|
||||
const OpDef* op_def_;
|
||||
FunctionLibraryRuntime* flib_;
|
||||
ResourceMgr* const resource_mgr_;
|
||||
DataTypeSlice input_types_;
|
||||
std::shared_ptr<const NodeProperties> props_;
|
||||
MemoryTypeSlice input_memory_types_;
|
||||
DataTypeSlice output_types_;
|
||||
MemoryTypeSlice output_memory_types_;
|
||||
const int graph_def_version_;
|
||||
Status* status_;
|
||||
|
||||
// Allow op_def_ across from OpKernel, but not from subclasses.
|
||||
// TODO(irving): Remove protos from this header entirely.
|
||||
// Allow access from OpKernel ctor.
|
||||
friend class OpKernel;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
|
||||
|
@ -1404,15 +1400,23 @@ const Eigen::SyclDevice& OpKernelContext::eigen_device() const;
|
|||
std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
|
||||
DeviceBase* device,
|
||||
Allocator* allocator,
|
||||
const NodeDef& def,
|
||||
const NodeDef& node_def,
|
||||
int graph_def_version, Status* status);
|
||||
|
||||
std::unique_ptr<OpKernel> CreateOpKernel(
|
||||
DeviceType device_type, DeviceBase* device, Allocator* allocator,
|
||||
const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
|
||||
Status* status);
|
||||
|
||||
Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
|
||||
Allocator* allocator, FunctionLibraryRuntime* flib,
|
||||
const NodeDef& def, int graph_def_version,
|
||||
OpKernel** kernel);
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
int graph_def_version, OpKernel** kernel);
|
||||
|
||||
Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
|
||||
Allocator* allocator, FunctionLibraryRuntime* flib,
|
||||
ResourceMgr* resource_mgr, const NodeDef& def,
|
||||
ResourceMgr* resource_mgr,
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
int graph_def_version, OpKernel** kernel);
|
||||
|
||||
// Returns into 'device_types' the subset of prioritized_types that this
|
||||
|
|
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
|
|
|
@ -19,7 +19,7 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/node_properties.h"
|
||||
#include "tensorflow/core/framework/op_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
|
@ -37,23 +37,7 @@ namespace tensorflow {
|
|||
|
||||
const int Graph::kControlSlot = -1;
|
||||
|
||||
struct NodeProperties {
|
||||
public:
|
||||
NodeProperties(const OpDef* op_def, NodeDef node_def,
|
||||
const DataTypeSlice inputs, const DataTypeSlice outputs)
|
||||
: op_def(op_def),
|
||||
node_def(std::move(node_def)),
|
||||
input_types(inputs.begin(), inputs.end()),
|
||||
output_types(outputs.begin(), outputs.end()) {}
|
||||
|
||||
const OpDef* op_def; // not owned
|
||||
NodeDef node_def;
|
||||
const DataTypeVector input_types;
|
||||
const DataTypeVector output_types;
|
||||
};
|
||||
|
||||
// Node
|
||||
|
||||
#define REF_CLASS(key, value) \
|
||||
{key, value}, { "Ref" key, value }
|
||||
|
||||
|
@ -97,7 +81,8 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
|
|||
{"StatelessIf", NC_IF},
|
||||
{"While", NC_WHILE},
|
||||
{"StatelessWhile", NC_WHILE},
|
||||
// Not using the constants defined in FunctionLibraryDefinition for the
|
||||
// Not using the constants defined in FunctionLibraryDefinition
|
||||
// for the
|
||||
// 4 ops below because android inference library does not link
|
||||
// tf.function related files.
|
||||
{"_Arg", NC_ARG},
|
||||
|
|
|
@ -43,6 +43,7 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/edgeset.h"
|
||||
|
@ -67,7 +68,6 @@ class WhileContext;
|
|||
|
||||
class NeighborIter; // Declared below
|
||||
class NodeIter; // Declared below
|
||||
struct NodeProperties; // Defined in .cc
|
||||
|
||||
class Node {
|
||||
public:
|
||||
|
@ -229,11 +229,12 @@ class Node {
|
|||
while_ctx_ = while_ctx;
|
||||
}
|
||||
|
||||
std::shared_ptr<NodeProperties> properties() const { return props_; }
|
||||
|
||||
private:
|
||||
friend class Graph;
|
||||
Node();
|
||||
|
||||
NodeProperties* properties() const { return props_.get(); }
|
||||
|
||||
void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props,
|
||||
bool is_function_op);
|
||||
|
|
|
@ -47,31 +47,30 @@ namespace tensorflow {
|
|||
|
||||
namespace {
|
||||
|
||||
std::unique_ptr<const NodeDef> StripTensorDataFromNodeDef(
|
||||
OpKernelConstruction* ctx) {
|
||||
NodeDef StripTensorDataFromNodeDef(OpKernelConstruction* ctx) {
|
||||
#ifndef __ANDROID__
|
||||
DCHECK_EQ(NodeDef::descriptor()->field_count(), 6)
|
||||
<< "The NodeDef format has changed, and the attr-stripping code may need "
|
||||
<< "to be updated.";
|
||||
#endif
|
||||
const NodeDef& original = ctx->def();
|
||||
NodeDef* ret = new NodeDef;
|
||||
ret->set_name(original.name());
|
||||
ret->set_op(original.op());
|
||||
ret->set_device(original.device());
|
||||
NodeDef ret;
|
||||
ret.set_name(original.name());
|
||||
ret.set_op(original.op());
|
||||
ret.set_device(original.device());
|
||||
// Strip the "value" attr from the returned NodeDef.
|
||||
// NOTE(mrry): The present implementation of `OpKernel::OpKernel()` only uses
|
||||
// attrs that affect the cardinality of list-typed inputs and outputs, so it
|
||||
// is safe to drop other attrs from the NodeDef.
|
||||
AddNodeAttr("dtype", ctx->output_type(0), ret);
|
||||
MergeDebugInfo(original, ret);
|
||||
return std::unique_ptr<const NodeDef>(ret);
|
||||
AddNodeAttr("dtype", ctx->output_type(0), &ret);
|
||||
MergeDebugInfo(original, &ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ConstantOp::ConstantOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx, StripTensorDataFromNodeDef(ctx)),
|
||||
: OpKernel(ctx, StripTensorDataFromNodeDef(ctx), false),
|
||||
tensor_(ctx->output_type(0)) {
|
||||
const TensorProto* proto = nullptr;
|
||||
MEMDEBUG_CACHE_OP(ctx->def().name().c_str());
|
||||
|
|
|
@ -304,9 +304,14 @@ Status DatasetOpsTestBase::ExpectEqual(std::vector<Tensor> produced_tensors,
|
|||
Status DatasetOpsTestBase::CreateOpKernel(
|
||||
const NodeDef& node_def, std::unique_ptr<OpKernel>* op_kernel) {
|
||||
OpKernel* kernel;
|
||||
Status s;
|
||||
|
||||
std::shared_ptr<const NodeProperties> props;
|
||||
TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef(
|
||||
node_def, flr_->GetFunctionLibraryDefinition(), &props));
|
||||
TF_RETURN_IF_ERROR(tensorflow::CreateOpKernel(
|
||||
device_type_, device_.get(), allocator_, flr_,
|
||||
device_->resource_manager(), node_def, TF_GRAPH_DEF_VERSION, &kernel));
|
||||
device_->resource_manager(), props, TF_GRAPH_DEF_VERSION, &kernel));
|
||||
op_kernel->reset(kernel);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -435,9 +440,10 @@ Status DatasetOpsTestBase::RunFunction(
|
|||
LocalExecutorParams params;
|
||||
params.function_library = flr_;
|
||||
params.device = device_.get();
|
||||
params.create_kernel = [this, version](const NodeDef& ndef,
|
||||
OpKernel** kernel) {
|
||||
return CreateNonCachedKernel(device_.get(), this->flr_, ndef, version,
|
||||
params.create_kernel = [this, version](
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
OpKernel** kernel) {
|
||||
return CreateNonCachedKernel(device_.get(), this->flr_, props, version,
|
||||
kernel);
|
||||
};
|
||||
params.delete_kernel = [](OpKernel* kernel) {
|
||||
|
|
|
@ -108,7 +108,8 @@ class SingleThreadedExecutorImpl : public Executor {
|
|||
KernelState& kernel_state = kernels_[kernel_index];
|
||||
node_to_index_map[n] = kernel_index;
|
||||
|
||||
TF_RETURN_IF_ERROR(params_.create_kernel(n->def(), &kernel_state.kernel));
|
||||
TF_RETURN_IF_ERROR(
|
||||
params_.create_kernel(n->properties(), &kernel_state.kernel));
|
||||
kernel_state.num_inputs = n->num_inputs();
|
||||
kernel_state.num_outputs = n->num_outputs();
|
||||
|
||||
|
|
|
@ -58,11 +58,12 @@ class ExecutorTest : public ::testing::Test {
|
|||
const int version = graph->versions().producer();
|
||||
LocalExecutorParams params;
|
||||
params.device = device_.get();
|
||||
params.create_kernel = [this, version](const NodeDef& ndef,
|
||||
OpKernel** kernel) {
|
||||
return CreateNonCachedKernel(device_.get(), nullptr, ndef, version,
|
||||
kernel);
|
||||
};
|
||||
params.create_kernel =
|
||||
[this, version](const std::shared_ptr<const NodeProperties>& props,
|
||||
OpKernel** kernel) {
|
||||
return CreateNonCachedKernel(device_.get(), nullptr, props, version,
|
||||
kernel);
|
||||
};
|
||||
params.delete_kernel = [](OpKernel* kernel) {
|
||||
DeleteNonCachedKernel(kernel);
|
||||
};
|
||||
|
|
|
@ -237,8 +237,7 @@ class Tests(test.TestCase):
|
|||
@test_util.assert_no_garbage_created
|
||||
def testInvalidNumOutputs(self):
|
||||
with self.assertRaisesRegexp(
|
||||
Exception,
|
||||
"Value for attr 'num_split' of -1 must be at least minimum 1"):
|
||||
Exception, r"Value for number_attr\(\) -1 < 0 \[Op:Split\]"):
|
||||
array_ops.split(value=[1, 2, 3], num_or_size_splits=-1)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
|
|
Loading…
Reference in New Issue