Move AttrsMatch() to kernel_def_util.h so client can use it.

PiperOrigin-RevId: 201991753
This commit is contained in:
A. Unique TensorFlower 2018-06-25 11:44:29 -07:00 committed by TensorFlower Gardener
parent 7c5d31aa32
commit 73524dc765
5 changed files with 251 additions and 57 deletions

View File

@ -792,6 +792,7 @@ tf_cuda_library(
"framework/graph_def_util.h",
"framework/graph_to_functiondef.h",
"framework/kernel_def_builder.h",
"framework/kernel_def_util.h",
"framework/log_memory.h",
"framework/lookup_interface.h",
"framework/memory_types.h",
@ -3376,6 +3377,7 @@ tf_cc_tests(
"framework/graph_def_util_test.cc",
"framework/graph_to_functiondef_test.cc",
"framework/kernel_def_builder_test.cc",
"framework/kernel_def_util_test.cc",
"framework/memory_types_test.cc",
"framework/node_def_builder_test.cc",
"framework/node_def_util_test.cc",

View File

@ -0,0 +1,83 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/kernel_def_util.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/kernel_def.pb_text.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/types.h"
namespace tensorflow {
namespace {
// Helper for KernelAttrsMatch().
bool InTypeList(DataType dt, const AttrValue& type_list) {
for (int in_list : type_list.list().type()) {
if (dt == in_list) return true;
}
return false;
}
} // namespace
Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs,
bool* match) {
*match = false;
for (const auto& constraint : kernel_def.constraint()) {
if (constraint.allowed_values().list().type_size() == 0) {
return errors::Unimplemented(
"KernelDef '", ProtoShortDebugString(kernel_def),
" has constraint on attr '", constraint.name(),
"' with unsupported type: ",
SummarizeAttrValue(constraint.allowed_values()));
}
const AttrValue* found = attrs.Find(constraint.name());
if (found) {
if (found->type() != DT_INVALID) {
if (!InTypeList(found->type(), constraint.allowed_values())) {
return Status::OK();
}
} else {
if (!AttrValueHasType(*found, "list(type)").ok()) {
return errors::InvalidArgument(
"KernelDef '", ProtoShortDebugString(kernel_def),
"' has constraint on attr '", constraint.name(),
"' that has value '", SummarizeAttrValue(*found),
"' that does not have type 'type' or 'list(type)' in NodeDef "
"'",
attrs.SummarizeNode(), "'");
}
for (int t : found->list().type()) {
if (!InTypeList(static_cast<DataType>(t),
constraint.allowed_values())) {
return Status::OK();
}
}
}
} else {
return errors::InvalidArgument(
"OpKernel '", kernel_def.op(), "' has constraint on attr '",
constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(),
"', KernelDef: '", ProtoShortDebugString(kernel_def), "'");
}
}
*match = true;
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,31 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_UTIL_H_
#define TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_UTIL_H_
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
namespace tensorflow {
// Returns whether the attrs satisfy the constraints in the kernel_def. Returns
// an error if attrs in kernel_def are not found, or have a mismatching type.
Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs,
bool* match);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_UTIL_H_

View File

@ -0,0 +1,133 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/kernel_def_util.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
NodeDef NodeDefFromText(const string& text) {
NodeDef node_def;
EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
return node_def;
}
KernelDef KernelDefFromText(const string& text) {
KernelDef kernel_def;
EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &kernel_def));
return kernel_def;
}
class AttrsMatchTest : public ::testing::Test {
protected:
void ExpectStatus(const string& node_def_str, const string& kernel_def_str,
error::Code code) {
bool match;
auto status = KernelAttrsMatch(KernelDefFromText(kernel_def_str),
NodeDefFromText(node_def_str), &match);
LOG(INFO) << "status: " << status;
EXPECT_EQ(code, status.code());
if (!status.ok()) {
EXPECT_FALSE(match)
<< "Expect no match between the given NodeDef and KernelDef";
}
}
};
TEST_F(AttrsMatchTest, ValidConstraint) {
string node_def_str = R"(
name: "ValidConstraint-op"
op: "ValidConstraint"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
)";
string kernel_def_str = R"(
op: "ValidConstraint"
device_type: "CPU"
constraint {
name: "T"
allowed_values {
list {
type: DT_FLOAT
}
}
}
)";
ExpectStatus(node_def_str, kernel_def_str, error::OK);
}
TEST_F(AttrsMatchTest, BadConstraint) {
string node_def_str = R"(
name: "BadConstraint-op"
op: "BadConstraint"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
)";
string kernel_def_str = R"(
op: "BadConstraint"
device_type: "CPU"
constraint {
name: "T"
allowed_values {
list {
type: DT_FLOAT
}
}
}
)";
ExpectStatus(node_def_str, kernel_def_str, error::INVALID_ARGUMENT);
}
TEST_F(AttrsMatchTest, Unimplemented) {
string node_def_str = R"(
name: "BadConstraint-op"
op: "BadConstraint"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
)";
string kernel_def_str = R"(
op: "BadConstraint"
device_type: "CPU"
constraint {
name: "T"
allowed_values {
list {
}
}
}
)";
ExpectStatus(node_def_str, kernel_def_str, error::UNIMPLEMENTED);
}
} // namespace
} // namespace tensorflow

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/graph.pb_text.h"
#include "tensorflow/core/framework/kernel_def.pb_text.h"
#include "tensorflow/core/framework/kernel_def_util.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -969,62 +970,6 @@ void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
namespace {
// Helper for AttrsMatch().
bool InTypeList(DataType dt, const AttrValue& type_list) {
for (int in_list : type_list.list().type()) {
if (dt == in_list) return true;
}
return false;
}
// Returns whether the attrs satisfy the constraints in the kernel_def. Returns
// an error if attrs in kernel_def are not found, or have a mismatching type.
Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) {
*match = false;
for (const auto& constraint : kernel_def.constraint()) {
if (constraint.allowed_values().list().type_size() == 0) {
return errors::Unimplemented(
"KernelDef '", ProtoShortDebugString(kernel_def),
" has constraint on attr '", constraint.name(),
"' with unsupported type: ",
SummarizeAttrValue(constraint.allowed_values()));
}
const AttrValue* found = attrs.Find(constraint.name());
if (found) {
if (found->type() != DT_INVALID) {
if (!InTypeList(found->type(), constraint.allowed_values())) {
return Status::OK();
}
} else {
if (!AttrValueHasType(*found, "list(type)").ok()) {
return errors::InvalidArgument(
"KernelDef '", ProtoShortDebugString(kernel_def),
"' has constraint on attr '", constraint.name(),
"' that has value '", SummarizeAttrValue(*found),
"' that does not have type 'type' or 'list(type)' in NodeDef "
"'",
attrs.SummarizeNode(), "'");
}
for (int t : found->list().type()) {
if (!InTypeList(static_cast<DataType>(t),
constraint.allowed_values())) {
return Status::OK();
}
}
}
} else {
return errors::InvalidArgument(
"OpKernel '", kernel_def.op(), "' has constraint on attr '",
constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(),
"', KernelDef: '", ProtoShortDebugString(kernel_def), "'");
}
}
*match = true;
return Status::OK();
}
static const StringPiece kKernelAttr("_kernel");
// TODO(irving): Replace with const Node& version below.
@ -1043,7 +988,7 @@ Status FindKernelRegistration(const DeviceType& device_type,
// If there is a kernel registered for the op and device_type,
// check that the attrs match.
bool match;
TF_RETURN_IF_ERROR(AttrsMatch(node_def, iter->second.def, &match));
TF_RETURN_IF_ERROR(KernelAttrsMatch(iter->second.def, node_def, &match));
if (match) {
if (*reg != nullptr) {
return errors::InvalidArgument(