Revert PR #44017: [PluggableDevice] Kernel C API enhancement for retrieving attributes

PiperOrigin-RevId: 346844467
Change-Id: I9d2121c3ea2402df51879852d0bade21577b3d9b
This commit is contained in:
Mihai Maruseac 2020-12-10 12:45:57 -08:00 committed by TensorFlower Gardener
parent 94bef4cb74
commit 03e4d9f255
3 changed files with 3 additions and 623 deletions

View File

@ -32,7 +32,6 @@ limitations under the License.
#include "tensorflow/stream_executor/stream.h"
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
using tensorflow::errors::InvalidArgument;
// This file forms the basis of a stable ABI for third-party kernel
// implementations. It is crucial that changes to this file are made cautiously
// and with a focus on maintaining both source and binary compatibility.
@ -88,25 +87,9 @@ void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name,
TF_SetStatus(status, TF_OK, "");
}
#undef CASE
} // namespace
} // namespace tensorflow
namespace {
const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx,
const char* attr_name,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
const tensorflow::AttrValue* attr =
::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name);
if (attr == nullptr) {
status->status = InvalidArgument("Operation '", cc_ctx->def().name(),
"' has no attr named '", attr_name, "'.");
}
return attr;
}
} // namespace
void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder,
const char* attr_name,
const TF_DataType type,
@ -274,81 +257,7 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
cc_ctx->CtxFailure(s);
}
void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx,
const char* attr_name,
int32_t* list_size,
int32_t* total_size,
TF_Status* status) {
const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);
if (!status->status.ok()) {
*list_size = -1;
*total_size = -1;
return;
}
switch (attr->value_case()) {
#define SINGLE_CASE(kK, attr_type, size_expr) \
case tensorflow::AttrValue::kK: \
*list_size = -1; \
*total_size = size_expr; \
break;
SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
SINGLE_CASE(kI, TF_ATTR_INT, -1);
SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
SINGLE_CASE(kShape, TF_ATTR_SHAPE,
attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
#undef SINGLE_CASE
case tensorflow::AttrValue::kList:
*list_size = 0;
*total_size = -1;
#define LIST_CASE(field, attr_type, ...) \
if (attr->list().field##_size() > 0) { \
*list_size = attr->list().field##_size(); \
__VA_ARGS__; \
break; \
}
LIST_CASE(
s, TF_ATTR_STRING, *total_size = 0;
for (int i = 0; i < attr->list().s_size();
++i) { *total_size += attr->list().s(i).size(); });
LIST_CASE(i, TF_ATTR_INT);
LIST_CASE(f, TF_ATTR_FLOAT);
LIST_CASE(b, TF_ATTR_BOOL);
LIST_CASE(type, TF_ATTR_TYPE);
LIST_CASE(
shape, TF_ATTR_SHAPE, *total_size = 0;
for (int i = 0; i < attr->list().shape_size(); ++i) {
const auto& s = attr->list().shape(i);
*total_size += s.unknown_rank() ? 0 : s.dim_size();
});
LIST_CASE(tensor, TF_ATTR_TENSOR);
LIST_CASE(tensor, TF_ATTR_FUNC);
#undef LIST_CASE
break;
case tensorflow::AttrValue::kPlaceholder:
*list_size = -1;
*total_size = -1;
break;
case tensorflow::AttrValue::kFunc:
*list_size = -1;
*total_size = -1;
break;
case tensorflow::AttrValue::VALUE_NOT_SET:
status->status =
InvalidArgument("Attribute '", attr_name, "' has no value set");
break;
}
}
#define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field) \
#define DEFINE_TF_GETATTR(func, c_type, cc_type) \
void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \
const char* attr_name, \
c_type* val, TF_Status* status) { \
@ -360,84 +269,10 @@ void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx,
if (s.ok()) { \
*val = static_cast<c_type>(v); \
} \
} \
void TF_OpKernelConstruction_GetAttr##func##List( \
TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals, \
int max_vals, TF_Status* status) { \
TF_SetStatus(status, TF_OK, ""); \
const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); \
if (!status->status.ok()) return; \
if (attr->value_case() != tensorflow::AttrValue::kList) { \
status->status = \
InvalidArgument("Value for '", attr_name, "' is not a list."); \
return; \
} \
status->status = \
tensorflow::AttrValueHasType(*attr, "list(" attr_type ")"); \
if (!status->status.ok()) return; \
const auto len = std::min(max_vals, attr->list().list_field##_size()); \
for (int i = 0; i < len; ++i) { \
vals[i] = static_cast<c_type>(attr->list().list_field(i)); \
} \
}
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type", type)
DEFINE_TF_GETATTR(Int32, int32_t, tensorflow::int32, "int", i)
DEFINE_TF_GETATTR(Int64, int64_t, tensorflow::int64, "int", i)
DEFINE_TF_GETATTR(Float, float, float, "float", f)
DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool", b)
void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx,
const char* attr_name, char* value,
size_t max_length,
TF_Status* status) {
std::string v;
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
::tensorflow::Set_TF_Status_from_Status(status, s);
if (!status->status.ok()) return;
if (max_length <= 0) {
return;
}
std::memcpy(value, v.data(), std::min<size_t>(v.length(), max_length));
}
void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx,
const char* attr_name,
char** values, size_t* lengths,
int max_values, void* storage,
size_t storage_size,
TF_Status* status) {
std::vector<std::string> v;
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
::tensorflow::Set_TF_Status_from_Status(status, s);
if (!status->status.ok()) return;
const auto len = std::min(max_values, static_cast<int>(v.size()));
char* p = static_cast<char*>(storage);
for (int i = 0; i < len; ++i) {
const std::string& s = v[i];
values[i] = p;
lengths[i] = s.size();
if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
status->status = InvalidArgument(
"Not enough storage to hold the requested list of strings");
return;
}
memcpy(values[i], s.data(), s.size());
p += s.size();
}
}
bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx,
const char* attr_name, TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
return cc_ctx->HasAttr(attr_name);
}
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
DEFINE_TF_GETATTR(Int32, tensorflow::int32, int32_t)
TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);

