Add an experimental eager C API for generically fetching and setting op attributes.
Right now you can only fetch the whole attribute map and set it wholesale, but we can add more fine-grained attribute control in the future. This allows the custom device API to pass in attributes, and custom devices to forward these to their own TFE_Execute calls. This is required for creating variables. PiperOrigin-RevId: 296096192 Change-Id: I98c23bdcd13e479235b3e27850b1bb0bd7a53bba
This commit is contained in:
parent
ccfc7fd531
commit
1f5bc8a979
@ -1199,14 +1199,6 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
|||||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (dtype == TF_STRING || dtype == TF_RESOURCE ||
|
|
||||||
!tensorflow::DataTypeCanUseMemcpy(
|
|
||||||
static_cast<tensorflow::DataType>(dtype))) {
|
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
|
||||||
"Trying to create a tensor with a pointer to non-pod memory.");
|
|
||||||
deallocator(data, len, deallocator_arg);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
// TODO(apassos) do we need to wrap the deallocator here to make sure to sync
|
// TODO(apassos) do we need to wrap the deallocator here to make sure to sync
|
||||||
// the device?
|
// the device?
|
||||||
TF_ManagedBuffer* buf =
|
TF_ManagedBuffer* buf =
|
||||||
@ -1680,6 +1672,19 @@ void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
|
|||||||
|
|
||||||
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
|
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
|
||||||
|
|
||||||
|
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
|
||||||
|
*attrs = TFE_OpAttrs(&op->operation.Attrs());
|
||||||
|
}
|
||||||
|
|
||||||
|
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||||
|
tensorflow::AttrValueMap m;
|
||||||
|
attrs->attributes->FillAttrValueMap(&m);
|
||||||
|
tensorflow::AttrBuilder* destination = op->operation.MutableAttrs();
|
||||||
|
for (auto attribute : m) {
|
||||||
|
destination->Set(attribute.first, attribute.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||||
const tensorflow::AttrValue& default_value,
|
const tensorflow::AttrValue& default_value,
|
||||||
@ -1799,10 +1804,10 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
|||||||
op->Inputs()[i])});
|
op->Inputs()[i])});
|
||||||
}
|
}
|
||||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||||
// TODO(allenl): figure out how to get attrs from EagerOperation
|
|
||||||
TF_Status status;
|
TF_Status status;
|
||||||
|
TFE_OpAttrs attributes(&op->Attrs());
|
||||||
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
|
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
|
||||||
num_retvals, outputs.data(), &status, info_);
|
&attributes, num_retvals, outputs.data(), &status, info_);
|
||||||
if (status.status.ok()) {
|
if (status.status.ok()) {
|
||||||
for (int i = 0; i < *num_retvals; ++i) {
|
for (int i = 0; i < *num_retvals; ++i) {
|
||||||
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||||
|
@ -424,7 +424,27 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
|||||||
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
|
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
|
||||||
TF_Buffer* buf);
|
TF_Buffer* buf);
|
||||||
|
|
||||||
#define TFE_CUSTOM_DEVICE_VERSION 0
|
// APIs for generically dealing with op attributes (e.g. when forwarding them
|
||||||
|
// through custom device implementations).
|
||||||
|
//
|
||||||
|
// TODO(allenl): Currently these are black boxes, but we should have some way to
|
||||||
|
// inspect values. This would let people e.g. copy over most attributes and then
|
||||||
|
// modify some based on their values.
|
||||||
|
|
||||||
|
// A reference to an op's name -> attribute mapping
|
||||||
|
typedef struct TFE_OpAttrs TFE_OpAttrs;
|
||||||
|
|
||||||
|
// Fetch a struct with a reference to information about attributes of `op`.
|
||||||
|
//
|
||||||
|
// The `attrs` struct does not own any memory, and `op` must outlive it.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
|
||||||
|
|
||||||
|
// Add attributes in `attrs` to `op`.
|
||||||
|
//
|
||||||
|
// Does not overwrite or update existing attributes, but adds new ones.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs);
|
||||||
|
|
||||||
|
#define TFE_CUSTOM_DEVICE_VERSION 1
|
||||||
|
|
||||||
// Struct to be filled in
|
// Struct to be filled in
|
||||||
typedef struct TFE_CustomDevice {
|
typedef struct TFE_CustomDevice {
|
||||||
@ -441,10 +461,10 @@ typedef struct TFE_CustomDevice {
|
|||||||
void* device_info);
|
void* device_info);
|
||||||
|
|
||||||
// Method to execute an operation.
|
// Method to execute an operation.
|
||||||
// TODO(allenl) figure out a generic way of passing attrs here
|
|
||||||
void (*execute)(int num_inputs, TFE_TensorHandle** inputs,
|
void (*execute)(int num_inputs, TFE_TensorHandle** inputs,
|
||||||
const char* operation_name, int* num_outputs,
|
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||||
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
|
int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s,
|
||||||
|
void* device_info);
|
||||||
|
|
||||||
// Method to delete a device.
|
// Method to delete a device.
|
||||||
void (*delete_device)(void* device_info);
|
void (*delete_device)(void* device_info);
|
||||||
|
@ -236,4 +236,13 @@ struct TFE_Executor {
|
|||||||
tensorflow::EagerExecutor* unowned_executor;
|
tensorflow::EagerExecutor* unowned_executor;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct TFE_OpAttrs {
|
||||||
|
explicit TFE_OpAttrs() : attributes(nullptr) {}
|
||||||
|
|
||||||
|
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value)
|
||||||
|
: attributes(value) {}
|
||||||
|
|
||||||
|
const tensorflow::AttrBuilder* attributes;
|
||||||
|
};
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||||
|
@ -1449,4 +1449,38 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
|
|||||||
TFE_DeleteContext(ctx);
|
TFE_DeleteContext(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, TestTFE_OpGetAttrs) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_Op* varop = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||||
|
TFE_OpSetAttrType(varop, "dtype", TF_INT64);
|
||||||
|
TFE_OpSetAttrShape(varop, "shape", {}, 0, status);
|
||||||
|
TFE_OpAttrs attributes;
|
||||||
|
TFE_OpGetAttrs(varop, &attributes);
|
||||||
|
|
||||||
|
TFE_Op* varop_copy = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||||
|
TFE_OpSetAttrType(varop_copy, "dtype", TF_FLOAT);
|
||||||
|
TFE_OpAddAttrs(varop_copy, &attributes);
|
||||||
|
unsigned char is_list = 0;
|
||||||
|
ASSERT_EQ(TF_ATTR_TYPE,
|
||||||
|
TFE_OpGetAttrType(varop_copy, "dtype", &is_list, status));
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
ASSERT_EQ(TF_ATTR_SHAPE,
|
||||||
|
TFE_OpGetAttrType(varop_copy, "shape", &is_list, status));
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
tensorflow::AttrValueMap attr_values;
|
||||||
|
varop_copy->operation.Attrs().FillAttrValueMap(&attr_values);
|
||||||
|
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
|
||||||
|
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
TFE_DeleteOp(varop);
|
||||||
|
TFE_DeleteOp(varop_copy);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -83,12 +84,14 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
|
void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
|
||||||
const char* operation_name, int* num_outputs,
|
const char* operation_name,
|
||||||
|
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||||
TFE_TensorHandle** outputs, TF_Status* s,
|
TFE_TensorHandle** outputs, TF_Status* s,
|
||||||
void* device_info) {
|
void* device_info) {
|
||||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||||
TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s));
|
TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s));
|
||||||
if (TF_GetCode(s) != TF_OK) return;
|
if (TF_GetCode(s) != TF_OK) return;
|
||||||
|
TFE_OpAddAttrs(op, attributes);
|
||||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||||
for (int j = 0; j < num_inputs; ++j) {
|
for (int j = 0; j < num_inputs; ++j) {
|
||||||
TFE_TensorHandle* input = inputs[j];
|
TFE_TensorHandle* input = inputs[j];
|
||||||
@ -203,4 +206,89 @@ TEST(CUSTOM_DEVICE, ResetOperation) {
|
|||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||||
|
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||||
|
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||||
|
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
bool arrived = false;
|
||||||
|
bool executed = false;
|
||||||
|
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||||
|
RegisterLoggingDevice(context.get(), name, &arrived, &executed);
|
||||||
|
|
||||||
|
// Create a variable handle placed on the custom device.
|
||||||
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||||
|
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||||
|
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
|
||||||
|
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
||||||
|
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
|
||||||
|
TFE_OpSetDevice(op.get(), name, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_TensorHandle* var_handle = nullptr;
|
||||||
|
int num_retvals = 1;
|
||||||
|
executed = false;
|
||||||
|
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(executed);
|
||||||
|
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
|
||||||
|
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
|
||||||
|
|
||||||
|
// Assign to the variable, copying to the custom device.
|
||||||
|
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
|
||||||
|
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
|
||||||
|
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
|
||||||
|
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||||
|
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||||
|
TFE_OpAddInput(op.get(), one.get(), status.get());
|
||||||
|
TFE_OpSetDevice(op.get(), name, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
executed = false;
|
||||||
|
num_retvals = 0;
|
||||||
|
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(executed);
|
||||||
|
|
||||||
|
// Read the variable's value.
|
||||||
|
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
|
||||||
|
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||||
|
TFE_OpSetDevice(op.get(), name, status.get());
|
||||||
|
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
executed = false;
|
||||||
|
num_retvals = 1;
|
||||||
|
TFE_TensorHandle* var_value = nullptr;
|
||||||
|
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(executed);
|
||||||
|
auto value_cleaner = tensorflow::gtl::MakeCleanup(
|
||||||
|
[var_value]() { TFE_DeleteTensorHandle(var_value); });
|
||||||
|
ASSERT_EQ(tensorflow::string(name),
|
||||||
|
tensorflow::string(
|
||||||
|
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
|
||||||
|
TFE_TensorHandle* var_value_unpacked =
|
||||||
|
reinterpret_cast<LoggedTensor*>(
|
||||||
|
TFE_TensorHandleDevicePointer(var_value, status.get()))
|
||||||
|
->tensor;
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
|
||||||
|
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
|
||||||
|
TF_DeleteTensor);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_EQ(111., *static_cast<float*>(TF_TensorData(resolved_value.get())));
|
||||||
|
|
||||||
|
// Free the backing buffer for the variable.
|
||||||
|
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
|
||||||
|
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||||
|
TFE_OpSetDevice(op.get(), name, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
num_retvals = 0;
|
||||||
|
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Loading…
x
Reference in New Issue
Block a user