Add kernel builder constraints for int64, string, bool

PiperOrigin-RevId: 245792486
This commit is contained in:
A. Unique TensorFlower 2019-04-29 11:52:55 -07:00 committed by TensorFlower Gardener
parent 0251c63f04
commit 4bee2098db
3 changed files with 164 additions and 0 deletions

View File

@ -34,6 +34,77 @@ KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) {
return *this;
}
template <>
KernelDefBuilder& KernelDefBuilder::AttrConstraint<int64>(
const char* attr_name, gtl::ArraySlice<int64> allowed) {
auto* constraint = kernel_def_->add_constraint();
constraint->set_name(attr_name);
auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
for (const int64 integer : allowed) {
LOG(INFO) << integer;
allowed_values->add_i(integer);
}
return *this;
}
template <>
KernelDefBuilder& KernelDefBuilder::AttrConstraint<int64>(const char* attr_name,
int64 allowed) {
return AttrConstraint(
attr_name,
gtl::ArraySlice<int64>(std::initializer_list<int64>({allowed})));
}
template <>
KernelDefBuilder& KernelDefBuilder::AttrConstraint<string>(
const char* attr_name, gtl::ArraySlice<string> allowed) {
auto* constraint = kernel_def_->add_constraint();
constraint->set_name(attr_name);
auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
for (const auto& str : allowed) {
allowed_values->add_s(str);
}
return *this;
}
template <>
KernelDefBuilder& KernelDefBuilder::AttrConstraint<string>(
const char* attr_name, string allowed) {
return AttrConstraint(
attr_name,
gtl::ArraySlice<string>(std::initializer_list<string>({allowed})));
}
template <>
KernelDefBuilder& KernelDefBuilder::AttrConstraint<const char*>(
const char* attr_name, gtl::ArraySlice<const char*> allowed) {
auto* constraint = kernel_def_->add_constraint();
constraint->set_name(attr_name);
auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
for (const auto& str : allowed) {
allowed_values->add_s(str);
}
return *this;
}
template <>
KernelDefBuilder& KernelDefBuilder::AttrConstraint<const char*>(
const char* attr_name, const char* allowed) {
return AttrConstraint(attr_name,
gtl::ArraySlice<const char*>(
std::initializer_list<const char*>({allowed})));
}
template <>
KernelDefBuilder& KernelDefBuilder::AttrConstraint<bool>(const char* attr_name,
bool allowed) {
auto* constraint = kernel_def_->add_constraint();
constraint->set_name(attr_name);
auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
allowed_values->add_b(allowed);
return *this;
}
KernelDefBuilder& KernelDefBuilder::TypeConstraint(
const char* attr_name, gtl::ArraySlice<DataType> allowed) {
auto* constraint = kernel_def_->add_constraint();

View File

@ -39,6 +39,18 @@ class KernelDefBuilder {
KernelDefBuilder& Device(const char* device_type);
// KernelDefBuilder& Device(DeviceType device_type);
// Specify that this kernel supports a limited set of values for a
// particular type or list(type) attr (a further restriction than
// what the Op allows).
// Returns *this.
template <typename T>
KernelDefBuilder& AttrConstraint(const char* attr_name,
gtl::ArraySlice<T> allowed);
// Like AttrConstraint above but supports just a single value.
template <typename T>
KernelDefBuilder& AttrConstraint(const char* attr_name, T allowed);
// Specify that this kernel supports a limited set of values for a
// particular type or list(type) attr (a further restriction than
// what the Op allows).

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
@ -72,6 +73,86 @@ TEST(KernelDefBuilderTest, TypeConstraint) {
delete def;
}
TEST(KernelDefBuilderTest, Int64Constraint) {
const KernelDef* def =
KernelDefBuilder("B").Device(DEVICE_GPU).AttrConstraint("T", 5ll).Build();
KernelDef expected;
protobuf::TextFormat::ParseFromString(R"proto(
op: 'B'
device_type: 'GPU'
constraint {
name: 'T'
allowed_values { list { i: 5 } }
})proto",
&expected);
EXPECT_EQ(def->DebugString(), expected.DebugString());
delete def;
def = KernelDefBuilder("C")
.Device(DEVICE_GPU)
.AttrConstraint("U", gtl::ArraySlice<int64>{5ll, 17ll})
.AttrConstraint("V", string("proto"))
.Build();
protobuf::TextFormat::ParseFromString(
R"proto(
op: 'C'
device_type: 'GPU'
constraint {
name: 'U'
allowed_values { list { i: [ 5, 17 ] } }
}
constraint {
name: 'V'
allowed_values { list { s: 'proto' } }
})proto",
&expected);
EXPECT_EQ(def->DebugString(), expected.DebugString());
delete def;
}
TEST(KernelDefBuilderTest, StringConstraint) {
const KernelDef* def = KernelDefBuilder("B")
.Device(DEVICE_GPU)
.AttrConstraint("T", "hi")
.Build();
KernelDef expected;
protobuf::TextFormat::ParseFromString(R"proto(
op: 'B'
device_type: 'GPU'
constraint {
name: 'T'
allowed_values { list { s: 'hi' } }
})proto",
&expected);
EXPECT_EQ(def->DebugString(), expected.DebugString());
delete def;
def = KernelDefBuilder("C")
.Device(DEVICE_GPU)
.AttrConstraint("U", gtl::ArraySlice<const char*>{"boo", "ya"})
.AttrConstraint("V", string("proto"))
.Build();
protobuf::TextFormat::ParseFromString(
R"proto(
op: 'C'
device_type: 'GPU'
constraint {
name: 'U'
allowed_values { list { s: [ 'boo', 'ya' ] } }
}
constraint {
name: 'V'
allowed_values { list { s: 'proto' } }
})proto",
&expected);
EXPECT_EQ(def->DebugString(), expected.DebugString());
delete def;
}
TEST(KernelDefBuilderTest, HostMemory) {
const KernelDef* def = KernelDefBuilder("E")
.Device(DEVICE_GPU)