Less pointer indirection for TFE_OpAttrs, add TFE_OpGetAttrs
We'll want this for implementing copy for `TF_AbstractOp`s backed by `TFE_Op`s (since we want to copy the type/attributes but not the inputs). PiperOrigin-RevId: 309756974 Change-Id: I07a8c48f50ab6d3c8a7d7db972fb60202b86434d
This commit is contained in:
parent
dbe7b589c5
commit
6e3bea20a1
|
@ -58,6 +58,7 @@ filegroup(
|
|||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"c_api_internal.h",
|
||||
"conversion_macros.h",
|
||||
"python_api.h",
|
||||
"tensor_interface.h",
|
||||
"tf_status_helper.h",
|
||||
|
|
|
@ -16,15 +16,18 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_C_CONVERSION_MACROS_H_
|
||||
#define TENSORFLOW_C_CONVERSION_MACROS_H_
|
||||
|
||||
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
|
||||
inline cpp_impl *unwrap(wrapper *w) { \
|
||||
return reinterpret_cast<cpp_impl *>(w); \
|
||||
} \
|
||||
\
|
||||
inline const cpp_impl *unwrap(const wrapper *w) { \
|
||||
return reinterpret_cast<const cpp_impl *>(w); \
|
||||
} \
|
||||
\
|
||||
inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast<wrapper *>(i); }
|
||||
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
|
||||
inline cpp_impl *unwrap(wrapper *w) { \
|
||||
return reinterpret_cast<cpp_impl *>(w); \
|
||||
} \
|
||||
\
|
||||
inline const cpp_impl *unwrap(const wrapper *w) { \
|
||||
return reinterpret_cast<const cpp_impl *>(w); \
|
||||
} \
|
||||
\
|
||||
inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast<wrapper *>(i); } \
|
||||
inline const wrapper *wrap(const cpp_impl *i) { \
|
||||
return reinterpret_cast<const wrapper *>(i); \
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_C_CONVERSION_MACROS_H_
|
||||
|
|
|
@ -247,6 +247,7 @@ cc_library(
|
|||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
|
|
|
@ -1483,9 +1483,14 @@ void TFE_ContextEndStep(TFE_Context* ctx) {
|
|||
context->EndStep();
|
||||
}
|
||||
|
||||
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
|
||||
return tensorflow::wrap(
|
||||
&OperationFromInterface(tensorflow::unwrap(op))->Attrs());
|
||||
}
|
||||
|
||||
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||
tensorflow::AttrValueMap m;
|
||||
attrs->attributes->FillAttrValueMap(&m);
|
||||
tensorflow::unwrap(attrs)->FillAttrValueMap(&m);
|
||||
tensorflow::EagerOperation* operation =
|
||||
OperationFromInterface(tensorflow::unwrap(op));
|
||||
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
|
||||
|
@ -1497,8 +1502,8 @@ void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
|||
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
tensorflow::NameAttrList name_and_attrs;
|
||||
attrs->attributes->FillAttrValueMap(name_and_attrs.mutable_attr());
|
||||
name_and_attrs.set_name(attrs->attributes->op_name());
|
||||
tensorflow::unwrap(attrs)->FillAttrValueMap(name_and_attrs.mutable_attr());
|
||||
name_and_attrs.set_name(tensorflow::unwrap(attrs)->op_name());
|
||||
status->status = MessageToBuffer(name_and_attrs, buf);
|
||||
}
|
||||
|
||||
|
@ -1619,9 +1624,9 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
|||
}
|
||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||
TF_Status status;
|
||||
TFE_OpAttrs attributes(&op->Attrs());
|
||||
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
|
||||
&attributes, num_retvals, outputs.data(), &status, info_);
|
||||
wrap(&op->Attrs()), num_retvals, outputs.data(), &status,
|
||||
info_);
|
||||
if (status.status.ok()) {
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = tensorflow::TensorHandleFromInterface(
|
||||
|
|
|
@ -431,6 +431,9 @@ TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
|
|||
// A reference to an op's name -> attribute mapping
|
||||
typedef struct TFE_OpAttrs TFE_OpAttrs;
|
||||
|
||||
// Fetch a reference to `op`'s attributes. The returned reference is only valid
|
||||
// while `op` is alive.
|
||||
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op);
|
||||
// Add attributes in `attrs` to `op`.
|
||||
//
|
||||
// Does not overwrite or update existing attributes, but adds new ones.
|
||||
|
|
|
@ -1591,15 +1591,11 @@ TEST(CAPI, TestTFE_OpAddAttrs) {
|
|||
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||
// There is currently no API to fetch attributes from an operation, fetching
|
||||
// happens only as an implementation detail of custom devices.
|
||||
tensorflow::EagerOperation* operation =
|
||||
OperationFromInterface(tensorflow::unwrap(var_op));
|
||||
TFE_OpAttrs attributes{&operation->Attrs()};
|
||||
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
|
||||
|
||||
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddAttrs(copy_op, &attributes);
|
||||
TFE_OpAddAttrs(copy_op, attributes);
|
||||
unsigned char is_list = 0;
|
||||
ASSERT_EQ(TF_ATTR_TYPE,
|
||||
TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
|
||||
|
@ -1631,14 +1627,10 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
|
|||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||
// There is currently no API to fetch attributes from an operation, fetching
|
||||
// happens only as an implementation detail of custom devices.
|
||||
tensorflow::EagerOperation* operation =
|
||||
OperationFromInterface(tensorflow::unwrap(var_op));
|
||||
TFE_OpAttrs attributes{&operation->Attrs()};
|
||||
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
|
||||
|
||||
TF_Buffer* serialized_attr_values = TF_NewBuffer();
|
||||
TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status);
|
||||
TFE_OpAttrsSerialize(attributes, serialized_attr_values, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
tensorflow::NameAttrList name_and_attrs;
|
||||
ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,
|
||||
|
|
|
@ -15,33 +15,21 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
|
||||
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
|
||||
// that sometimes do not require serialization.
|
||||
typedef struct TFE_OpAttrs TFE_OpAttrs;
|
||||
|
||||
typedef struct TFE_Context TFE_Context;
|
||||
typedef struct TFE_Op TFE_Op;
|
||||
|
||||
struct TFE_OpAttrs {
|
||||
explicit TFE_OpAttrs() : attributes(nullptr) {}
|
||||
|
||||
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value)
|
||||
: attributes(value) {}
|
||||
|
||||
const tensorflow::AttrBuilder* attributes;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AttrBuilder, TFE_OpAttrs);
|
||||
|
||||
// Set an AttrValue on the op. Doesn't handle the list types.
|
||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
const tensorflow::AttrValue& default_value,
|
||||
|
|
Loading…
Reference in New Issue