Add an experimental eager C API for generically fetching and setting op attributes.
Right now you can only fetch the whole attribute map and set it wholesale, but we can add more fine-grained attribute control in the future. This allows the custom device API to pass in attributes, and custom devices to forward these to their own TFE_Execute calls. This is required for creating variables. PiperOrigin-RevId: 296096192 Change-Id: I98c23bdcd13e479235b3e27850b1bb0bd7a53bba
This commit is contained in:
parent
ccfc7fd531
commit
1f5bc8a979
@ -1199,14 +1199,6 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||
}
|
||||
|
||||
if (dtype == TF_STRING || dtype == TF_RESOURCE ||
|
||||
!tensorflow::DataTypeCanUseMemcpy(
|
||||
static_cast<tensorflow::DataType>(dtype))) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Trying to create a tensor with a pointer to non-pod memory.");
|
||||
deallocator(data, len, deallocator_arg);
|
||||
return nullptr;
|
||||
}
|
||||
// TODO(apassos) do we need to wrap the deallocator here to make sure to sync
|
||||
// the device?
|
||||
TF_ManagedBuffer* buf =
|
||||
@ -1680,6 +1672,19 @@ void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
|
||||
|
||||
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
|
||||
|
||||
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
|
||||
*attrs = TFE_OpAttrs(&op->operation.Attrs());
|
||||
}
|
||||
|
||||
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||
tensorflow::AttrValueMap m;
|
||||
attrs->attributes->FillAttrValueMap(&m);
|
||||
tensorflow::AttrBuilder* destination = op->operation.MutableAttrs();
|
||||
for (auto attribute : m) {
|
||||
destination->Set(attribute.first, attribute.second);
|
||||
}
|
||||
}
|
||||
|
||||
namespace tensorflow {
|
||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
const tensorflow::AttrValue& default_value,
|
||||
@ -1799,10 +1804,10 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
op->Inputs()[i])});
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||
// TODO(allenl): figure out how to get attrs from EagerOperation
|
||||
TF_Status status;
|
||||
TFE_OpAttrs attributes(&op->Attrs());
|
||||
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
|
||||
num_retvals, outputs.data(), &status, info_);
|
||||
&attributes, num_retvals, outputs.data(), &status, info_);
|
||||
if (status.status.ok()) {
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
|
@ -424,7 +424,27 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
|
||||
TF_Buffer* buf);
|
||||
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 0
|
||||
// APIs for generically dealing with op attributes (e.g. when forwarding them
|
||||
// through custom device implementations).
|
||||
//
|
||||
// TODO(allenl): Currently these are black boxes, but we should have some way to
|
||||
// inspect values. This would let people e.g. copy over most attributes and then
|
||||
// modify some based on their values.
|
||||
|
||||
// A reference to an op's name -> attribute mapping
|
||||
typedef struct TFE_OpAttrs TFE_OpAttrs;
|
||||
|
||||
// Fetch a struct with a reference to information about attributes of `op`.
|
||||
//
|
||||
// The `attrs` struct does not own any memory, and `op` must outlive it.
|
||||
TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
|
||||
|
||||
// Add attributes in `attrs` to `op`.
|
||||
//
|
||||
// Does not overwrite or update existing attributes, but adds new ones.
|
||||
TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs);
|
||||
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 1
|
||||
|
||||
// Struct to be filled in
|
||||
typedef struct TFE_CustomDevice {
|
||||
@ -441,10 +461,10 @@ typedef struct TFE_CustomDevice {
|
||||
void* device_info);
|
||||
|
||||
// Method to execute an operation.
|
||||
// TODO(allenl) figure out a generic way of passing attrs here
|
||||
void (*execute)(int num_inputs, TFE_TensorHandle** inputs,
|
||||
const char* operation_name, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info);
|
||||
|
||||
// Method to delete a device.
|
||||
void (*delete_device)(void* device_info);
|
||||
|
@ -236,4 +236,13 @@ struct TFE_Executor {
|
||||
tensorflow::EagerExecutor* unowned_executor;
|
||||
};
|
||||
|
||||
struct TFE_OpAttrs {
|
||||
explicit TFE_OpAttrs() : attributes(nullptr) {}
|
||||
|
||||
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value)
|
||||
: attributes(value) {}
|
||||
|
||||
const tensorflow::AttrBuilder* attributes;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||
|
@ -1449,4 +1449,38 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTFE_OpGetAttrs) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_Op* varop = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
TFE_OpSetAttrType(varop, "dtype", TF_INT64);
|
||||
TFE_OpSetAttrShape(varop, "shape", {}, 0, status);
|
||||
TFE_OpAttrs attributes;
|
||||
TFE_OpGetAttrs(varop, &attributes);
|
||||
|
||||
TFE_Op* varop_copy = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
TFE_OpSetAttrType(varop_copy, "dtype", TF_FLOAT);
|
||||
TFE_OpAddAttrs(varop_copy, &attributes);
|
||||
unsigned char is_list = 0;
|
||||
ASSERT_EQ(TF_ATTR_TYPE,
|
||||
TFE_OpGetAttrType(varop_copy, "dtype", &is_list, status));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(TF_ATTR_SHAPE,
|
||||
TFE_OpGetAttrType(varop_copy, "shape", &is_list, status));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
varop_copy->operation.Attrs().FillAttrValueMap(&attr_values);
|
||||
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteOp(varop);
|
||||
TFE_DeleteOp(varop_copy);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
@ -83,12 +84,14 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
|
||||
}
|
||||
|
||||
void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
|
||||
const char* operation_name, int* num_outputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddAttrs(op, attributes);
|
||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||
for (int j = 0; j < num_inputs; ++j) {
|
||||
TFE_TensorHandle* input = inputs[j];
|
||||
@ -203,4 +206,89 @@ TEST(CUSTOM_DEVICE, ResetOperation) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed);
|
||||
|
||||
// Create a variable handle placed on the custom device.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
|
||||
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
||||
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
|
||||
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
|
||||
|
||||
// Assign to the variable, copying to the custom device.
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
|
||||
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
|
||||
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpAddInput(op.get(), one.get(), status.get());
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
executed = false;
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
|
||||
// Read the variable's value.
|
||||
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
executed = false;
|
||||
num_retvals = 1;
|
||||
TFE_TensorHandle* var_value = nullptr;
|
||||
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
auto value_cleaner = tensorflow::gtl::MakeCleanup(
|
||||
[var_value]() { TFE_DeleteTensorHandle(var_value); });
|
||||
ASSERT_EQ(tensorflow::string(name),
|
||||
tensorflow::string(
|
||||
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
|
||||
TFE_TensorHandle* var_value_unpacked =
|
||||
reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(var_value, status.get()))
|
||||
->tensor;
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
|
||||
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
|
||||
TF_DeleteTensor);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(111., *static_cast<float*>(TF_TensorData(resolved_value.get())));
|
||||
|
||||
// Free the backing buffer for the variable.
|
||||
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpSetDevice(op.get(), name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user