Move AttrsMatch() to kernel_def_util.h so client can use it.
PiperOrigin-RevId: 201991753
This commit is contained in:
parent
7c5d31aa32
commit
73524dc765
@ -792,6 +792,7 @@ tf_cuda_library(
|
||||
"framework/graph_def_util.h",
|
||||
"framework/graph_to_functiondef.h",
|
||||
"framework/kernel_def_builder.h",
|
||||
"framework/kernel_def_util.h",
|
||||
"framework/log_memory.h",
|
||||
"framework/lookup_interface.h",
|
||||
"framework/memory_types.h",
|
||||
@ -3376,6 +3377,7 @@ tf_cc_tests(
|
||||
"framework/graph_def_util_test.cc",
|
||||
"framework/graph_to_functiondef_test.cc",
|
||||
"framework/kernel_def_builder_test.cc",
|
||||
"framework/kernel_def_util_test.cc",
|
||||
"framework/memory_types_test.cc",
|
||||
"framework/node_def_builder_test.cc",
|
||||
"framework/node_def_util_test.cc",
|
||||
|
83
tensorflow/core/framework/kernel_def_util.cc
Normal file
83
tensorflow/core/framework/kernel_def_util.cc
Normal file
@ -0,0 +1,83 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/kernel_def_util.h"
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb_text.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
// Helper for KernelAttrsMatch().
|
||||
bool InTypeList(DataType dt, const AttrValue& type_list) {
|
||||
for (int in_list : type_list.list().type()) {
|
||||
if (dt == in_list) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs,
|
||||
bool* match) {
|
||||
*match = false;
|
||||
for (const auto& constraint : kernel_def.constraint()) {
|
||||
if (constraint.allowed_values().list().type_size() == 0) {
|
||||
return errors::Unimplemented(
|
||||
"KernelDef '", ProtoShortDebugString(kernel_def),
|
||||
" has constraint on attr '", constraint.name(),
|
||||
"' with unsupported type: ",
|
||||
SummarizeAttrValue(constraint.allowed_values()));
|
||||
}
|
||||
|
||||
const AttrValue* found = attrs.Find(constraint.name());
|
||||
if (found) {
|
||||
if (found->type() != DT_INVALID) {
|
||||
if (!InTypeList(found->type(), constraint.allowed_values())) {
|
||||
return Status::OK();
|
||||
}
|
||||
} else {
|
||||
if (!AttrValueHasType(*found, "list(type)").ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"KernelDef '", ProtoShortDebugString(kernel_def),
|
||||
"' has constraint on attr '", constraint.name(),
|
||||
"' that has value '", SummarizeAttrValue(*found),
|
||||
"' that does not have type 'type' or 'list(type)' in NodeDef "
|
||||
"'",
|
||||
attrs.SummarizeNode(), "'");
|
||||
}
|
||||
|
||||
for (int t : found->list().type()) {
|
||||
if (!InTypeList(static_cast<DataType>(t),
|
||||
constraint.allowed_values())) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return errors::InvalidArgument(
|
||||
"OpKernel '", kernel_def.op(), "' has constraint on attr '",
|
||||
constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(),
|
||||
"', KernelDef: '", ProtoShortDebugString(kernel_def), "'");
|
||||
}
|
||||
}
|
||||
*match = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
31
tensorflow/core/framework/kernel_def_util.h
Normal file
31
tensorflow/core/framework/kernel_def_util.h
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_UTIL_H_
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_UTIL_H_
|
||||
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Returns whether the attrs satisfy the constraints in the kernel_def. Returns
|
||||
// an error if attrs in kernel_def are not found, or have a mismatching type.
|
||||
Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs,
|
||||
bool* match);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_UTIL_H_
|
133
tensorflow/core/framework/kernel_def_util_test.cc
Normal file
133
tensorflow/core/framework/kernel_def_util_test.cc
Normal file
@ -0,0 +1,133 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/kernel_def_util.h"
|
||||
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
NodeDef NodeDefFromText(const string& text) {
|
||||
NodeDef node_def;
|
||||
EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
|
||||
return node_def;
|
||||
}
|
||||
|
||||
KernelDef KernelDefFromText(const string& text) {
|
||||
KernelDef kernel_def;
|
||||
EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &kernel_def));
|
||||
return kernel_def;
|
||||
}
|
||||
|
||||
class AttrsMatchTest : public ::testing::Test {
|
||||
protected:
|
||||
void ExpectStatus(const string& node_def_str, const string& kernel_def_str,
|
||||
error::Code code) {
|
||||
bool match;
|
||||
auto status = KernelAttrsMatch(KernelDefFromText(kernel_def_str),
|
||||
NodeDefFromText(node_def_str), &match);
|
||||
LOG(INFO) << "status: " << status;
|
||||
EXPECT_EQ(code, status.code());
|
||||
if (!status.ok()) {
|
||||
EXPECT_FALSE(match)
|
||||
<< "Expect no match between the given NodeDef and KernelDef";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(AttrsMatchTest, ValidConstraint) {
|
||||
string node_def_str = R"(
|
||||
name: "ValidConstraint-op"
|
||||
op: "ValidConstraint"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
)";
|
||||
string kernel_def_str = R"(
|
||||
op: "ValidConstraint"
|
||||
device_type: "CPU"
|
||||
constraint {
|
||||
name: "T"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
ExpectStatus(node_def_str, kernel_def_str, error::OK);
|
||||
}
|
||||
|
||||
TEST_F(AttrsMatchTest, BadConstraint) {
|
||||
string node_def_str = R"(
|
||||
name: "BadConstraint-op"
|
||||
op: "BadConstraint"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
)";
|
||||
string kernel_def_str = R"(
|
||||
op: "BadConstraint"
|
||||
device_type: "CPU"
|
||||
constraint {
|
||||
name: "T"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
ExpectStatus(node_def_str, kernel_def_str, error::INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
TEST_F(AttrsMatchTest, Unimplemented) {
|
||||
string node_def_str = R"(
|
||||
name: "BadConstraint-op"
|
||||
op: "BadConstraint"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
)";
|
||||
string kernel_def_str = R"(
|
||||
op: "BadConstraint"
|
||||
device_type: "CPU"
|
||||
constraint {
|
||||
name: "T"
|
||||
allowed_values {
|
||||
list {
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
ExpectStatus(node_def_str, kernel_def_str, error::UNIMPLEMENTED);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb_text.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb_text.h"
|
||||
#include "tensorflow/core/framework/kernel_def_util.h"
|
||||
#include "tensorflow/core/framework/log_memory.h"
|
||||
#include "tensorflow/core/framework/memory_types.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -969,62 +970,6 @@ void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
|
||||
|
||||
namespace {
|
||||
|
||||
// Helper for AttrsMatch().
|
||||
bool InTypeList(DataType dt, const AttrValue& type_list) {
|
||||
for (int in_list : type_list.list().type()) {
|
||||
if (dt == in_list) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Returns whether the attrs satisfy the constraints in the kernel_def. Returns
|
||||
// an error if attrs in kernel_def are not found, or have a mismatching type.
|
||||
Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) {
|
||||
*match = false;
|
||||
for (const auto& constraint : kernel_def.constraint()) {
|
||||
if (constraint.allowed_values().list().type_size() == 0) {
|
||||
return errors::Unimplemented(
|
||||
"KernelDef '", ProtoShortDebugString(kernel_def),
|
||||
" has constraint on attr '", constraint.name(),
|
||||
"' with unsupported type: ",
|
||||
SummarizeAttrValue(constraint.allowed_values()));
|
||||
}
|
||||
|
||||
const AttrValue* found = attrs.Find(constraint.name());
|
||||
if (found) {
|
||||
if (found->type() != DT_INVALID) {
|
||||
if (!InTypeList(found->type(), constraint.allowed_values())) {
|
||||
return Status::OK();
|
||||
}
|
||||
} else {
|
||||
if (!AttrValueHasType(*found, "list(type)").ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"KernelDef '", ProtoShortDebugString(kernel_def),
|
||||
"' has constraint on attr '", constraint.name(),
|
||||
"' that has value '", SummarizeAttrValue(*found),
|
||||
"' that does not have type 'type' or 'list(type)' in NodeDef "
|
||||
"'",
|
||||
attrs.SummarizeNode(), "'");
|
||||
}
|
||||
|
||||
for (int t : found->list().type()) {
|
||||
if (!InTypeList(static_cast<DataType>(t),
|
||||
constraint.allowed_values())) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return errors::InvalidArgument(
|
||||
"OpKernel '", kernel_def.op(), "' has constraint on attr '",
|
||||
constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(),
|
||||
"', KernelDef: '", ProtoShortDebugString(kernel_def), "'");
|
||||
}
|
||||
}
|
||||
*match = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static const StringPiece kKernelAttr("_kernel");
|
||||
|
||||
// TODO(irving): Replace with const Node& version below.
|
||||
@ -1043,7 +988,7 @@ Status FindKernelRegistration(const DeviceType& device_type,
|
||||
// If there is a kernel registered for the op and device_type,
|
||||
// check that the attrs match.
|
||||
bool match;
|
||||
TF_RETURN_IF_ERROR(AttrsMatch(node_def, iter->second.def, &match));
|
||||
TF_RETURN_IF_ERROR(KernelAttrsMatch(iter->second.def, node_def, &match));
|
||||
if (match) {
|
||||
if (*reg != nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
|
Loading…
x
Reference in New Issue
Block a user