Revert PR #44017: [PluggableDevice] Kernel C API enhancement for retrieving attributes
PiperOrigin-RevId: 346844467 Change-Id: I9d2121c3ea2402df51879852d0bade21577b3d9b
This commit is contained in:
parent
94bef4cb74
commit
03e4d9f255
@ -32,7 +32,6 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
|
||||
using tensorflow::errors::InvalidArgument;
|
||||
// This file forms the basis of a stable ABI for third-party kernel
|
||||
// implementations. It is crucial that changes to this file are made cautiously
|
||||
// and with a focus on maintaining both source and binary compatibility.
|
||||
@ -88,25 +87,9 @@ void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name,
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
#undef CASE
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace {
|
||||
const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx,
|
||||
const char* attr_name,
|
||||
TF_Status* status) {
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
|
||||
const tensorflow::AttrValue* attr =
|
||||
::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name);
|
||||
if (attr == nullptr) {
|
||||
status->status = InvalidArgument("Operation '", cc_ctx->def().name(),
|
||||
"' has no attr named '", attr_name, "'.");
|
||||
}
|
||||
return attr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder,
|
||||
const char* attr_name,
|
||||
const TF_DataType type,
|
||||
@ -274,81 +257,7 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
|
||||
cc_ctx->CtxFailure(s);
|
||||
}
|
||||
|
||||
void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx,
|
||||
const char* attr_name,
|
||||
int32_t* list_size,
|
||||
int32_t* total_size,
|
||||
TF_Status* status) {
|
||||
const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);
|
||||
if (!status->status.ok()) {
|
||||
*list_size = -1;
|
||||
*total_size = -1;
|
||||
return;
|
||||
}
|
||||
switch (attr->value_case()) {
|
||||
#define SINGLE_CASE(kK, attr_type, size_expr) \
|
||||
case tensorflow::AttrValue::kK: \
|
||||
*list_size = -1; \
|
||||
*total_size = size_expr; \
|
||||
break;
|
||||
|
||||
SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
|
||||
SINGLE_CASE(kI, TF_ATTR_INT, -1);
|
||||
SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
|
||||
SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
|
||||
SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
|
||||
SINGLE_CASE(kShape, TF_ATTR_SHAPE,
|
||||
attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
|
||||
SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
|
||||
#undef SINGLE_CASE
|
||||
|
||||
case tensorflow::AttrValue::kList:
|
||||
*list_size = 0;
|
||||
*total_size = -1;
|
||||
#define LIST_CASE(field, attr_type, ...) \
|
||||
if (attr->list().field##_size() > 0) { \
|
||||
*list_size = attr->list().field##_size(); \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
}
|
||||
|
||||
LIST_CASE(
|
||||
s, TF_ATTR_STRING, *total_size = 0;
|
||||
for (int i = 0; i < attr->list().s_size();
|
||||
++i) { *total_size += attr->list().s(i).size(); });
|
||||
LIST_CASE(i, TF_ATTR_INT);
|
||||
LIST_CASE(f, TF_ATTR_FLOAT);
|
||||
LIST_CASE(b, TF_ATTR_BOOL);
|
||||
LIST_CASE(type, TF_ATTR_TYPE);
|
||||
LIST_CASE(
|
||||
shape, TF_ATTR_SHAPE, *total_size = 0;
|
||||
for (int i = 0; i < attr->list().shape_size(); ++i) {
|
||||
const auto& s = attr->list().shape(i);
|
||||
*total_size += s.unknown_rank() ? 0 : s.dim_size();
|
||||
});
|
||||
LIST_CASE(tensor, TF_ATTR_TENSOR);
|
||||
LIST_CASE(tensor, TF_ATTR_FUNC);
|
||||
#undef LIST_CASE
|
||||
break;
|
||||
|
||||
case tensorflow::AttrValue::kPlaceholder:
|
||||
*list_size = -1;
|
||||
*total_size = -1;
|
||||
break;
|
||||
|
||||
case tensorflow::AttrValue::kFunc:
|
||||
*list_size = -1;
|
||||
*total_size = -1;
|
||||
break;
|
||||
|
||||
case tensorflow::AttrValue::VALUE_NOT_SET:
|
||||
status->status =
|
||||
InvalidArgument("Attribute '", attr_name, "' has no value set");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field) \
|
||||
#define DEFINE_TF_GETATTR(func, c_type, cc_type) \
|
||||
void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \
|
||||
const char* attr_name, \
|
||||
c_type* val, TF_Status* status) { \
|
||||
@ -360,84 +269,10 @@ void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx,
|
||||
if (s.ok()) { \
|
||||
*val = static_cast<c_type>(v); \
|
||||
} \
|
||||
} \
|
||||
void TF_OpKernelConstruction_GetAttr##func##List( \
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals, \
|
||||
int max_vals, TF_Status* status) { \
|
||||
TF_SetStatus(status, TF_OK, ""); \
|
||||
const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); \
|
||||
if (!status->status.ok()) return; \
|
||||
if (attr->value_case() != tensorflow::AttrValue::kList) { \
|
||||
status->status = \
|
||||
InvalidArgument("Value for '", attr_name, "' is not a list."); \
|
||||
return; \
|
||||
} \
|
||||
status->status = \
|
||||
tensorflow::AttrValueHasType(*attr, "list(" attr_type ")"); \
|
||||
if (!status->status.ok()) return; \
|
||||
const auto len = std::min(max_vals, attr->list().list_field##_size()); \
|
||||
for (int i = 0; i < len; ++i) { \
|
||||
vals[i] = static_cast<c_type>(attr->list().list_field(i)); \
|
||||
} \
|
||||
}
|
||||
|
||||
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type", type)
|
||||
DEFINE_TF_GETATTR(Int32, int32_t, tensorflow::int32, "int", i)
|
||||
DEFINE_TF_GETATTR(Int64, int64_t, tensorflow::int64, "int", i)
|
||||
DEFINE_TF_GETATTR(Float, float, float, "float", f)
|
||||
DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool", b)
|
||||
|
||||
void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx,
|
||||
const char* attr_name, char* value,
|
||||
size_t max_length,
|
||||
TF_Status* status) {
|
||||
std::string v;
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
|
||||
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
|
||||
::tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
|
||||
if (!status->status.ok()) return;
|
||||
|
||||
if (max_length <= 0) {
|
||||
return;
|
||||
}
|
||||
std::memcpy(value, v.data(), std::min<size_t>(v.length(), max_length));
|
||||
}
|
||||
|
||||
void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx,
|
||||
const char* attr_name,
|
||||
char** values, size_t* lengths,
|
||||
int max_values, void* storage,
|
||||
size_t storage_size,
|
||||
TF_Status* status) {
|
||||
std::vector<std::string> v;
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
|
||||
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
|
||||
::tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
|
||||
if (!status->status.ok()) return;
|
||||
|
||||
const auto len = std::min(max_values, static_cast<int>(v.size()));
|
||||
char* p = static_cast<char*>(storage);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
const std::string& s = v[i];
|
||||
values[i] = p;
|
||||
lengths[i] = s.size();
|
||||
if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
|
||||
status->status = InvalidArgument(
|
||||
"Not enough storage to hold the requested list of strings");
|
||||
return;
|
||||
}
|
||||
memcpy(values[i], s.data(), s.size());
|
||||
p += s.size();
|
||||
}
|
||||
}
|
||||
|
||||
bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx,
|
||||
const char* attr_name, TF_Status* status) {
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
|
||||
return cc_ctx->HasAttr(attr_name);
|
||||
}
|
||||
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
|
||||
DEFINE_TF_GETATTR(Int32, tensorflow::int32, int32_t)
|
||||
|
||||
TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
|
||||
auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);
|
||||
|
@ -184,24 +184,6 @@ TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType(
|
||||
// Returns the step ID of the given context.
|
||||
TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx);
|
||||
|
||||
// Get the list_size and total_size of the attribute `attr_name` of `oper`.
|
||||
// list_size - the length of the list.
|
||||
// total_size - total size of the list.
|
||||
// (1) If attr_type == TF_ATTR_STRING
|
||||
// then total_size is the cumulative byte size
|
||||
// of all the strings in the list.
|
||||
// (3) If attr_type == TF_ATTR_SHAPE
|
||||
// then total_size is the number of dimensions
|
||||
// of the shape valued attribute, or -1
|
||||
// if its rank is unknown.
|
||||
// (4) If attr_type == TF_ATTR_SHAPE
|
||||
// then total_size is the cumulative number
|
||||
// of dimensions of all shapes in the list.
|
||||
// (5) Otherwise, total_size is undefined.
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrSize(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* list_size,
|
||||
int32_t* total_size, TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as a TF_DataType and
|
||||
// places it into *val. *status is set to TF_OK.
|
||||
//
|
||||
@ -220,112 +202,6 @@ TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val,
|
||||
TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as int64_t and
|
||||
// places it into *val. *status is set to TF_OK.
|
||||
//
|
||||
// If the attribute could not be found or could not be interpreted as
|
||||
// int64, *status is populated with an error.
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* val,
|
||||
TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as float and
|
||||
// places it into *val. *status is set to TF_OK.
|
||||
//
|
||||
// If the attribute could not be found or could not be interpreted as
|
||||
// float, *status is populated with an error.
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloat(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, float* val,
|
||||
TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as bool and
|
||||
// places it into *val. *status is set to TF_OK.
|
||||
//
|
||||
// If the attribute could not be found or could not be interpreted as
|
||||
// bool, *status is populated with an error.
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBool(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* val,
|
||||
TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as string and
|
||||
// places it into *val. `val` must
|
||||
// point to an array of length at least `max_length` (ideally set to
|
||||
// total_size from TF_OpKernelConstruction_GetAttrSize(ctx,
|
||||
// attr_name, list_size, total_size)). *status is set to TF_OK.
|
||||
//
|
||||
// If the attribute could not be found or could not be interpreted as
|
||||
// string, *status is populated with an error.
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrString(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, char* val,
|
||||
size_t max_length, TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as a TF_DataType array and
|
||||
// places it into *vals. *status is set to TF_OK.
|
||||
// `vals` must point to an array of length at least `max_values` (ideally set
|
||||
// to list_size from
|
||||
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
|
||||
// total_size)).
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTypeList(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* vals,
|
||||
int max_vals, TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as int32_t array and
|
||||
// places it into *vals. *status is set to TF_OK.
|
||||
// `vals` must point to an array of length at least `max_values` (ideally set
|
||||
// to list_size from
|
||||
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
|
||||
// total_size)).
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32List(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* vals,
|
||||
int max_vals, TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as int64_t array and
|
||||
// places it into *vals. *status is set to TF_OK.
|
||||
// `vals` must point to an array of length at least `max_values` (ideally set
|
||||
// to list_size from
|
||||
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
|
||||
// total_size)).
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64List(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* vals,
|
||||
int max_vals, TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as float array and
|
||||
// places it into *vals. *status is set to TF_OK.
|
||||
// `vals` must point to an array of length at least `max_values` (ideally set
|
||||
// to list_size from
|
||||
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
|
||||
// total_size)).
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloatList(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, float* vals,
|
||||
int max_vals, TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as bool array and
|
||||
// places it into *vals. *status is set to TF_OK.
|
||||
// `vals` must point to an array of length at least `max_values` (ideally set
|
||||
// to list_size from
|
||||
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
|
||||
// total_size)).
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBoolList(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* vals,
|
||||
int max_vals, TF_Status* status);
|
||||
|
||||
// Interprets the named kernel construction attribute as string array and fills
|
||||
// in `vals` and `lengths`, each of which must point to an array of length at
|
||||
// least `max_values`. *status is set to TF_OK. The elements of values will
|
||||
// point to addresses in `storage` which must be at least `storage_size` bytes
|
||||
// in length. Ideally, max_values would be set to list_size and `storage` would
|
||||
// be at least total_size, obtained from
|
||||
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
|
||||
// total_size).
|
||||
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrStringList(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, char** vals,
|
||||
size_t* lengths, int max_values, void* storage, size_t storage_size,
|
||||
TF_Status* status);
|
||||
|
||||
// Return true if the kernel construction has the attr_name
|
||||
TF_CAPI_EXPORT extern bool TF_OpKernelConstruction_HasAttr(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Status* status);
|
||||
|
||||
// Returns the unique operation name for this OpKernel.
|
||||
TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName(
|
||||
TF_OpKernelConstruction* ctx);
|
||||
|
@ -161,337 +161,6 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
|
||||
ASSERT_TRUE(delete_called);
|
||||
}
|
||||
|
||||
// REGISTER_OP for TF_OpKernelConstruction_GetAttr* test cases.
|
||||
// Registers two ops, each with a single attribute called 'Attr'.
|
||||
// The attribute in one op will have a type 'type', the other
|
||||
// will have list(type).
|
||||
#define ATTR_TEST_REGISTER_OP(name, type) \
|
||||
REGISTER_OP("TestKernelAttr" #name) \
|
||||
.Attr("Attr: " #type) \
|
||||
.SetShapeFn(tensorflow::shape_inference::UnknownShape); \
|
||||
REGISTER_OP("TestKernelAttr" #name "List") \
|
||||
.Attr("Attr: list(" #type ")") \
|
||||
.SetShapeFn(tensorflow::shape_inference::UnknownShape)
|
||||
ATTR_TEST_REGISTER_OP(String, string);
|
||||
ATTR_TEST_REGISTER_OP(Int, int);
|
||||
ATTR_TEST_REGISTER_OP(Float, float);
|
||||
ATTR_TEST_REGISTER_OP(Bool, bool);
|
||||
ATTR_TEST_REGISTER_OP(Type, type);
|
||||
#undef ATTR_TEST_REGISTER_OP
|
||||
|
||||
// Helper macros for the TF_OpKernelConstruction_GetAttr* tests.
|
||||
#define EXPECT_TF_SIZE(attr_name, expected_list_size, expected_total_size) \
|
||||
do { \
|
||||
int32_t list_size, total_size; \
|
||||
TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size, \
|
||||
&total_size, status); \
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); \
|
||||
EXPECT_EQ(expected_list_size, list_size); \
|
||||
EXPECT_EQ(expected_total_size, total_size); \
|
||||
} while (0)
|
||||
|
||||
typedef void* (*MyCreateFuncWithAttr)(TF_OpKernelConstruction*);
|
||||
class TestKernelAttr : public ::testing::Test {
|
||||
public:
|
||||
TestKernelAttr() {}
|
||||
~TestKernelAttr() {}
|
||||
|
||||
std::unique_ptr<OpKernel> GetFakeKernelWithAttr(const char* op_name,
|
||||
AttrValue v, Status* status) {
|
||||
NodeDef def;
|
||||
def.set_op(op_name);
|
||||
def.set_name("FakeNode");
|
||||
def.set_device("FakeDevice");
|
||||
(*def.mutable_attr())["Attr"] = v;
|
||||
return CreateOpKernel(DeviceType("FakeDevice"), nullptr, nullptr, def, 1,
|
||||
status);
|
||||
}
|
||||
|
||||
void SetAttr(MyCreateFuncWithAttr MyCreateFuncAttr, const char* op_name,
|
||||
AttrValue& v) {
|
||||
TF_KernelBuilder* builder = TF_NewKernelBuilder(
|
||||
op_name, "FakeDevice", MyCreateFuncAttr, &MyComputeFunc, &MyDeleteFunc);
|
||||
{
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_RegisterKernelBuilder("FakeNode", builder, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
Status status;
|
||||
std::unique_ptr<OpKernel> kernel =
|
||||
GetFakeKernelWithAttr(op_name, v, &status);
|
||||
TF_EXPECT_OK(status);
|
||||
ASSERT_NE(nullptr, kernel.get());
|
||||
kernel->Compute(nullptr);
|
||||
|
||||
ASSERT_TRUE(delete_called);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TestKernelAttr, String) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
std::unique_ptr<char[]> val(new char[5]);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
|
||||
/*expected_total_size*/ 5);
|
||||
TF_OpKernelConstruction_GetAttrString(ctx, "Attr", val.get(),
|
||||
/*max_length*/ 5, status);
|
||||
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ("bunny", string(static_cast<const char*>(val.get()), 5));
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
v.set_s("bunny");
|
||||
SetAttr(my_create_func, "TestKernelAttrString", v);
|
||||
}
|
||||
|
||||
TEST_F(TestKernelAttr, StringList) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
std::vector<string> list = {"bugs", "bunny", "duck"};
|
||||
int list_total_size = 0;
|
||||
for (const auto& s : list) {
|
||||
list_total_size += s.size();
|
||||
}
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
std::unique_ptr<char*[]> values(new char*[list.size()]);
|
||||
std::unique_ptr<size_t[]> lens(new size_t[list.size()]);
|
||||
std::unique_ptr<char[]> storage(new char[list_total_size]);
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list.size(),
|
||||
/*expected_total_size*/ list_total_size);
|
||||
TF_OpKernelConstruction_GetAttrStringList(
|
||||
ctx, "Attr", values.get(), lens.get(), list.size(), storage.get(),
|
||||
list_total_size, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
for (size_t i = 0; i < list.size(); ++i) {
|
||||
EXPECT_EQ(list[i].size(), lens[i]) << i;
|
||||
EXPECT_EQ(list[i], string(static_cast<const char*>(values[i]), lens[i]))
|
||||
<< i;
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
auto attr_in = gtl::ArraySlice<StringPiece>({"bugs", "bunny", "duck"});
|
||||
SetAttrValue(attr_in, &v);
|
||||
SetAttr(my_create_func, "TestKernelAttrStringList", v);
|
||||
}
|
||||
|
||||
TEST_F(TestKernelAttr, Int) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
int64_t val;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
|
||||
/*expected_total_size*/ -1);
|
||||
TF_OpKernelConstruction_GetAttrInt64(ctx, "Attr", &val, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ(1234, val);
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
v.set_i(1234);
|
||||
SetAttr(my_create_func, "TestKernelAttrInt", v);
|
||||
}
|
||||
|
||||
TEST_F(TestKernelAttr, IntList) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
const int64_t list[] = {1, 2, 3, 4};
|
||||
const size_t list_size = TF_ARRAYSIZE(list);
|
||||
int64_t values[list_size];
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
|
||||
/*expected_total_size*/ -1);
|
||||
TF_OpKernelConstruction_GetAttrInt64List(ctx, "Attr", values, list_size,
|
||||
status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_TRUE(
|
||||
std::equal(std::begin(list), std::end(list), std::begin(values)));
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
auto attr_in = gtl::ArraySlice<int64>({1, 2, 3, 4});
|
||||
SetAttrValue(attr_in, &v);
|
||||
SetAttr(my_create_func, "TestKernelAttrIntList", v);
|
||||
}
|
||||
|
||||
TEST_F(TestKernelAttr, Float) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
float val;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
|
||||
/*expected_total_size*/ -1);
|
||||
TF_OpKernelConstruction_GetAttrFloat(ctx, "Attr", &val, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_FLOAT_EQ(2.718, val);
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
v.set_f(2.718);
|
||||
SetAttr(my_create_func, "TestKernelAttrFloat", v);
|
||||
}
|
||||
|
||||
TEST_F(TestKernelAttr, FloatList) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
const float list[] = {1.414, 2.718, 3.1415};
|
||||
const size_t list_size = TF_ARRAYSIZE(list);
|
||||
float values[list_size];
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
|
||||
/*expected_total_size*/ -1);
|
||||
TF_OpKernelConstruction_GetAttrFloatList(ctx, "Attr", values, list_size,
|
||||
status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_TRUE(
|
||||
std::equal(std::begin(list), std::end(list), std::begin(values)));
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
auto attr_in = gtl::ArraySlice<float>({1.414, 2.718, 3.1415});
|
||||
SetAttrValue(attr_in, &v);
|
||||
SetAttr(my_create_func, "TestKernelAttrFloatList", v);
|
||||
}
|
||||
|
||||
TEST_F(TestKernelAttr, Bool) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
unsigned char val;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
|
||||
/*expected_total_size*/ -1);
|
||||
TF_OpKernelConstruction_GetAttrBool(ctx, "Attr", &val, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ(1, val);
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
v.set_b(1);
|
||||
SetAttr(my_create_func, "TestKernelAttrBool", v);
|
||||
}
|
||||
|
||||
TEST_F(TestKernelAttr, BoolList) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
const unsigned char list[] = {1, 0, 1, 0};
|
||||
const size_t list_size = TF_ARRAYSIZE(list);
|
||||
unsigned char values[list_size];
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
|
||||
/*expected_total_size*/ -1);
|
||||
TF_OpKernelConstruction_GetAttrBoolList(ctx, "Attr", values, list_size,
|
||||
status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_TRUE(
|
||||
std::equal(std::begin(list), std::end(list), std::begin(values)));
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
auto attr_in = gtl::ArraySlice<bool>({1, 0, 1, 0});
|
||||
SetAttrValue(attr_in, &v);
|
||||
SetAttr(my_create_func, "TestKernelAttrBoolList", v);
|
||||
}
|
||||
|
||||
TEST_F(TestKernelAttr, Type) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
TF_DataType val;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
|
||||
/*expected_total_size*/ -1);
|
||||
TF_OpKernelConstruction_GetAttrType(ctx, "Attr", &val, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ(TF_FLOAT, val);
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
v.set_type(DT_FLOAT);
|
||||
SetAttr(my_create_func, "TestKernelAttrType", v);
|
||||
}
|
||||
|
||||
TEST_F(TestKernelAttr, TypeList) {
|
||||
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
|
||||
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||
s->created = true;
|
||||
s->compute_called = false;
|
||||
|
||||
const TF_DataType list[] = {TF_FLOAT, TF_DOUBLE, TF_HALF, TF_COMPLEX128};
|
||||
const size_t list_size = TF_ARRAYSIZE(list);
|
||||
TF_DataType values[list_size];
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
|
||||
/*expected_total_size*/ -1);
|
||||
TF_OpKernelConstruction_GetAttrTypeList(ctx, "Attr", values, list_size,
|
||||
status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_TRUE(
|
||||
std::equal(std::begin(list), std::end(list), std::begin(values)));
|
||||
TF_DeleteStatus(status);
|
||||
return static_cast<void*>(s);
|
||||
};
|
||||
|
||||
AttrValue v;
|
||||
auto attr_in =
|
||||
gtl::ArraySlice<DataType>({DT_FLOAT, DT_DOUBLE, DT_HALF, DT_COMPLEX128});
|
||||
SetAttrValue(attr_in, &v);
|
||||
SetAttr(my_create_func, "TestKernelAttrTypeList", v);
|
||||
}
|
||||
#undef EXPECT_TF_SIZE
|
||||
|
||||
class DummyDevice : public DeviceBase {
|
||||
public:
|
||||
explicit DummyDevice(Env* env) : DeviceBase(env) {}
|
||||
|
Loading…
Reference in New Issue
Block a user