View File

@ -184,24 +184,6 @@ TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType(
// Returns the step ID of the given context.
TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx);
// Get the list_size and total_size of the attribute `attr_name` of `oper`.
// list_size - the length of the list.
// total_size - total size of the list.
// (1) If attr_type == TF_ATTR_STRING
// then total_size is the cumulative byte size
// of all the strings in the list.
// (3) If attr_type == TF_ATTR_SHAPE
// then total_size is the number of dimensions
// of the shape valued attribute, or -1
// if its rank is unknown.
// (4) If attr_type == TF_ATTR_SHAPE
// then total_size is the cumulative number
// of dimensions of all shapes in the list.
// (5) Otherwise, total_size is undefined.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrSize(
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* list_size,
int32_t* total_size, TF_Status* status);
// Interprets the named kernel construction attribute as a TF_DataType and
// places it into *val. *status is set to TF_OK.
//
@ -220,112 +202,6 @@ TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32(
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val,
TF_Status* status);
// Interprets the named kernel construction attribute as int64_t and
// places it into *val. *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// int64, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64(
TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* val,
TF_Status* status);
// Interprets the named kernel construction attribute as float and
// places it into *val. *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// float, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloat(
TF_OpKernelConstruction* ctx, const char* attr_name, float* val,
TF_Status* status);
// Interprets the named kernel construction attribute as bool and
// places it into *val. *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// bool, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBool(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* val,
TF_Status* status);
// Interprets the named kernel construction attribute as string and
// places it into *val. `val` must
// point to an array of length at least `max_length` (ideally set to
// total_size from TF_OpKernelConstruction_GetAttrSize(ctx,
// attr_name, list_size, total_size)). *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// string, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrString(
TF_OpKernelConstruction* ctx, const char* attr_name, char* val,
size_t max_length, TF_Status* status);
// Interprets the named kernel construction attribute as a TF_DataType array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTypeList(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as int32_t array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32List(
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as int64_t array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64List(
TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as float array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloatList(
TF_OpKernelConstruction* ctx, const char* attr_name, float* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as bool array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBoolList(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as string array and fills
// in `vals` and `lengths`, each of which must point to an array of length at
// least `max_values`. *status is set to TF_OK. The elements of values will
// point to addresses in `storage` which must be at least `storage_size` bytes
// in length. Ideally, max_values would be set to list_size and `storage` would
// be at least total_size, obtained from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrStringList(
TF_OpKernelConstruction* ctx, const char* attr_name, char** vals,
size_t* lengths, int max_values, void* storage, size_t storage_size,
TF_Status* status);
// Return true if the kernel construction has the attr_name
TF_CAPI_EXPORT extern bool TF_OpKernelConstruction_HasAttr(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Status* status);
// Returns the unique operation name for this OpKernel.
TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName(
TF_OpKernelConstruction* ctx);

View File

@ -161,337 +161,6 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
ASSERT_TRUE(delete_called);
}
// REGISTER_OP for TF_OpKernelConstruction_GetAttr* test cases.
// Registers two ops, each with a single attribute called 'Attr'.
// The attribute in one op will have a type 'type', the other
// will have list(type).
#define ATTR_TEST_REGISTER_OP(name, type) \
REGISTER_OP("TestKernelAttr" #name) \
.Attr("Attr: " #type) \
.SetShapeFn(tensorflow::shape_inference::UnknownShape); \
REGISTER_OP("TestKernelAttr" #name "List") \
.Attr("Attr: list(" #type ")") \
.SetShapeFn(tensorflow::shape_inference::UnknownShape)
ATTR_TEST_REGISTER_OP(String, string);
ATTR_TEST_REGISTER_OP(Int, int);
ATTR_TEST_REGISTER_OP(Float, float);
ATTR_TEST_REGISTER_OP(Bool, bool);
ATTR_TEST_REGISTER_OP(Type, type);
#undef ATTR_TEST_REGISTER_OP
// Helper macros for the TF_OpKernelConstruction_GetAttr* tests.
#define EXPECT_TF_SIZE(attr_name, expected_list_size, expected_total_size) \
do { \
int32_t list_size, total_size; \
TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size, \
&total_size, status); \
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); \
EXPECT_EQ(expected_list_size, list_size); \
EXPECT_EQ(expected_total_size, total_size); \
} while (0)
typedef void* (*MyCreateFuncWithAttr)(TF_OpKernelConstruction*);
class TestKernelAttr : public ::testing::Test {
public:
TestKernelAttr() {}
~TestKernelAttr() {}
std::unique_ptr<OpKernel> GetFakeKernelWithAttr(const char* op_name,
AttrValue v, Status* status) {
NodeDef def;
def.set_op(op_name);
def.set_name("FakeNode");
def.set_device("FakeDevice");
(*def.mutable_attr())["Attr"] = v;
return CreateOpKernel(DeviceType("FakeDevice"), nullptr, nullptr, def, 1,
status);
}
void SetAttr(MyCreateFuncWithAttr MyCreateFuncAttr, const char* op_name,
AttrValue& v) {
TF_KernelBuilder* builder = TF_NewKernelBuilder(
op_name, "FakeDevice", MyCreateFuncAttr, &MyComputeFunc, &MyDeleteFunc);
{
TF_Status* status = TF_NewStatus();
TF_RegisterKernelBuilder("FakeNode", builder, status);
EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_DeleteStatus(status);
}
Status status;
std::unique_ptr<OpKernel> kernel =
GetFakeKernelWithAttr(op_name, v, &status);
TF_EXPECT_OK(status);
ASSERT_NE(nullptr, kernel.get());
kernel->Compute(nullptr);
ASSERT_TRUE(delete_called);
}
};
TEST_F(TestKernelAttr, String) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
std::unique_ptr<char[]> val(new char[5]);
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ 5);
TF_OpKernelConstruction_GetAttrString(ctx, "Attr", val.get(),
/*max_length*/ 5, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ("bunny", string(static_cast<const char*>(val.get()), 5));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_s("bunny");
SetAttr(my_create_func, "TestKernelAttrString", v);
}
TEST_F(TestKernelAttr, StringList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
std::vector<string> list = {"bugs", "bunny", "duck"};
int list_total_size = 0;
for (const auto& s : list) {
list_total_size += s.size();
}
TF_Status* status = TF_NewStatus();
std::unique_ptr<char*[]> values(new char*[list.size()]);
std::unique_ptr<size_t[]> lens(new size_t[list.size()]);
std::unique_ptr<char[]> storage(new char[list_total_size]);
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list.size(),
/*expected_total_size*/ list_total_size);
TF_OpKernelConstruction_GetAttrStringList(
ctx, "Attr", values.get(), lens.get(), list.size(), storage.get(),
list_total_size, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
for (size_t i = 0; i < list.size(); ++i) {
EXPECT_EQ(list[i].size(), lens[i]) << i;
EXPECT_EQ(list[i], string(static_cast<const char*>(values[i]), lens[i]))
<< i;
}
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
auto attr_in = gtl::ArraySlice<StringPiece>({"bugs", "bunny", "duck"});
SetAttrValue(attr_in, &v);
SetAttr(my_create_func, "TestKernelAttrStringList", v);
}
TEST_F(TestKernelAttr, Int) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
int64_t val;
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrInt64(ctx, "Attr", &val, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(1234, val);
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_i(1234);
SetAttr(my_create_func, "TestKernelAttrInt", v);
}
TEST_F(TestKernelAttr, IntList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
const int64_t list[] = {1, 2, 3, 4};
const size_t list_size = TF_ARRAYSIZE(list);
int64_t values[list_size];
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrInt64List(ctx, "Attr", values, list_size,
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_TRUE(
std::equal(std::begin(list), std::end(list), std::begin(values)));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
auto attr_in = gtl::ArraySlice<int64>({1, 2, 3, 4});
SetAttrValue(attr_in, &v);
SetAttr(my_create_func, "TestKernelAttrIntList", v);
}
TEST_F(TestKernelAttr, Float) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
float val;
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrFloat(ctx, "Attr", &val, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_FLOAT_EQ(2.718, val);
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_f(2.718);
SetAttr(my_create_func, "TestKernelAttrFloat", v);
}
TEST_F(TestKernelAttr, FloatList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
const float list[] = {1.414, 2.718, 3.1415};
const size_t list_size = TF_ARRAYSIZE(list);
float values[list_size];
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrFloatList(ctx, "Attr", values, list_size,
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_TRUE(
std::equal(std::begin(list), std::end(list), std::begin(values)));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
auto attr_in = gtl::ArraySlice<float>({1.414, 2.718, 3.1415});
SetAttrValue(attr_in, &v);
SetAttr(my_create_func, "TestKernelAttrFloatList", v);
}
TEST_F(TestKernelAttr, Bool) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
unsigned char val;
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrBool(ctx, "Attr", &val, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(1, val);
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_b(1);
SetAttr(my_create_func, "TestKernelAttrBool", v);
}
TEST_F(TestKernelAttr, BoolList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
const unsigned char list[] = {1, 0, 1, 0};
const size_t list_size = TF_ARRAYSIZE(list);
unsigned char values[list_size];
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrBoolList(ctx, "Attr", values, list_size,
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_TRUE(
std::equal(std::begin(list), std::end(list), std::begin(values)));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
auto attr_in = gtl::ArraySlice<bool>({1, 0, 1, 0});
SetAttrValue(attr_in, &v);
SetAttr(my_create_func, "TestKernelAttrBoolList", v);
}
TEST_F(TestKernelAttr, Type) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
TF_DataType val;
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrType(ctx, "Attr", &val, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(TF_FLOAT, val);
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_type(DT_FLOAT);
SetAttr(my_create_func, "TestKernelAttrType", v);
}
TEST_F(TestKernelAttr, TypeList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
const TF_DataType list[] = {TF_FLOAT, TF_DOUBLE, TF_HALF, TF_COMPLEX128};
const size_t list_size = TF_ARRAYSIZE(list);
TF_DataType values[list_size];
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrTypeList(ctx, "Attr", values, list_size,
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_TRUE(
std::equal(std::begin(list), std::end(list), std::begin(values)));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
auto attr_in =
gtl::ArraySlice<DataType>({DT_FLOAT, DT_DOUBLE, DT_HALF, DT_COMPLEX128});
SetAttrValue(attr_in, &v);
SetAttr(my_create_func, "TestKernelAttrTypeList", v);
}
#undef EXPECT_TF_SIZE
class DummyDevice : public DeviceBase {
public:
explicit DummyDevice(Env* env) : DeviceBase(env) {}