Add kernel builder constraints for int64, string, bool
PiperOrigin-RevId: 245792486
This commit is contained in:
parent
0251c63f04
commit
4bee2098db
@ -34,6 +34,77 @@ KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <>
|
||||
KernelDefBuilder& KernelDefBuilder::AttrConstraint<int64>(
|
||||
const char* attr_name, gtl::ArraySlice<int64> allowed) {
|
||||
auto* constraint = kernel_def_->add_constraint();
|
||||
constraint->set_name(attr_name);
|
||||
auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
|
||||
for (const int64 integer : allowed) {
|
||||
LOG(INFO) << integer;
|
||||
allowed_values->add_i(integer);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <>
|
||||
KernelDefBuilder& KernelDefBuilder::AttrConstraint<int64>(const char* attr_name,
|
||||
int64 allowed) {
|
||||
return AttrConstraint(
|
||||
attr_name,
|
||||
gtl::ArraySlice<int64>(std::initializer_list<int64>({allowed})));
|
||||
}
|
||||
|
||||
template <>
|
||||
KernelDefBuilder& KernelDefBuilder::AttrConstraint<string>(
|
||||
const char* attr_name, gtl::ArraySlice<string> allowed) {
|
||||
auto* constraint = kernel_def_->add_constraint();
|
||||
constraint->set_name(attr_name);
|
||||
auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
|
||||
for (const auto& str : allowed) {
|
||||
allowed_values->add_s(str);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <>
|
||||
KernelDefBuilder& KernelDefBuilder::AttrConstraint<string>(
|
||||
const char* attr_name, string allowed) {
|
||||
return AttrConstraint(
|
||||
attr_name,
|
||||
gtl::ArraySlice<string>(std::initializer_list<string>({allowed})));
|
||||
}
|
||||
|
||||
template <>
|
||||
KernelDefBuilder& KernelDefBuilder::AttrConstraint<const char*>(
|
||||
const char* attr_name, gtl::ArraySlice<const char*> allowed) {
|
||||
auto* constraint = kernel_def_->add_constraint();
|
||||
constraint->set_name(attr_name);
|
||||
auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
|
||||
for (const auto& str : allowed) {
|
||||
allowed_values->add_s(str);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <>
|
||||
KernelDefBuilder& KernelDefBuilder::AttrConstraint<const char*>(
|
||||
const char* attr_name, const char* allowed) {
|
||||
return AttrConstraint(attr_name,
|
||||
gtl::ArraySlice<const char*>(
|
||||
std::initializer_list<const char*>({allowed})));
|
||||
}
|
||||
|
||||
template <>
|
||||
KernelDefBuilder& KernelDefBuilder::AttrConstraint<bool>(const char* attr_name,
|
||||
bool allowed) {
|
||||
auto* constraint = kernel_def_->add_constraint();
|
||||
constraint->set_name(attr_name);
|
||||
auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
|
||||
allowed_values->add_b(allowed);
|
||||
return *this;
|
||||
}
|
||||
|
||||
KernelDefBuilder& KernelDefBuilder::TypeConstraint(
|
||||
const char* attr_name, gtl::ArraySlice<DataType> allowed) {
|
||||
auto* constraint = kernel_def_->add_constraint();
|
||||
|
@ -39,6 +39,18 @@ class KernelDefBuilder {
|
||||
KernelDefBuilder& Device(const char* device_type);
|
||||
// KernelDefBuilder& Device(DeviceType device_type);
|
||||
|
||||
// Specify that this kernel supports a limited set of values for a
|
||||
// particular type or list(type) attr (a further restriction than
|
||||
// what the Op allows).
|
||||
// Returns *this.
|
||||
template <typename T>
|
||||
KernelDefBuilder& AttrConstraint(const char* attr_name,
|
||||
gtl::ArraySlice<T> allowed);
|
||||
|
||||
// Like AttrConstraint above but supports just a single value.
|
||||
template <typename T>
|
||||
KernelDefBuilder& AttrConstraint(const char* attr_name, T allowed);
|
||||
|
||||
// Specify that this kernel supports a limited set of values for a
|
||||
// particular type or list(type) attr (a further restriction than
|
||||
// what the Op allows).
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
@ -72,6 +73,86 @@ TEST(KernelDefBuilderTest, TypeConstraint) {
|
||||
delete def;
|
||||
}
|
||||
|
||||
TEST(KernelDefBuilderTest, Int64Constraint) {
|
||||
const KernelDef* def =
|
||||
KernelDefBuilder("B").Device(DEVICE_GPU).AttrConstraint("T", 5ll).Build();
|
||||
KernelDef expected;
|
||||
protobuf::TextFormat::ParseFromString(R"proto(
|
||||
op: 'B'
|
||||
device_type: 'GPU'
|
||||
constraint {
|
||||
name: 'T'
|
||||
allowed_values { list { i: 5 } }
|
||||
})proto",
|
||||
&expected);
|
||||
|
||||
EXPECT_EQ(def->DebugString(), expected.DebugString());
|
||||
delete def;
|
||||
|
||||
def = KernelDefBuilder("C")
|
||||
.Device(DEVICE_GPU)
|
||||
.AttrConstraint("U", gtl::ArraySlice<int64>{5ll, 17ll})
|
||||
.AttrConstraint("V", string("proto"))
|
||||
.Build();
|
||||
|
||||
protobuf::TextFormat::ParseFromString(
|
||||
R"proto(
|
||||
op: 'C'
|
||||
device_type: 'GPU'
|
||||
constraint {
|
||||
name: 'U'
|
||||
allowed_values { list { i: [ 5, 17 ] } }
|
||||
}
|
||||
constraint {
|
||||
name: 'V'
|
||||
allowed_values { list { s: 'proto' } }
|
||||
})proto",
|
||||
&expected);
|
||||
EXPECT_EQ(def->DebugString(), expected.DebugString());
|
||||
delete def;
|
||||
}
|
||||
|
||||
TEST(KernelDefBuilderTest, StringConstraint) {
|
||||
const KernelDef* def = KernelDefBuilder("B")
|
||||
.Device(DEVICE_GPU)
|
||||
.AttrConstraint("T", "hi")
|
||||
.Build();
|
||||
KernelDef expected;
|
||||
protobuf::TextFormat::ParseFromString(R"proto(
|
||||
op: 'B'
|
||||
device_type: 'GPU'
|
||||
constraint {
|
||||
name: 'T'
|
||||
allowed_values { list { s: 'hi' } }
|
||||
})proto",
|
||||
&expected);
|
||||
|
||||
EXPECT_EQ(def->DebugString(), expected.DebugString());
|
||||
delete def;
|
||||
|
||||
def = KernelDefBuilder("C")
|
||||
.Device(DEVICE_GPU)
|
||||
.AttrConstraint("U", gtl::ArraySlice<const char*>{"boo", "ya"})
|
||||
.AttrConstraint("V", string("proto"))
|
||||
.Build();
|
||||
|
||||
protobuf::TextFormat::ParseFromString(
|
||||
R"proto(
|
||||
op: 'C'
|
||||
device_type: 'GPU'
|
||||
constraint {
|
||||
name: 'U'
|
||||
allowed_values { list { s: [ 'boo', 'ya' ] } }
|
||||
}
|
||||
constraint {
|
||||
name: 'V'
|
||||
allowed_values { list { s: 'proto' } }
|
||||
})proto",
|
||||
&expected);
|
||||
EXPECT_EQ(def->DebugString(), expected.DebugString());
|
||||
delete def;
|
||||
}
|
||||
|
||||
TEST(KernelDefBuilderTest, HostMemory) {
|
||||
const KernelDef* def = KernelDefBuilder("E")
|
||||
.Device(DEVICE_GPU)
|
||||
|
Loading…
Reference in New Issue
Block a user