diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index d33a91b5898..27f98be14ad 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/stream_executor/stream.h" #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) -using tensorflow::errors::InvalidArgument; // This file forms the basis of a stable ABI for third-party kernel // implementations. It is crucial that changes to this file are made cautiously // and with a focus on maintaining both source and binary compatibility. @@ -88,25 +87,9 @@ void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name, TF_SetStatus(status, TF_OK, ""); } #undef CASE - } // namespace } // namespace tensorflow -namespace { -const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx, - const char* attr_name, - TF_Status* status) { - auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); - const tensorflow::AttrValue* attr = - ::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name); - if (attr == nullptr) { - status->status = InvalidArgument("Operation '", cc_ctx->def().name(), - "' has no attr named '", attr_name, "'."); - } - return attr; -} -} // namespace - void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name, const TF_DataType type, @@ -274,81 +257,7 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) { cc_ctx->CtxFailure(s); } -void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx, - const char* attr_name, - int32_t* list_size, - int32_t* total_size, - TF_Status* status) { - const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); - if (!status->status.ok()) { - *list_size = -1; - *total_size = -1; - return; - } - switch (attr->value_case()) { -#define SINGLE_CASE(kK, attr_type, size_expr) \ - case tensorflow::AttrValue::kK: \ - *list_size = -1; \ - *total_size = size_expr; \ - break; - - SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length()); - SINGLE_CASE(kI, TF_ATTR_INT, -1); - SINGLE_CASE(kF, TF_ATTR_FLOAT, -1); - SINGLE_CASE(kB, TF_ATTR_BOOL, -1); - SINGLE_CASE(kType, TF_ATTR_TYPE, -1); - SINGLE_CASE(kShape, TF_ATTR_SHAPE, - attr->shape().unknown_rank() ? -1 : attr->shape().dim_size()); - SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1); -#undef SINGLE_CASE - - case tensorflow::AttrValue::kList: - *list_size = 0; - *total_size = -1; -#define LIST_CASE(field, attr_type, ...) \ - if (attr->list().field##_size() > 0) { \ - *list_size = attr->list().field##_size(); \ - __VA_ARGS__; \ - break; \ - } - - LIST_CASE( - s, TF_ATTR_STRING, *total_size = 0; - for (int i = 0; i < attr->list().s_size(); - ++i) { *total_size += attr->list().s(i).size(); }); - LIST_CASE(i, TF_ATTR_INT); - LIST_CASE(f, TF_ATTR_FLOAT); - LIST_CASE(b, TF_ATTR_BOOL); - LIST_CASE(type, TF_ATTR_TYPE); - LIST_CASE( - shape, TF_ATTR_SHAPE, *total_size = 0; - for (int i = 0; i < attr->list().shape_size(); ++i) { - const auto& s = attr->list().shape(i); - *total_size += s.unknown_rank() ? 0 : s.dim_size(); - }); - LIST_CASE(tensor, TF_ATTR_TENSOR); - LIST_CASE(tensor, TF_ATTR_FUNC); -#undef LIST_CASE - break; - - case tensorflow::AttrValue::kPlaceholder: - *list_size = -1; - *total_size = -1; - break; - - case tensorflow::AttrValue::kFunc: - *list_size = -1; - *total_size = -1; - break; - - case tensorflow::AttrValue::VALUE_NOT_SET: - status->status = - InvalidArgument("Attribute '", attr_name, "' has no value set"); - break; - } -} - -#define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field) \ +#define DEFINE_TF_GETATTR(func, c_type, cc_type) \ void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \ const char* attr_name, \ c_type* val, TF_Status* status) { \ @@ -360,84 +269,10 @@ void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx, if (s.ok()) { \ *val = static_cast(v); \ } \ - } \ - void TF_OpKernelConstruction_GetAttr##func##List( \ - TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals, \ - int max_vals, TF_Status* status) { \ - TF_SetStatus(status, TF_OK, ""); \ - const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); \ - if (!status->status.ok()) return; \ - if (attr->value_case() != tensorflow::AttrValue::kList) { \ - status->status = \ - InvalidArgument("Value for '", attr_name, "' is not a list."); \ - return; \ - } \ - status->status = \ - tensorflow::AttrValueHasType(*attr, "list(" attr_type ")"); \ - if (!status->status.ok()) return; \ - const auto len = std::min(max_vals, attr->list().list_field##_size()); \ - for (int i = 0; i < len; ++i) { \ - vals[i] = static_cast(attr->list().list_field(i)); \ - } \ } -DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type", type) -DEFINE_TF_GETATTR(Int32, int32_t, tensorflow::int32, "int", i) -DEFINE_TF_GETATTR(Int64, int64_t, tensorflow::int64, "int", i) -DEFINE_TF_GETATTR(Float, float, float, "float", f) -DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool", b) - -void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx, - const char* attr_name, char* value, - size_t max_length, - TF_Status* status) { - std::string v; - auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); - ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v); - ::tensorflow::Set_TF_Status_from_Status(status, s); - - if (!status->status.ok()) return; - - if (max_length <= 0) { - return; - } - std::memcpy(value, v.data(), std::min(v.length(), max_length)); -} - -void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx, - const char* attr_name, - char** values, size_t* lengths, - int max_values, void* storage, - size_t storage_size, - TF_Status* status) { - std::vector v; - auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); - ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v); - ::tensorflow::Set_TF_Status_from_Status(status, s); - - if (!status->status.ok()) return; - - const auto len = std::min(max_values, static_cast(v.size())); - char* p = static_cast(storage); - for (int i = 0; i < len; ++i) { - const std::string& s = v[i]; - values[i] = p; - lengths[i] = s.size(); - if ((p + s.size()) > (static_cast(storage) + storage_size)) { - status->status = InvalidArgument( - "Not enough storage to hold the requested list of strings"); - return; - } - memcpy(values[i], s.data(), s.size()); - p += s.size(); - } -} - -bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx, - const char* attr_name, TF_Status* status) { - auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); - return cc_ctx->HasAttr(attr_name); -} +DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType) +DEFINE_TF_GETATTR(Int32, tensorflow::int32, int32_t) TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) { auto* cc_ctx = reinterpret_cast(ctx); diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index 508d59b1223..34848a1c92a 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -184,24 +184,6 @@ TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType( // Returns the step ID of the given context. TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx); -// Get the list_size and total_size of the attribute `attr_name` of `oper`. -// list_size - the length of the list. -// total_size - total size of the list. -// (1) If attr_type == TF_ATTR_STRING -// then total_size is the cumulative byte size -// of all the strings in the list. -// (3) If attr_type == TF_ATTR_SHAPE -// then total_size is the number of dimensions -// of the shape valued attribute, or -1 -// if its rank is unknown. -// (4) If attr_type == TF_ATTR_SHAPE -// then total_size is the cumulative number -// of dimensions of all shapes in the list. -// (5) Otherwise, total_size is undefined. -TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrSize( - TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* list_size, - int32_t* total_size, TF_Status* status); - // Interprets the named kernel construction attribute as a TF_DataType and // places it into *val. *status is set to TF_OK. // @@ -220,112 +202,6 @@ TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32( TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val, TF_Status* status); -// Interprets the named kernel construction attribute as int64_t and -// places it into *val. *status is set to TF_OK. -// -// If the attribute could not be found or could not be interpreted as -// int64, *status is populated with an error. -TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64( - TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* val, - TF_Status* status); - -// Interprets the named kernel construction attribute as float and -// places it into *val. *status is set to TF_OK. -// -// If the attribute could not be found or could not be interpreted as -// float, *status is populated with an error. -TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloat( - TF_OpKernelConstruction* ctx, const char* attr_name, float* val, - TF_Status* status); - -// Interprets the named kernel construction attribute as bool and -// places it into *val. *status is set to TF_OK. -// -// If the attribute could not be found or could not be interpreted as -// bool, *status is populated with an error. -TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBool( - TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* val, - TF_Status* status); - -// Interprets the named kernel construction attribute as string and -// places it into *val. `val` must -// point to an array of length at least `max_length` (ideally set to -// total_size from TF_OpKernelConstruction_GetAttrSize(ctx, -// attr_name, list_size, total_size)). *status is set to TF_OK. -// -// If the attribute could not be found or could not be interpreted as -// string, *status is populated with an error. -TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrString( - TF_OpKernelConstruction* ctx, const char* attr_name, char* val, - size_t max_length, TF_Status* status); - -// Interprets the named kernel construction attribute as a TF_DataType array and -// places it into *vals. *status is set to TF_OK. -// `vals` must point to an array of length at least `max_values` (ideally set -// to list_size from -// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, -// total_size)). -TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTypeList( - TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* vals, - int max_vals, TF_Status* status); - -// Interprets the named kernel construction attribute as int32_t array and -// places it into *vals. *status is set to TF_OK. -// `vals` must point to an array of length at least `max_values` (ideally set -// to list_size from -// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, -// total_size)). -TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32List( - TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* vals, - int max_vals, TF_Status* status); - -// Interprets the named kernel construction attribute as int64_t array and -// places it into *vals. *status is set to TF_OK. -// `vals` must point to an array of length at least `max_values` (ideally set -// to list_size from -// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, -// total_size)). -TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64List( - TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* vals, - int max_vals, TF_Status* status); - -// Interprets the named kernel construction attribute as float array and -// places it into *vals. *status is set to TF_OK. -// `vals` must point to an array of length at least `max_values` (ideally set -// to list_size from -// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, -// total_size)). -TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloatList( - TF_OpKernelConstruction* ctx, const char* attr_name, float* vals, - int max_vals, TF_Status* status); - -// Interprets the named kernel construction attribute as bool array and -// places it into *vals. *status is set to TF_OK. -// `vals` must point to an array of length at least `max_values` (ideally set -// to list_size from -// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, -// total_size)). -TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBoolList( - TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* vals, - int max_vals, TF_Status* status); - -// Interprets the named kernel construction attribute as string array and fills -// in `vals` and `lengths`, each of which must point to an array of length at -// least `max_values`. *status is set to TF_OK. The elements of values will -// point to addresses in `storage` which must be at least `storage_size` bytes -// in length. Ideally, max_values would be set to list_size and `storage` would -// be at least total_size, obtained from -// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, -// total_size). -TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrStringList( - TF_OpKernelConstruction* ctx, const char* attr_name, char** vals, - size_t* lengths, int max_values, void* storage, size_t storage_size, - TF_Status* status); - -// Return true if the kernel construction has the attr_name -TF_CAPI_EXPORT extern bool TF_OpKernelConstruction_HasAttr( - TF_OpKernelConstruction* ctx, const char* attr_name, TF_Status* status); - // Returns the unique operation name for this OpKernel. TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName( TF_OpKernelConstruction* ctx); diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc index 4716bc3da55..49a168af076 100644 --- a/tensorflow/c/kernels_test.cc +++ b/tensorflow/c/kernels_test.cc @@ -161,337 +161,6 @@ TEST(TestKernel, TestRegisterKernelBuilder) { ASSERT_TRUE(delete_called); } -// REGISTER_OP for TF_OpKernelConstruction_GetAttr* test cases. -// Registers two ops, each with a single attribute called 'Attr'. -// The attribute in one op will have a type 'type', the other -// will have list(type). -#define ATTR_TEST_REGISTER_OP(name, type) \ - REGISTER_OP("TestKernelAttr" #name) \ - .Attr("Attr: " #type) \ - .SetShapeFn(tensorflow::shape_inference::UnknownShape); \ - REGISTER_OP("TestKernelAttr" #name "List") \ - .Attr("Attr: list(" #type ")") \ - .SetShapeFn(tensorflow::shape_inference::UnknownShape) -ATTR_TEST_REGISTER_OP(String, string); -ATTR_TEST_REGISTER_OP(Int, int); -ATTR_TEST_REGISTER_OP(Float, float); -ATTR_TEST_REGISTER_OP(Bool, bool); -ATTR_TEST_REGISTER_OP(Type, type); -#undef ATTR_TEST_REGISTER_OP - -// Helper macros for the TF_OpKernelConstruction_GetAttr* tests. -#define EXPECT_TF_SIZE(attr_name, expected_list_size, expected_total_size) \ - do { \ - int32_t list_size, total_size; \ - TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size, \ - &total_size, status); \ - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); \ - EXPECT_EQ(expected_list_size, list_size); \ - EXPECT_EQ(expected_total_size, total_size); \ - } while (0) - -typedef void* (*MyCreateFuncWithAttr)(TF_OpKernelConstruction*); -class TestKernelAttr : public ::testing::Test { - public: - TestKernelAttr() {} - ~TestKernelAttr() {} - - std::unique_ptr GetFakeKernelWithAttr(const char* op_name, - AttrValue v, Status* status) { - NodeDef def; - def.set_op(op_name); - def.set_name("FakeNode"); - def.set_device("FakeDevice"); - (*def.mutable_attr())["Attr"] = v; - return CreateOpKernel(DeviceType("FakeDevice"), nullptr, nullptr, def, 1, - status); - } - - void SetAttr(MyCreateFuncWithAttr MyCreateFuncAttr, const char* op_name, - AttrValue& v) { - TF_KernelBuilder* builder = TF_NewKernelBuilder( - op_name, "FakeDevice", MyCreateFuncAttr, &MyComputeFunc, &MyDeleteFunc); - { - TF_Status* status = TF_NewStatus(); - TF_RegisterKernelBuilder("FakeNode", builder, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)); - TF_DeleteStatus(status); - } - Status status; - std::unique_ptr kernel = - GetFakeKernelWithAttr(op_name, v, &status); - TF_EXPECT_OK(status); - ASSERT_NE(nullptr, kernel.get()); - kernel->Compute(nullptr); - - ASSERT_TRUE(delete_called); - } -}; - -TEST_F(TestKernelAttr, String) { - auto my_create_func = [](TF_OpKernelConstruction* ctx) { - struct MyCustomKernel* s = new struct MyCustomKernel; - s->created = true; - s->compute_called = false; - - std::unique_ptr val(new char[5]); - TF_Status* status = TF_NewStatus(); - EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1, - /*expected_total_size*/ 5); - TF_OpKernelConstruction_GetAttrString(ctx, "Attr", val.get(), - /*max_length*/ 5, status); - - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - EXPECT_EQ("bunny", string(static_cast(val.get()), 5)); - TF_DeleteStatus(status); - return static_cast(s); - }; - - AttrValue v; - v.set_s("bunny"); - SetAttr(my_create_func, "TestKernelAttrString", v); -} - -TEST_F(TestKernelAttr, StringList) { - auto my_create_func = [](TF_OpKernelConstruction* ctx) { - struct MyCustomKernel* s = new struct MyCustomKernel; - s->created = true; - s->compute_called = false; - - std::vector list = {"bugs", "bunny", "duck"}; - int list_total_size = 0; - for (const auto& s : list) { - list_total_size += s.size(); - } - - TF_Status* status = TF_NewStatus(); - std::unique_ptr values(new char*[list.size()]); - std::unique_ptr lens(new size_t[list.size()]); - std::unique_ptr storage(new char[list_total_size]); - EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list.size(), - /*expected_total_size*/ list_total_size); - TF_OpKernelConstruction_GetAttrStringList( - ctx, "Attr", values.get(), lens.get(), list.size(), storage.get(), - list_total_size, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - for (size_t i = 0; i < list.size(); ++i) { - EXPECT_EQ(list[i].size(), lens[i]) << i; - EXPECT_EQ(list[i], string(static_cast(values[i]), lens[i])) - << i; - } - TF_DeleteStatus(status); - return static_cast(s); - }; - - AttrValue v; - auto attr_in = gtl::ArraySlice({"bugs", "bunny", "duck"}); - SetAttrValue(attr_in, &v); - SetAttr(my_create_func, "TestKernelAttrStringList", v); -} - -TEST_F(TestKernelAttr, Int) { - auto my_create_func = [](TF_OpKernelConstruction* ctx) { - struct MyCustomKernel* s = new struct MyCustomKernel; - s->created = true; - s->compute_called = false; - - int64_t val; - TF_Status* status = TF_NewStatus(); - EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1, - /*expected_total_size*/ -1); - TF_OpKernelConstruction_GetAttrInt64(ctx, "Attr", &val, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - EXPECT_EQ(1234, val); - TF_DeleteStatus(status); - return static_cast(s); - }; - - AttrValue v; - v.set_i(1234); - SetAttr(my_create_func, "TestKernelAttrInt", v); -} - -TEST_F(TestKernelAttr, IntList) { - auto my_create_func = [](TF_OpKernelConstruction* ctx) { - struct MyCustomKernel* s = new struct MyCustomKernel; - s->created = true; - s->compute_called = false; - - const int64_t list[] = {1, 2, 3, 4}; - const size_t list_size = TF_ARRAYSIZE(list); - int64_t values[list_size]; - - TF_Status* status = TF_NewStatus(); - EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size, - /*expected_total_size*/ -1); - TF_OpKernelConstruction_GetAttrInt64List(ctx, "Attr", values, list_size, - status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - EXPECT_TRUE( - std::equal(std::begin(list), std::end(list), std::begin(values))); - TF_DeleteStatus(status); - return static_cast(s); - }; - - AttrValue v; - auto attr_in = gtl::ArraySlice({1, 2, 3, 4}); - SetAttrValue(attr_in, &v); - SetAttr(my_create_func, "TestKernelAttrIntList", v); -} - -TEST_F(TestKernelAttr, Float) { - auto my_create_func = [](TF_OpKernelConstruction* ctx) { - struct MyCustomKernel* s = new struct MyCustomKernel; - s->created = true; - s->compute_called = false; - - float val; - TF_Status* status = TF_NewStatus(); - EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1, - /*expected_total_size*/ -1); - TF_OpKernelConstruction_GetAttrFloat(ctx, "Attr", &val, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - EXPECT_FLOAT_EQ(2.718, val); - TF_DeleteStatus(status); - return static_cast(s); - }; - - AttrValue v; - v.set_f(2.718); - SetAttr(my_create_func, "TestKernelAttrFloat", v); -} - -TEST_F(TestKernelAttr, FloatList) { - auto my_create_func = [](TF_OpKernelConstruction* ctx) { - struct MyCustomKernel* s = new struct MyCustomKernel; - s->created = true; - s->compute_called = false; - - const float list[] = {1.414, 2.718, 3.1415}; - const size_t list_size = TF_ARRAYSIZE(list); - float values[list_size]; - - TF_Status* status = TF_NewStatus(); - EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size, - /*expected_total_size*/ -1); - TF_OpKernelConstruction_GetAttrFloatList(ctx, "Attr", values, list_size, - status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - EXPECT_TRUE( - std::equal(std::begin(list), std::end(list), std::begin(values))); - TF_DeleteStatus(status); - return static_cast(s); - }; - - AttrValue v; - auto attr_in = gtl::ArraySlice({1.414, 2.718, 3.1415}); - SetAttrValue(attr_in, &v); - SetAttr(my_create_func, "TestKernelAttrFloatList", v); -} - -TEST_F(TestKernelAttr, Bool) { - auto my_create_func = [](TF_OpKernelConstruction* ctx) { - struct MyCustomKernel* s = new struct MyCustomKernel; - s->created = true; - s->compute_called = false; - - unsigned char val; - TF_Status* status = TF_NewStatus(); - EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1, - /*expected_total_size*/ -1); - TF_OpKernelConstruction_GetAttrBool(ctx, "Attr", &val, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - EXPECT_EQ(1, val); - TF_DeleteStatus(status); - return static_cast(s); - }; - - AttrValue v; - v.set_b(1); - SetAttr(my_create_func, "TestKernelAttrBool", v); -} - -TEST_F(TestKernelAttr, BoolList) { - auto my_create_func = [](TF_OpKernelConstruction* ctx) { - struct MyCustomKernel* s = new struct MyCustomKernel; - s->created = true; - s->compute_called = false; - - const unsigned char list[] = {1, 0, 1, 0}; - const size_t list_size = TF_ARRAYSIZE(list); - unsigned char values[list_size]; - - TF_Status* status = TF_NewStatus(); - EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size, - /*expected_total_size*/ -1); - TF_OpKernelConstruction_GetAttrBoolList(ctx, "Attr", values, list_size, - status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - EXPECT_TRUE( - std::equal(std::begin(list), std::end(list), std::begin(values))); - TF_DeleteStatus(status); - return static_cast(s); - }; - - AttrValue v; - auto attr_in = gtl::ArraySlice({1, 0, 1, 0}); - SetAttrValue(attr_in, &v); - SetAttr(my_create_func, "TestKernelAttrBoolList", v); -} - -TEST_F(TestKernelAttr, Type) { - auto my_create_func = [](TF_OpKernelConstruction* ctx) { - struct MyCustomKernel* s = new struct MyCustomKernel; - s->created = true; - s->compute_called = false; - - TF_DataType val; - TF_Status* status = TF_NewStatus(); - EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1, - /*expected_total_size*/ -1); - TF_OpKernelConstruction_GetAttrType(ctx, "Attr", &val, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - EXPECT_EQ(TF_FLOAT, val); - TF_DeleteStatus(status); - return static_cast(s); - }; - - AttrValue v; - v.set_type(DT_FLOAT); - SetAttr(my_create_func, "TestKernelAttrType", v); -} - -TEST_F(TestKernelAttr, TypeList) { - auto my_create_func = [](TF_OpKernelConstruction* ctx) { - struct MyCustomKernel* s = new struct MyCustomKernel; - s->created = true; - s->compute_called = false; - - const TF_DataType list[] = {TF_FLOAT, TF_DOUBLE, TF_HALF, TF_COMPLEX128}; - const size_t list_size = TF_ARRAYSIZE(list); - TF_DataType values[list_size]; - - TF_Status* status = TF_NewStatus(); - EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size, - /*expected_total_size*/ -1); - TF_OpKernelConstruction_GetAttrTypeList(ctx, "Attr", values, list_size, - status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - EXPECT_TRUE( - std::equal(std::begin(list), std::end(list), std::begin(values))); - TF_DeleteStatus(status); - return static_cast(s); - }; - - AttrValue v; - auto attr_in = - gtl::ArraySlice({DT_FLOAT, DT_DOUBLE, DT_HALF, DT_COMPLEX128}); - SetAttrValue(attr_in, &v); - SetAttr(my_create_func, "TestKernelAttrTypeList", v); -} -#undef EXPECT_TF_SIZE - class DummyDevice : public DeviceBase { public: explicit DummyDevice(Env* env) : DeviceBase(env) {}