Add an experimental eager C API for generically fetching and setting op attributes.
Right now you can only fetch the whole attribute map and set it wholesale, but we can add more fine-grained attribute control in the future. This allows the custom device API to pass in attributes, and custom devices to forward these to their own TFE_Execute calls. This is required for creating variables. PiperOrigin-RevId: 296096192 Change-Id: I98c23bdcd13e479235b3e27850b1bb0bd7a53bba
This commit is contained in:
		
							parent
							
								
									ccfc7fd531
								
							
						
					
					
						commit
						1f5bc8a979
					
				| @ -1199,14 +1199,6 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( | ||||
|     dimvec[i] = static_cast<tensorflow::int64>(dims[i]); | ||||
|   } | ||||
| 
 | ||||
|   if (dtype == TF_STRING || dtype == TF_RESOURCE || | ||||
|       !tensorflow::DataTypeCanUseMemcpy( | ||||
|           static_cast<tensorflow::DataType>(dtype))) { | ||||
|     status->status = tensorflow::errors::InvalidArgument( | ||||
|         "Trying to create a tensor with a pointer to non-pod memory."); | ||||
|     deallocator(data, len, deallocator_arg); | ||||
|     return nullptr; | ||||
|   } | ||||
|   // TODO(apassos) do we need to wrap the deallocator here to make sure to sync
 | ||||
|   // the device?
 | ||||
|   TF_ManagedBuffer* buf = | ||||
| @ -1680,6 +1672,19 @@ void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); } | ||||
| 
 | ||||
| void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); } | ||||
| 
 | ||||
| void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) { | ||||
|   *attrs = TFE_OpAttrs(&op->operation.Attrs()); | ||||
| } | ||||
| 
 | ||||
| void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) { | ||||
|   tensorflow::AttrValueMap m; | ||||
|   attrs->attributes->FillAttrValueMap(&m); | ||||
|   tensorflow::AttrBuilder* destination = op->operation.MutableAttrs(); | ||||
|   for (auto attribute : m) { | ||||
|     destination->Set(attribute.first, attribute.second); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| namespace tensorflow { | ||||
| void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, | ||||
|                           const tensorflow::AttrValue& default_value, | ||||
| @ -1799,10 +1804,10 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { | ||||
|               op->Inputs()[i])}); | ||||
|     } | ||||
|     std::vector<TFE_TensorHandle*> outputs(*num_retvals); | ||||
|     // TODO(allenl): figure out how to get attrs from EagerOperation
 | ||||
|     TF_Status status; | ||||
|     TFE_OpAttrs attributes(&op->Attrs()); | ||||
|     device_.execute(inputs.size(), inputs.data(), op->Name().c_str(), | ||||
|                     num_retvals, outputs.data(), &status, info_); | ||||
|                     &attributes, num_retvals, outputs.data(), &status, info_); | ||||
|     if (status.status.ok()) { | ||||
|       for (int i = 0; i < *num_retvals; ++i) { | ||||
|         retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>( | ||||
|  | ||||
| @ -424,7 +424,27 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( | ||||
| TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx, | ||||
|                                                 TF_Buffer* buf); | ||||
| 
 | ||||
| #define TFE_CUSTOM_DEVICE_VERSION 0 | ||||
| // APIs for generically dealing with op attributes (e.g. when forwarding them
 | ||||
| // through custom device implementations).
 | ||||
| //
 | ||||
| // TODO(allenl): Currently these are black boxes, but we should have some way to
 | ||||
| // inspect values. This would let people e.g. copy over most attributes and then
 | ||||
| // modify some based on their values.
 | ||||
| 
 | ||||
| // A reference to an op's name -> attribute mapping
 | ||||
| typedef struct TFE_OpAttrs TFE_OpAttrs; | ||||
| 
 | ||||
| // Fetch a struct with a reference to information about attributes of `op`.
 | ||||
| //
 | ||||
| // The `attrs` struct does not own any memory, and `op` must outlive it.
 | ||||
| TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs); | ||||
| 
 | ||||
| // Add attributes in `attrs` to `op`.
 | ||||
| //
 | ||||
| // Does not overwrite or update existing attributes, but adds new ones.
 | ||||
| TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs); | ||||
| 
 | ||||
| #define TFE_CUSTOM_DEVICE_VERSION 1 | ||||
| 
 | ||||
| // Struct to be filled in
 | ||||
| typedef struct TFE_CustomDevice { | ||||
| @ -441,10 +461,10 @@ typedef struct TFE_CustomDevice { | ||||
|                                                void* device_info); | ||||
| 
 | ||||
|   // Method to execute an operation.
 | ||||
|   // TODO(allenl) figure out a generic way of passing attrs here
 | ||||
|   void (*execute)(int num_inputs, TFE_TensorHandle** inputs, | ||||
|                   const char* operation_name, int* num_outputs, | ||||
|                   TFE_TensorHandle** outputs, TF_Status* s, void* device_info); | ||||
|                   const char* operation_name, const TFE_OpAttrs* attributes, | ||||
|                   int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s, | ||||
|                   void* device_info); | ||||
| 
 | ||||
|   // Method to delete a device.
 | ||||
|   void (*delete_device)(void* device_info); | ||||
|  | ||||
| @ -236,4 +236,13 @@ struct TFE_Executor { | ||||
|   tensorflow::EagerExecutor* unowned_executor; | ||||
| }; | ||||
| 
 | ||||
| struct TFE_OpAttrs { | ||||
|   explicit TFE_OpAttrs() : attributes(nullptr) {} | ||||
| 
 | ||||
|   explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value) | ||||
|       : attributes(value) {} | ||||
| 
 | ||||
|   const tensorflow::AttrBuilder* attributes; | ||||
| }; | ||||
| 
 | ||||
