Small cleanups for experimental TFE attribute APIs
The op name was included twice, and TFE_OpGetAttrs is unusable without a way to allocate a TFE_OpAttrs on the heap (and so has no callers). I'm removing it for now. PiperOrigin-RevId: 308859222 Change-Id: Ibb3901a1821ffc2e9ebc0efb26592e5b3d8bb88f
This commit is contained in:
		
							parent
							
								
									cb77658d4e
								
							
						
					
					
						commit
						e0606af65f
					
				@ -1485,11 +1485,6 @@ void TFE_ContextEndStep(TFE_Context* ctx) {
 | 
				
			|||||||
  context->EndStep();
 | 
					  context->EndStep();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
 | 
					 | 
				
			||||||
  tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
 | 
					 | 
				
			||||||
  *attrs = TFE_OpAttrs(&operation->Attrs(), operation->Name().c_str());
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
 | 
					void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
 | 
				
			||||||
  tensorflow::AttrValueMap m;
 | 
					  tensorflow::AttrValueMap m;
 | 
				
			||||||
  attrs->attributes->FillAttrValueMap(&m);
 | 
					  attrs->attributes->FillAttrValueMap(&m);
 | 
				
			||||||
@ -1504,7 +1499,7 @@ void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
 | 
				
			|||||||
                          TF_Status* status) {
 | 
					                          TF_Status* status) {
 | 
				
			||||||
  tensorflow::NameAttrList name_and_attrs;
 | 
					  tensorflow::NameAttrList name_and_attrs;
 | 
				
			||||||
  attrs->attributes->FillAttrValueMap(name_and_attrs.mutable_attr());
 | 
					  attrs->attributes->FillAttrValueMap(name_and_attrs.mutable_attr());
 | 
				
			||||||
  name_and_attrs.set_name(attrs->name);
 | 
					  name_and_attrs.set_name(attrs->attributes->op_name());
 | 
				
			||||||
  status->status = MessageToBuffer(name_and_attrs, buf);
 | 
					  status->status = MessageToBuffer(name_and_attrs, buf);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1624,7 +1619,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
    std::vector<TFE_TensorHandle*> outputs(*num_retvals);
 | 
					    std::vector<TFE_TensorHandle*> outputs(*num_retvals);
 | 
				
			||||||
    TF_Status status;
 | 
					    TF_Status status;
 | 
				
			||||||
    TFE_OpAttrs attributes(&op->Attrs(), op->Name().c_str());
 | 
					    TFE_OpAttrs attributes(&op->Attrs());
 | 
				
			||||||
    device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
 | 
					    device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
 | 
				
			||||||
                    &attributes, num_retvals, outputs.data(), &status, info_);
 | 
					                    &attributes, num_retvals, outputs.data(), &status, info_);
 | 
				
			||||||
    if (status.status.ok()) {
 | 
					    if (status.status.ok()) {
 | 
				
			||||||
 | 
				
			|||||||
@ -431,11 +431,6 @@ TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
 | 
				
			|||||||
// A reference to an op's name -> attribute mapping
 | 
					// A reference to an op's name -> attribute mapping
 | 
				
			||||||
typedef struct TFE_OpAttrs TFE_OpAttrs;
 | 
					typedef struct TFE_OpAttrs TFE_OpAttrs;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Fetch a struct with a reference to information about attributes of `op`.
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// The `attrs` struct does not own any memory, and `op` must outlive it.
 | 
					 | 
				
			||||||
TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Add attributes in `attrs` to `op`.
 | 
					// Add attributes in `attrs` to `op`.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// Does not overwrite or update existing attributes, but adds new ones.
 | 
					// Does not overwrite or update existing attributes, but adds new ones.
 | 
				
			||||||
 | 
				
			|||||||
@ -1577,7 +1577,7 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
 | 
				
			|||||||
  TFE_DeleteContext(ctx);
 | 
					  TFE_DeleteContext(ctx);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST(CAPI, TestTFE_OpGetAttrs) {
 | 
					TEST(CAPI, TestTFE_OpAddAttrs) {
 | 
				
			||||||
  TF_Status* status = TF_NewStatus();
 | 
					  TF_Status* status = TF_NewStatus();
 | 
				
			||||||
  TFE_ContextOptions* opts = TFE_NewContextOptions();
 | 
					  TFE_ContextOptions* opts = TFE_NewContextOptions();
 | 
				
			||||||
  TFE_Context* ctx = TFE_NewContext(opts, status);
 | 
					  TFE_Context* ctx = TFE_NewContext(opts, status);
 | 
				
			||||||
@ -1587,8 +1587,11 @@ TEST(CAPI, TestTFE_OpGetAttrs) {
 | 
				
			|||||||
  TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
 | 
					  TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
 | 
				
			||||||
  TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
 | 
					  TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
 | 
				
			||||||
  TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
 | 
					  TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
 | 
				
			||||||
  TFE_OpAttrs attributes;
 | 
					  // There is currently no API to fetch attributes from an operation, fetching
 | 
				
			||||||
  TFE_OpGetAttrs(var_op, &attributes);
 | 
					  // happens only as an implementation detail of custom devices.
 | 
				
			||||||
 | 
					  tensorflow::EagerOperation* operation =
 | 
				
			||||||
 | 
					      OperationFromInterface(var_op->operation);
 | 
				
			||||||
 | 
					  TFE_OpAttrs attributes{&operation->Attrs()};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
 | 
					  TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
 | 
				
			||||||
  TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
 | 
					  TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
 | 
				
			||||||
@ -1624,8 +1627,11 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
 | 
				
			|||||||
  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
 | 
					  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
 | 
				
			||||||
  TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
 | 
					  TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
 | 
				
			||||||
  TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
 | 
					  TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
 | 
				
			||||||
  TFE_OpAttrs attributes;
 | 
					  // There is currently no API to fetch attributes from an operation, fetching
 | 
				
			||||||
  TFE_OpGetAttrs(var_op, &attributes);
 | 
					  // happens only as an implementation detail of custom devices.
 | 
				
			||||||
 | 
					  tensorflow::EagerOperation* operation =
 | 
				
			||||||
 | 
					      OperationFromInterface(var_op->operation);
 | 
				
			||||||
 | 
					  TFE_OpAttrs attributes{&operation->Attrs()};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  TF_Buffer* serialized_attr_values = TF_NewBuffer();
 | 
					  TF_Buffer* serialized_attr_values = TF_NewBuffer();
 | 
				
			||||||
  TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status);
 | 
					  TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status);
 | 
				
			||||||
 | 
				
			|||||||
@ -32,13 +32,11 @@ limitations under the License.
 | 
				
			|||||||
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
 | 
					// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
 | 
				
			||||||
// that sometimes do not require serialization.
 | 
					// that sometimes do not require serialization.
 | 
				
			||||||
struct TFE_OpAttrs {
 | 
					struct TFE_OpAttrs {
 | 
				
			||||||
  explicit TFE_OpAttrs() : name(nullptr), attributes(nullptr) {}
 | 
					  explicit TFE_OpAttrs() : attributes(nullptr) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value,
 | 
					  explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value)
 | 
				
			||||||
                       const char* op_name)
 | 
					      : attributes(value) {}
 | 
				
			||||||
      : name(op_name), attributes(value) {}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const char* name;
 | 
					 | 
				
			||||||
  const tensorflow::AttrBuilder* attributes;
 | 
					  const tensorflow::AttrBuilder* attributes;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user