Less pointer indirection for TFE_OpAttrs, add TFE_OpGetAttrs

We'll want this for implementing copy for `TF_AbstractOp`s backed by `TFE_Op`s (since we want to copy the type/attributes but not the inputs).

PiperOrigin-RevId: 309756974
Change-Id: I07a8c48f50ab6d3c8a7d7db972fb60202b86434d
This commit is contained in:
Allen Lavoie 2020-05-04 09:18:36 -07:00 committed by TensorFlower Gardener
parent dbe7b589c5
commit 6e3bea20a1
7 changed files with 37 additions and 44 deletions

View File

@ -58,6 +58,7 @@ filegroup(
name = "pywrap_required_hdrs", name = "pywrap_required_hdrs",
srcs = [ srcs = [
"c_api_internal.h", "c_api_internal.h",
"conversion_macros.h",
"python_api.h", "python_api.h",
"tensor_interface.h", "tensor_interface.h",
"tf_status_helper.h", "tf_status_helper.h",

View File

@ -16,15 +16,18 @@ limitations under the License.
#ifndef TENSORFLOW_C_CONVERSION_MACROS_H_ #ifndef TENSORFLOW_C_CONVERSION_MACROS_H_
#define TENSORFLOW_C_CONVERSION_MACROS_H_ #define TENSORFLOW_C_CONVERSION_MACROS_H_
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \ #define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
inline cpp_impl *unwrap(wrapper *w) { \ inline cpp_impl *unwrap(wrapper *w) { \
return reinterpret_cast<cpp_impl *>(w); \ return reinterpret_cast<cpp_impl *>(w); \
} \ } \
\ \
inline const cpp_impl *unwrap(const wrapper *w) { \ inline const cpp_impl *unwrap(const wrapper *w) { \
return reinterpret_cast<const cpp_impl *>(w); \ return reinterpret_cast<const cpp_impl *>(w); \
} \ } \
\ \
inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast<wrapper *>(i); } inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast<wrapper *>(i); } \
inline const wrapper *wrap(const cpp_impl *i) { \
return reinterpret_cast<const wrapper *>(i); \
}
#endif // TENSORFLOW_C_CONVERSION_MACROS_H_ #endif // TENSORFLOW_C_CONVERSION_MACROS_H_

View File

@ -247,6 +247,7 @@ cc_library(
"//tensorflow:internal", "//tensorflow:internal",
], ],
deps = [ deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/common_runtime/eager:attr_builder",

View File

@ -1483,9 +1483,14 @@ void TFE_ContextEndStep(TFE_Context* ctx) {
context->EndStep(); context->EndStep();
} }
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
return tensorflow::wrap(
&OperationFromInterface(tensorflow::unwrap(op))->Attrs());
}
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) { void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
tensorflow::AttrValueMap m; tensorflow::AttrValueMap m;
attrs->attributes->FillAttrValueMap(&m); tensorflow::unwrap(attrs)->FillAttrValueMap(&m);
tensorflow::EagerOperation* operation = tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(op)); OperationFromInterface(tensorflow::unwrap(op));
tensorflow::AttrBuilder* destination = operation->MutableAttrs(); tensorflow::AttrBuilder* destination = operation->MutableAttrs();
@ -1497,8 +1502,8 @@ void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf, void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
TF_Status* status) { TF_Status* status) {
tensorflow::NameAttrList name_and_attrs; tensorflow::NameAttrList name_and_attrs;
attrs->attributes->FillAttrValueMap(name_and_attrs.mutable_attr()); tensorflow::unwrap(attrs)->FillAttrValueMap(name_and_attrs.mutable_attr());
name_and_attrs.set_name(attrs->attributes->op_name()); name_and_attrs.set_name(tensorflow::unwrap(attrs)->op_name());
status->status = MessageToBuffer(name_and_attrs, buf); status->status = MessageToBuffer(name_and_attrs, buf);
} }
@ -1619,9 +1624,9 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
} }
std::vector<TFE_TensorHandle*> outputs(*num_retvals); std::vector<TFE_TensorHandle*> outputs(*num_retvals);
TF_Status status; TF_Status status;
TFE_OpAttrs attributes(&op->Attrs());
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(), device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
&attributes, num_retvals, outputs.data(), &status, info_); wrap(&op->Attrs()), num_retvals, outputs.data(), &status,
info_);
if (status.status.ok()) { if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) { for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = tensorflow::TensorHandleFromInterface( retvals[i] = tensorflow::TensorHandleFromInterface(

View File

@ -431,6 +431,9 @@ TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
// A reference to an op's name -> attribute mapping // A reference to an op's name -> attribute mapping
typedef struct TFE_OpAttrs TFE_OpAttrs; typedef struct TFE_OpAttrs TFE_OpAttrs;
// Fetch a reference to `op`'s attributes. The returned reference is only valid
// while `op` is alive.
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op);
// Add attributes in `attrs` to `op`. // Add attributes in `attrs` to `op`.
// //
// Does not overwrite or update existing attributes, but adds new ones. // Does not overwrite or update existing attributes, but adds new ones.

View File

@ -1591,15 +1591,11 @@ TEST(CAPI, TestTFE_OpAddAttrs) {
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status); TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(var_op, "dtype", TF_INT64); TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status); TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
// There is currently no API to fetch attributes from an operation, fetching const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
// happens only as an implementation detail of custom devices.
tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(var_op));
TFE_OpAttrs attributes{&operation->Attrs()};
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status); TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT); TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
TFE_OpAddAttrs(copy_op, &attributes); TFE_OpAddAttrs(copy_op, attributes);
unsigned char is_list = 0; unsigned char is_list = 0;
ASSERT_EQ(TF_ATTR_TYPE, ASSERT_EQ(TF_ATTR_TYPE,
TFE_OpGetAttrType(copy_op, "dtype", &is_list, status)); TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
@ -1631,14 +1627,10 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(var_op, "dtype", TF_INT64); TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status); TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
// There is currently no API to fetch attributes from an operation, fetching const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
// happens only as an implementation detail of custom devices.
tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(var_op));
TFE_OpAttrs attributes{&operation->Attrs()};
TF_Buffer* serialized_attr_values = TF_NewBuffer(); TF_Buffer* serialized_attr_values = TF_NewBuffer();
TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status); TFE_OpAttrsSerialize(attributes, serialized_attr_values, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::NameAttrList name_and_attrs; tensorflow::NameAttrList name_and_attrs;
ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data, ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,

View File

@ -15,33 +15,21 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_ #ifndef TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_ #define TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
#include <algorithm> #include "tensorflow/c/conversion_macros.h"
#include <cstddef>
#include <map>
#include <memory>
#include <queue>
#include <string>
#include <vector>
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value.pb.h"
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways // An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
// that sometimes do not require serialization. // that sometimes do not require serialization.
typedef struct TFE_OpAttrs TFE_OpAttrs;
typedef struct TFE_Context TFE_Context; typedef struct TFE_Context TFE_Context;
typedef struct TFE_Op TFE_Op; typedef struct TFE_Op TFE_Op;
struct TFE_OpAttrs {
explicit TFE_OpAttrs() : attributes(nullptr) {}
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value)
: attributes(value) {}
const tensorflow::AttrBuilder* attributes;
};
namespace tensorflow { namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AttrBuilder, TFE_OpAttrs);
// Set an AttrValue on the op. Doesn't handle the list types. // Set an AttrValue on the op. Doesn't handle the list types.
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
const tensorflow::AttrValue& default_value, const tensorflow::AttrValue& default_value,