| #endif  // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
 | ||||
|  | ||||
| @ -1449,4 +1449,38 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) { | ||||
|   TFE_DeleteContext(ctx); | ||||
| } | ||||
| 
 | ||||
| TEST(CAPI, TestTFE_OpGetAttrs) { | ||||
|   TF_Status* status = TF_NewStatus(); | ||||
|   TFE_ContextOptions* opts = TFE_NewContextOptions(); | ||||
|   TFE_Context* ctx = TFE_NewContext(opts, status); | ||||
|   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); | ||||
|   TFE_DeleteContextOptions(opts); | ||||
| 
 | ||||
|   TFE_Op* varop = TFE_NewOp(ctx, "VarHandleOp", status); | ||||
|   TFE_OpSetAttrType(varop, "dtype", TF_INT64); | ||||
|   TFE_OpSetAttrShape(varop, "shape", {}, 0, status); | ||||
|   TFE_OpAttrs attributes; | ||||
|   TFE_OpGetAttrs(varop, &attributes); | ||||
| 
 | ||||
|   TFE_Op* varop_copy = TFE_NewOp(ctx, "VarHandleOp", status); | ||||
|   TFE_OpSetAttrType(varop_copy, "dtype", TF_FLOAT); | ||||
|   TFE_OpAddAttrs(varop_copy, &attributes); | ||||
|   unsigned char is_list = 0; | ||||
|   ASSERT_EQ(TF_ATTR_TYPE, | ||||
|             TFE_OpGetAttrType(varop_copy, "dtype", &is_list, status)); | ||||
|   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); | ||||
|   ASSERT_EQ(TF_ATTR_SHAPE, | ||||
|             TFE_OpGetAttrType(varop_copy, "shape", &is_list, status)); | ||||
|   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); | ||||
| 
 | ||||
|   tensorflow::AttrValueMap attr_values; | ||||
|   varop_copy->operation.Attrs().FillAttrValueMap(&attr_values); | ||||
|   EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type()); | ||||
| 
 | ||||
|   TF_DeleteStatus(status); | ||||
|   TFE_DeleteOp(varop); | ||||
|   TFE_DeleteOp(varop_copy); | ||||
|   TFE_DeleteContext(ctx); | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
|  | ||||
| @ -21,6 +21,7 @@ limitations under the License. | ||||
| #include "tensorflow/c/eager/c_api_experimental.h" | ||||
| #include "tensorflow/c/eager/c_api_test_util.h" | ||||
| #include "tensorflow/c/tf_status.h" | ||||
| #include "tensorflow/core/lib/gtl/cleanup.h" | ||||
| #include "tensorflow/core/platform/test.h" | ||||
| 
 | ||||
| namespace { | ||||
| @ -83,12 +84,14 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor, | ||||
| } | ||||
| 
 | ||||
| void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs, | ||||
|                           const char* operation_name, int* num_outputs, | ||||
|                           const char* operation_name, | ||||
|                           const TFE_OpAttrs* attributes, int* num_outputs, | ||||
|                           TFE_TensorHandle** outputs, TF_Status* s, | ||||
|                           void* device_info) { | ||||
|   LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info); | ||||
|   TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s)); | ||||
|   if (TF_GetCode(s) != TF_OK) return; | ||||
|   TFE_OpAddAttrs(op, attributes); | ||||
|   TFE_OpSetDevice(op, dev->underlying_device.c_str(), s); | ||||
|   for (int j = 0; j < num_inputs; ++j) { | ||||
|     TFE_TensorHandle* input = inputs[j]; | ||||
| @ -203,4 +206,89 @@ TEST(CUSTOM_DEVICE, ResetOperation) { | ||||
|   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); | ||||
| } | ||||
| 
 | ||||
