Less pointer indirection for TFE_OpAttrs, add TFE_OpGetAttrs
We'll want this for implementing copy for `TF_AbstractOp`s backed by `TFE_Op`s (since we want to copy the type/attributes but not the inputs). PiperOrigin-RevId: 309756974 Change-Id: I07a8c48f50ab6d3c8a7d7db972fb60202b86434d
This commit is contained in:
parent
dbe7b589c5
commit
6e3bea20a1
|
@ -58,6 +58,7 @@ filegroup(
|
||||||
name = "pywrap_required_hdrs",
|
name = "pywrap_required_hdrs",
|
||||||
srcs = [
|
srcs = [
|
||||||
"c_api_internal.h",
|
"c_api_internal.h",
|
||||||
|
"conversion_macros.h",
|
||||||
"python_api.h",
|
"python_api.h",
|
||||||
"tensor_interface.h",
|
"tensor_interface.h",
|
||||||
"tf_status_helper.h",
|
"tf_status_helper.h",
|
||||||
|
|
|
@ -16,15 +16,18 @@ limitations under the License.
|
||||||
#ifndef TENSORFLOW_C_CONVERSION_MACROS_H_
|
#ifndef TENSORFLOW_C_CONVERSION_MACROS_H_
|
||||||
#define TENSORFLOW_C_CONVERSION_MACROS_H_
|
#define TENSORFLOW_C_CONVERSION_MACROS_H_
|
||||||
|
|
||||||
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
|
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
|
||||||
inline cpp_impl *unwrap(wrapper *w) { \
|
inline cpp_impl *unwrap(wrapper *w) { \
|
||||||
return reinterpret_cast<cpp_impl *>(w); \
|
return reinterpret_cast<cpp_impl *>(w); \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
inline const cpp_impl *unwrap(const wrapper *w) { \
|
inline const cpp_impl *unwrap(const wrapper *w) { \
|
||||||
return reinterpret_cast<const cpp_impl *>(w); \
|
return reinterpret_cast<const cpp_impl *>(w); \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast<wrapper *>(i); }
|
inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast<wrapper *>(i); } \
|
||||||
|
inline const wrapper *wrap(const cpp_impl *i) { \
|
||||||
|
return reinterpret_cast<const wrapper *>(i); \
|
||||||
|
}
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_CONVERSION_MACROS_H_
|
#endif // TENSORFLOW_C_CONVERSION_MACROS_H_
|
||||||
|
|
|
@ -247,6 +247,7 @@ cc_library(
|
||||||
"//tensorflow:internal",
|
"//tensorflow:internal",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/c:conversion_macros",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||||
|
|
|
@ -1483,9 +1483,14 @@ void TFE_ContextEndStep(TFE_Context* ctx) {
|
||||||
context->EndStep();
|
context->EndStep();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
|
||||||
|
return tensorflow::wrap(
|
||||||
|
&OperationFromInterface(tensorflow::unwrap(op))->Attrs());
|
||||||
|
}
|
||||||
|
|
||||||
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||||
tensorflow::AttrValueMap m;
|
tensorflow::AttrValueMap m;
|
||||||
attrs->attributes->FillAttrValueMap(&m);
|
tensorflow::unwrap(attrs)->FillAttrValueMap(&m);
|
||||||
tensorflow::EagerOperation* operation =
|
tensorflow::EagerOperation* operation =
|
||||||
OperationFromInterface(tensorflow::unwrap(op));
|
OperationFromInterface(tensorflow::unwrap(op));
|
||||||
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
|
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
|
||||||
|
@ -1497,8 +1502,8 @@ void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||||
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
|
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
tensorflow::NameAttrList name_and_attrs;
|
tensorflow::NameAttrList name_and_attrs;
|
||||||
attrs->attributes->FillAttrValueMap(name_and_attrs.mutable_attr());
|
tensorflow::unwrap(attrs)->FillAttrValueMap(name_and_attrs.mutable_attr());
|
||||||
name_and_attrs.set_name(attrs->attributes->op_name());
|
name_and_attrs.set_name(tensorflow::unwrap(attrs)->op_name());
|
||||||
status->status = MessageToBuffer(name_and_attrs, buf);
|
status->status = MessageToBuffer(name_and_attrs, buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1619,9 +1624,9 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||||
}
|
}
|
||||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||||
TF_Status status;
|
TF_Status status;
|
||||||
TFE_OpAttrs attributes(&op->Attrs());
|
|
||||||
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
|
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
|
||||||
&attributes, num_retvals, outputs.data(), &status, info_);
|
wrap(&op->Attrs()), num_retvals, outputs.data(), &status,
|
||||||
|
info_);
|
||||||
if (status.status.ok()) {
|
if (status.status.ok()) {
|
||||||
for (int i = 0; i < *num_retvals; ++i) {
|
for (int i = 0; i < *num_retvals; ++i) {
|
||||||
retvals[i] = tensorflow::TensorHandleFromInterface(
|
retvals[i] = tensorflow::TensorHandleFromInterface(
|
||||||
|
|
|
@ -431,6 +431,9 @@ TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
|
||||||
// A reference to an op's name -> attribute mapping
|
// A reference to an op's name -> attribute mapping
|
||||||
typedef struct TFE_OpAttrs TFE_OpAttrs;
|
typedef struct TFE_OpAttrs TFE_OpAttrs;
|
||||||
|
|
||||||
|
// Fetch a reference to `op`'s attributes. The returned reference is only valid
|
||||||
|
// while `op` is alive.
|
||||||
|
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op);
|
||||||
// Add attributes in `attrs` to `op`.
|
// Add attributes in `attrs` to `op`.
|
||||||
//
|
//
|
||||||
// Does not overwrite or update existing attributes, but adds new ones.
|
// Does not overwrite or update existing attributes, but adds new ones.
|
||||||
|
|
|
@ -1591,15 +1591,11 @@ TEST(CAPI, TestTFE_OpAddAttrs) {
|
||||||
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||||
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||||
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||||
// There is currently no API to fetch attributes from an operation, fetching
|
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
|
||||||
// happens only as an implementation detail of custom devices.
|
|
||||||
tensorflow::EagerOperation* operation =
|
|
||||||
OperationFromInterface(tensorflow::unwrap(var_op));
|
|
||||||
TFE_OpAttrs attributes{&operation->Attrs()};
|
|
||||||
|
|
||||||
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||||
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
|
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
|
||||||
TFE_OpAddAttrs(copy_op, &attributes);
|
TFE_OpAddAttrs(copy_op, attributes);
|
||||||
unsigned char is_list = 0;
|
unsigned char is_list = 0;
|
||||||
ASSERT_EQ(TF_ATTR_TYPE,
|
ASSERT_EQ(TF_ATTR_TYPE,
|
||||||
TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
|
TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
|
||||||
|
@ -1631,14 +1627,10 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||||
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||||
// There is currently no API to fetch attributes from an operation, fetching
|
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
|
||||||
// happens only as an implementation detail of custom devices.
|
|
||||||
tensorflow::EagerOperation* operation =
|
|
||||||
OperationFromInterface(tensorflow::unwrap(var_op));
|
|
||||||
TFE_OpAttrs attributes{&operation->Attrs()};
|
|
||||||
|
|
||||||
TF_Buffer* serialized_attr_values = TF_NewBuffer();
|
TF_Buffer* serialized_attr_values = TF_NewBuffer();
|
||||||
TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status);
|
TFE_OpAttrsSerialize(attributes, serialized_attr_values, status);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
tensorflow::NameAttrList name_and_attrs;
|
tensorflow::NameAttrList name_and_attrs;
|
||||||
ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,
|
ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,
|
||||||
|
|
|
@ -15,33 +15,21 @@ limitations under the License.
|
||||||
#ifndef TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
|
#ifndef TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
|
||||||
#define TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
|
#define TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
|
||||||
|
|
||||||
#include <algorithm>
|
#include "tensorflow/c/conversion_macros.h"
|
||||||
#include <cstddef>
|
|
||||||
#include <map>
|
|
||||||
#include <memory>
|
|
||||||
#include <queue>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
|
|
||||||
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
|
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
|
||||||
// that sometimes do not require serialization.
|
// that sometimes do not require serialization.
|
||||||
|
typedef struct TFE_OpAttrs TFE_OpAttrs;
|
||||||
|
|
||||||
typedef struct TFE_Context TFE_Context;
|
typedef struct TFE_Context TFE_Context;
|
||||||
typedef struct TFE_Op TFE_Op;
|
typedef struct TFE_Op TFE_Op;
|
||||||
|
|
||||||
struct TFE_OpAttrs {
|
|
||||||
explicit TFE_OpAttrs() : attributes(nullptr) {}
|
|
||||||
|
|
||||||
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value)
|
|
||||||
: attributes(value) {}
|
|
||||||
|
|
||||||
const tensorflow::AttrBuilder* attributes;
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AttrBuilder, TFE_OpAttrs);
|
||||||
|
|
||||||
// Set an AttrValue on the op. Doesn't handle the list types.
|
// Set an AttrValue on the op. Doesn't handle the list types.
|
||||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||||
const tensorflow::AttrValue& default_value,
|
const tensorflow::AttrValue& default_value,
|
||||||
|
|
Loading…
Reference in New Issue