Add an experimental eager C API for generically fetching and setting op attributes.

Right now you can only fetch the whole attribute map and set it wholesale, but we can add more fine-grained attribute control in the future.

This allows the custom device API to pass in attributes, and custom devices to forward these to their own TFE_Execute calls. This is required for creating variables.

PiperOrigin-RevId: 296096192
Change-Id: I98c23bdcd13e479235b3e27850b1bb0bd7a53bba
This commit is contained in:
Allen Lavoie 2020-02-19 17:31:04 -08:00 committed by TensorFlower Gardener
parent ccfc7fd531
commit 1f5bc8a979
5 changed files with 171 additions and 15 deletions

View File

@ -1199,14 +1199,6 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
if (dtype == TF_STRING || dtype == TF_RESOURCE ||
!tensorflow::DataTypeCanUseMemcpy(
static_cast<tensorflow::DataType>(dtype))) {
status->status = tensorflow::errors::InvalidArgument(
"Trying to create a tensor with a pointer to non-pod memory.");
deallocator(data, len, deallocator_arg);
return nullptr;
}
// TODO(apassos) do we need to wrap the deallocator here to make sure to sync
// the device?
TF_ManagedBuffer* buf =
@ -1680,6 +1672,19 @@ void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
*attrs = TFE_OpAttrs(&op->operation.Attrs());
}
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
tensorflow::AttrValueMap m;
attrs->attributes->FillAttrValueMap(&m);
tensorflow::AttrBuilder* destination = op->operation.MutableAttrs();
for (auto attribute : m) {
destination->Set(attribute.first, attribute.second);
}
}
namespace tensorflow {
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
const tensorflow::AttrValue& default_value,
@ -1799,10 +1804,10 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
op->Inputs()[i])});
}
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
// TODO(allenl): figure out how to get attrs from EagerOperation
TF_Status status;
TFE_OpAttrs attributes(&op->Attrs());
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
num_retvals, outputs.data(), &status, info_);
&attributes, num_retvals, outputs.data(), &status, info_);
if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(

View File

@ -424,7 +424,27 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
TF_Buffer* buf);
#define TFE_CUSTOM_DEVICE_VERSION 0
// APIs for generically dealing with op attributes (e.g. when forwarding them
// through custom device implementations).
//
// TODO(allenl): Currently these are black boxes, but we should have some way to
// inspect values. This would let people e.g. copy over most attributes and then
// modify some based on their values.
// A reference to an op's name -> attribute mapping
typedef struct TFE_OpAttrs TFE_OpAttrs;
// Fetch a struct with a reference to information about attributes of `op`.
//
// The `attrs` struct does not own any memory, and `op` must outlive it.
TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
// Add attributes in `attrs` to `op`.
//
// Does not overwrite or update existing attributes, but adds new ones.
TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs);
#define TFE_CUSTOM_DEVICE_VERSION 1
// Struct to be filled in
typedef struct TFE_CustomDevice {
@ -441,10 +461,10 @@ typedef struct TFE_CustomDevice {
void* device_info);
// Method to execute an operation.
// TODO(allenl) figure out a generic way of passing attrs here
void (*execute)(int num_inputs, TFE_TensorHandle** inputs,
const char* operation_name, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
const char* operation_name, const TFE_OpAttrs* attributes,
int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s,
void* device_info);
// Method to delete a device.
void (*delete_device)(void* device_info);

View File

@ -236,4 +236,13 @@ struct TFE_Executor {
tensorflow::EagerExecutor* unowned_executor;
};
struct TFE_OpAttrs {
explicit TFE_OpAttrs() : attributes(nullptr) {}
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value)
: attributes(value) {}
const tensorflow::AttrBuilder* attributes;
};
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_

View File

@ -1449,4 +1449,38 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
TFE_DeleteContext(ctx);
}
TEST(CAPI, TestTFE_OpGetAttrs) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Op* varop = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(varop, "dtype", TF_INT64);
TFE_OpSetAttrShape(varop, "shape", {}, 0, status);
TFE_OpAttrs attributes;
TFE_OpGetAttrs(varop, &attributes);
TFE_Op* varop_copy = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(varop_copy, "dtype", TF_FLOAT);
TFE_OpAddAttrs(varop_copy, &attributes);
unsigned char is_list = 0;
ASSERT_EQ(TF_ATTR_TYPE,
TFE_OpGetAttrType(varop_copy, "dtype", &is_list, status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_ATTR_SHAPE,
TFE_OpGetAttrType(varop_copy, "shape", &is_list, status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
varop_copy->operation.Attrs().FillAttrValueMap(&attr_values);
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
TF_DeleteStatus(status);
TFE_DeleteOp(varop);
TFE_DeleteOp(varop_copy);
TFE_DeleteContext(ctx);
}
} // namespace

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
namespace {
@ -83,12 +84,14 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
}
void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
const char* operation_name, int* num_outputs,
const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
@ -203,4 +206,89 @@ TEST(CUSTOM_DEVICE, ResetOperation) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, MakeVariable) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed);
// Create a variable handle placed on the custom device.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
TFE_OpSetAttrString(op.get(), "container", "", 0);
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
executed = false;
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
// Assign to the variable, copying to the custom device.
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpAddInput(op.get(), one.get(), status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
// Read the variable's value.
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetDevice(op.get(), name, status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
auto value_cleaner = tensorflow::gtl::MakeCleanup(
[var_value]() { TFE_DeleteTensorHandle(var_value); });
ASSERT_EQ(tensorflow::string(name),
tensorflow::string(
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
TFE_TensorHandle* var_value_unpacked =
reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(var_value, status.get()))
->tensor;
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(111., *static_cast<float*>(TF_TensorData(resolved_value.get())));
// Free the backing buffer for the variable.
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
} // namespace