From 6e3bea20a13137af71fee48d93a80b843b9bc22a Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Mon, 4 May 2020 09:18:36 -0700 Subject: [PATCH] Less pointer indirection for TFE_OpAttrs, add TFE_OpGetAttrs We'll want this for implementing copy for `TF_AbstractOp`s backed by `TFE_Op`s (since we want to copy the type/attributes but not the inputs). PiperOrigin-RevId: 309756974 Change-Id: I07a8c48f50ab6d3c8a7d7db972fb60202b86434d --- tensorflow/c/BUILD | 1 + tensorflow/c/conversion_macros.h | 23 ++++++++++++---------- tensorflow/c/eager/BUILD | 1 + tensorflow/c/eager/c_api.cc | 15 +++++++++----- tensorflow/c/eager/c_api_experimental.h | 3 +++ tensorflow/c/eager/c_api_test.cc | 16 ++++----------- tensorflow/c/eager/tfe_op_attrs_internal.h | 22 +++++---------------- 7 files changed, 37 insertions(+), 44 deletions(-) diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index b2466a5c123..a4148147eb9 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -58,6 +58,7 @@ filegroup( name = "pywrap_required_hdrs", srcs = [ "c_api_internal.h", + "conversion_macros.h", "python_api.h", "tensor_interface.h", "tf_status_helper.h", diff --git a/tensorflow/c/conversion_macros.h b/tensorflow/c/conversion_macros.h index ce8adfadb26..d1f99b7b5b0 100644 --- a/tensorflow/c/conversion_macros.h +++ b/tensorflow/c/conversion_macros.h @@ -16,15 +16,18 @@ limitations under the License. #ifndef TENSORFLOW_C_CONVERSION_MACROS_H_ #define TENSORFLOW_C_CONVERSION_MACROS_H_ -#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \ - inline cpp_impl *unwrap(wrapper *w) { \ - return reinterpret_cast(w); \ - } \ - \ - inline const cpp_impl *unwrap(const wrapper *w) { \ - return reinterpret_cast(w); \ - } \ - \ - inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast(i); } +#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \ + inline cpp_impl *unwrap(wrapper *w) { \ + return reinterpret_cast(w); \ + } \ + \ + inline const cpp_impl *unwrap(const wrapper *w) { \ + return reinterpret_cast(w); \ + } \ + \ + inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast(i); } \ + inline const wrapper *wrap(const cpp_impl *i) { \ + return reinterpret_cast(i); \ + } #endif // TENSORFLOW_C_CONVERSION_MACROS_H_ diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 2efaa4ecb36..3d3fc7065a4 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -247,6 +247,7 @@ cc_library( "//tensorflow:internal", ], deps = [ + "//tensorflow/c:conversion_macros", "//tensorflow/c:tf_status", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/eager:attr_builder", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 4ef178eb30c..9be1290fd91 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -1483,9 +1483,14 @@ void TFE_ContextEndStep(TFE_Context* ctx) { context->EndStep(); } +const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) { + return tensorflow::wrap( + &OperationFromInterface(tensorflow::unwrap(op))->Attrs()); +} + void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) { tensorflow::AttrValueMap m; - attrs->attributes->FillAttrValueMap(&m); + tensorflow::unwrap(attrs)->FillAttrValueMap(&m); tensorflow::EagerOperation* operation = OperationFromInterface(tensorflow::unwrap(op)); tensorflow::AttrBuilder* destination = operation->MutableAttrs(); @@ -1497,8 +1502,8 @@ void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) { void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf, TF_Status* status) { tensorflow::NameAttrList name_and_attrs; - attrs->attributes->FillAttrValueMap(name_and_attrs.mutable_attr()); - name_and_attrs.set_name(attrs->attributes->op_name()); + tensorflow::unwrap(attrs)->FillAttrValueMap(name_and_attrs.mutable_attr()); + name_and_attrs.set_name(tensorflow::unwrap(attrs)->op_name()); status->status = MessageToBuffer(name_and_attrs, buf); } @@ -1619,9 +1624,9 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { } std::vector outputs(*num_retvals); TF_Status status; - TFE_OpAttrs attributes(&op->Attrs()); device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(), - &attributes, num_retvals, outputs.data(), &status, info_); + wrap(&op->Attrs()), num_retvals, outputs.data(), &status, + info_); if (status.status.ok()) { for (int i = 0; i < *num_retvals; ++i) { retvals[i] = tensorflow::TensorHandleFromInterface( diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index d1e99d86180..33adce40da0 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -431,6 +431,9 @@ TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx, // A reference to an op's name -> attribute mapping typedef struct TFE_OpAttrs TFE_OpAttrs; +// Fetch a reference to `op`'s attributes. The returned reference is only valid +// while `op` is alive. +const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op); // Add attributes in `attrs` to `op`. // // Does not overwrite or update existing attributes, but adds new ones. diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 0e4183dad16..3160cb0e585 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -1591,15 +1591,11 @@ TEST(CAPI, TestTFE_OpAddAttrs) { TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status); TFE_OpSetAttrType(var_op, "dtype", TF_INT64); TFE_OpSetAttrShape(var_op, "shape", {}, 0, status); - // There is currently no API to fetch attributes from an operation, fetching - // happens only as an implementation detail of custom devices. - tensorflow::EagerOperation* operation = - OperationFromInterface(tensorflow::unwrap(var_op)); - TFE_OpAttrs attributes{&operation->Attrs()}; + const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op); TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status); TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT); - TFE_OpAddAttrs(copy_op, &attributes); + TFE_OpAddAttrs(copy_op, attributes); unsigned char is_list = 0; ASSERT_EQ(TF_ATTR_TYPE, TFE_OpGetAttrType(copy_op, "dtype", &is_list, status)); @@ -1631,14 +1627,10 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_OpSetAttrType(var_op, "dtype", TF_INT64); TFE_OpSetAttrShape(var_op, "shape", {}, 0, status); - // There is currently no API to fetch attributes from an operation, fetching - // happens only as an implementation detail of custom devices. - tensorflow::EagerOperation* operation = - OperationFromInterface(tensorflow::unwrap(var_op)); - TFE_OpAttrs attributes{&operation->Attrs()}; + const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op); TF_Buffer* serialized_attr_values = TF_NewBuffer(); - TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status); + TFE_OpAttrsSerialize(attributes, serialized_attr_values, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); tensorflow::NameAttrList name_and_attrs; ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data, diff --git a/tensorflow/c/eager/tfe_op_attrs_internal.h b/tensorflow/c/eager/tfe_op_attrs_internal.h index 935d7d520e5..0287502dea6 100644 --- a/tensorflow/c/eager/tfe_op_attrs_internal.h +++ b/tensorflow/c/eager/tfe_op_attrs_internal.h @@ -15,33 +15,21 @@ limitations under the License. #ifndef TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_ #define TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_ -#include -#include -#include -#include -#include -#include -#include - +#include "tensorflow/c/conversion_macros.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/framework/attr_value.pb.h" // An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways // that sometimes do not require serialization. +typedef struct TFE_OpAttrs TFE_OpAttrs; + typedef struct TFE_Context TFE_Context; typedef struct TFE_Op TFE_Op; -struct TFE_OpAttrs { - explicit TFE_OpAttrs() : attributes(nullptr) {} - - explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value) - : attributes(value) {} - - const tensorflow::AttrBuilder* attributes; -}; - namespace tensorflow { +DEFINE_CONVERSION_FUNCTIONS(tensorflow::AttrBuilder, TFE_OpAttrs); + // Set an AttrValue on the op. Doesn't handle the list types. void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value,