| TEST(CUSTOM_DEVICE, MakeVariable) { | ||||
|   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( | ||||
|       TF_NewStatus(), TF_DeleteStatus); | ||||
|   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts( | ||||
|       TFE_NewContextOptions(), TFE_DeleteContextOptions); | ||||
|   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context( | ||||
|       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); | ||||
|   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); | ||||
|   bool arrived = false; | ||||
|   bool executed = false; | ||||
|   const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; | ||||
|   RegisterLoggingDevice(context.get(), name, &arrived, &executed); | ||||
| 
 | ||||
|   // Create a variable handle placed on the custom device.
 | ||||
|   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op( | ||||
|       TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp); | ||||
|   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); | ||||
|   TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT); | ||||
|   TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get()); | ||||
|   TFE_OpSetAttrString(op.get(), "container", "", 0); | ||||
|   TFE_OpSetAttrString(op.get(), "shared_name", "", 0); | ||||
|   TFE_OpSetDevice(op.get(), name, status.get()); | ||||
|   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); | ||||
|   TFE_TensorHandle* var_handle = nullptr; | ||||
|   int num_retvals = 1; | ||||
|   executed = false; | ||||
|   TFE_Execute(op.get(), &var_handle, &num_retvals, status.get()); | ||||
|   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); | ||||
|   ASSERT_TRUE(executed); | ||||
|   auto handle_cleaner = tensorflow::gtl::MakeCleanup( | ||||
|       [var_handle]() { TFE_DeleteTensorHandle(var_handle); }); | ||||
| 
 | ||||
|   // Assign to the variable, copying to the custom device.
 | ||||
|   std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one( | ||||
|       TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle); | ||||
|   op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get())); | ||||
|   TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT); | ||||
|   TFE_OpAddInput(op.get(), var_handle, status.get()); | ||||
|   TFE_OpAddInput(op.get(), one.get(), status.get()); | ||||
|   TFE_OpSetDevice(op.get(), name, status.get()); | ||||
|   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); | ||||
|   executed = false; | ||||
|   num_retvals = 0; | ||||
|   TFE_Execute(op.get(), nullptr, &num_retvals, status.get()); | ||||
|   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); | ||||
|   ASSERT_TRUE(executed); | ||||
| 
 | ||||
|   // Read the variable's value.
 | ||||
|   op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get())); | ||||
|   TFE_OpAddInput(op.get(), var_handle, status.get()); | ||||
|   TFE_OpSetDevice(op.get(), name, status.get()); | ||||
|   TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT); | ||||
|   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); | ||||
|   executed = false; | ||||
|   num_retvals = 1; | ||||
|   TFE_TensorHandle* var_value = nullptr; | ||||
|   TFE_Execute(op.get(), &var_value, &num_retvals, status.get()); | ||||
|   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); | ||||
|   ASSERT_TRUE(executed); | ||||
|   auto value_cleaner = tensorflow::gtl::MakeCleanup( | ||||
|       [var_value]() { TFE_DeleteTensorHandle(var_value); }); | ||||
|   ASSERT_EQ(tensorflow::string(name), | ||||
|             tensorflow::string( | ||||
|                 TFE_TensorHandleBackingDeviceName(var_value, status.get()))); | ||||
|   TFE_TensorHandle* var_value_unpacked = | ||||
|       reinterpret_cast<LoggedTensor*>( | ||||
|           TFE_TensorHandleDevicePointer(var_value, status.get())) | ||||
|           ->tensor; | ||||
|   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); | ||||
|   std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value( | ||||
|       TFE_TensorHandleResolve(var_value_unpacked, status.get()), | ||||
|       TF_DeleteTensor); | ||||
|   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); | ||||
|   ASSERT_EQ(111., *static_cast<float*>(TF_TensorData(resolved_value.get()))); | ||||
| 
 | ||||
|   // Free the backing buffer for the variable.
 | ||||
|   op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get())); | ||||
|   TFE_OpAddInput(op.get(), var_handle, status.get()); | ||||
|   TFE_OpSetDevice(op.get(), name, status.get()); | ||||
|   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); | ||||
|   num_retvals = 0; | ||||
|   TFE_Execute(op.get(), nullptr, &num_retvals, status.get()); | ||||
|   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user