Less pointer indirection for TFE_OpAttrs, add TFE_OpGetAttrs

We'll want this for implementing copy for `TF_AbstractOp`s backed by `TFE_Op`s (since we want to copy the type/attributes but not the inputs).

PiperOrigin-RevId: 309756974
Change-Id: I07a8c48f50ab6d3c8a7d7db972fb60202b86434d
This commit is contained in:
Allen Lavoie 2020-05-04 09:18:36 -07:00 committed by TensorFlower Gardener
parent dbe7b589c5
commit 6e3bea20a1
7 changed files with 37 additions and 44 deletions

View File

@ -58,6 +58,7 @@ filegroup(
name = "pywrap_required_hdrs",
srcs = [
"c_api_internal.h",
"conversion_macros.h",
"python_api.h",
"tensor_interface.h",
"tf_status_helper.h",

View File

@ -16,15 +16,18 @@ limitations under the License.
#ifndef TENSORFLOW_C_CONVERSION_MACROS_H_
#define TENSORFLOW_C_CONVERSION_MACROS_H_
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
inline cpp_impl *unwrap(wrapper *w) { \
return reinterpret_cast<cpp_impl *>(w); \
} \
\
inline const cpp_impl *unwrap(const wrapper *w) { \
return reinterpret_cast<const cpp_impl *>(w); \
} \
\
inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast<wrapper *>(i); }
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
inline cpp_impl *unwrap(wrapper *w) { \
return reinterpret_cast<cpp_impl *>(w); \
} \
\
inline const cpp_impl *unwrap(const wrapper *w) { \
return reinterpret_cast<const cpp_impl *>(w); \
} \
\
inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast<wrapper *>(i); } \
inline const wrapper *wrap(const cpp_impl *i) { \
return reinterpret_cast<const wrapper *>(i); \
}
#endif // TENSORFLOW_C_CONVERSION_MACROS_H_

View File

@ -247,6 +247,7 @@ cc_library(
"//tensorflow:internal",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c:tf_status",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:attr_builder",

View File

@ -1483,9 +1483,14 @@ void TFE_ContextEndStep(TFE_Context* ctx) {
context->EndStep();
}
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
return tensorflow::wrap(
&OperationFromInterface(tensorflow::unwrap(op))->Attrs());
}
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
tensorflow::AttrValueMap m;
attrs->attributes->FillAttrValueMap(&m);
tensorflow::unwrap(attrs)->FillAttrValueMap(&m);
tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(op));
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
@ -1497,8 +1502,8 @@ void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
TF_Status* status) {
tensorflow::NameAttrList name_and_attrs;
attrs->attributes->FillAttrValueMap(name_and_attrs.mutable_attr());
name_and_attrs.set_name(attrs->attributes->op_name());
tensorflow::unwrap(attrs)->FillAttrValueMap(name_and_attrs.mutable_attr());
name_and_attrs.set_name(tensorflow::unwrap(attrs)->op_name());
status->status = MessageToBuffer(name_and_attrs, buf);
}
@ -1619,9 +1624,9 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
}
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
TF_Status status;
TFE_OpAttrs attributes(&op->Attrs());
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
&attributes, num_retvals, outputs.data(), &status, info_);
wrap(&op->Attrs()), num_retvals, outputs.data(), &status,
info_);
if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = tensorflow::TensorHandleFromInterface(

View File

@ -431,6 +431,9 @@ TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
// A reference to an op's name -> attribute mapping
typedef struct TFE_OpAttrs TFE_OpAttrs;
// Fetch a reference to `op`'s attributes. The returned reference is only valid
// while `op` is alive.
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op);
// Add attributes in `attrs` to `op`.
//
// Does not overwrite or update existing attributes, but adds new ones.

View File

@ -1591,15 +1591,11 @@ TEST(CAPI, TestTFE_OpAddAttrs) {
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
// There is currently no API to fetch attributes from an operation, fetching
// happens only as an implementation detail of custom devices.
tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(var_op));
TFE_OpAttrs attributes{&operation->Attrs()};
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
TFE_OpAddAttrs(copy_op, &attributes);
TFE_OpAddAttrs(copy_op, attributes);
unsigned char is_list = 0;
ASSERT_EQ(TF_ATTR_TYPE,
TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
@ -1631,14 +1627,10 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
// There is currently no API to fetch attributes from an operation, fetching
// happens only as an implementation detail of custom devices.
tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(var_op));
TFE_OpAttrs attributes{&operation->Attrs()};
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
TF_Buffer* serialized_attr_values = TF_NewBuffer();
TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status);
TFE_OpAttrsSerialize(attributes, serialized_attr_values, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::NameAttrList name_and_attrs;
ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,

View File

@ -15,33 +15,21 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
#include <algorithm>
#include <cstddef>
#include <map>
#include <memory>
#include <queue>
#include <string>
#include <vector>
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/framework/attr_value.pb.h"
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
// that sometimes do not require serialization.
typedef struct TFE_OpAttrs TFE_OpAttrs;
typedef struct TFE_Context TFE_Context;
typedef struct TFE_Op TFE_Op;
struct TFE_OpAttrs {
explicit TFE_OpAttrs() : attributes(nullptr) {}
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value)
: attributes(value) {}
const tensorflow::AttrBuilder* attributes;
};
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AttrBuilder, TFE_OpAttrs);
// Set an AttrValue on the op. Doesn't handle the list types.
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
const tensorflow::AttrValue& default_value,