Clean up some methods in EagerOperation
Rename AddInput and make it along with a few other methods private. PiperOrigin-RevId: 306745935 Change-Id: I1e333419552a28e96755bb249448974ba6a49eb7
This commit is contained in:
		
							parent
							
								
									ac271534b8
								
							
						
					
					
						commit
						f290fd7f48
					
				| @ -249,7 +249,7 @@ Status EagerOperation::OutputLength(const char* output_name, int* length) { | |||||||
| 
 | 
 | ||||||
| Status EagerOperation::AddInput(AbstractTensorHandleInterface* input) { | Status EagerOperation::AddInput(AbstractTensorHandleInterface* input) { | ||||||
|   TensorHandle* h = TensorHandleFromInterface(input); |   TensorHandle* h = TensorHandleFromInterface(input); | ||||||
|   AddInput(h); |   AddTensorHandle(h); | ||||||
|   return MaybeInferSingleInputAttrs(h); |   return MaybeInferSingleInputAttrs(h); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -257,7 +257,7 @@ Status EagerOperation::AddInputList( | |||||||
|     absl::Span<AbstractTensorHandleInterface*> inputs) { |     absl::Span<AbstractTensorHandleInterface*> inputs) { | ||||||
|   for (auto& input : inputs) { |   for (auto& input : inputs) { | ||||||
|     TensorHandle* h = TensorHandleFromInterface(input); |     TensorHandle* h = TensorHandleFromInterface(input); | ||||||
|     AddInput(h); |     AddTensorHandle(h); | ||||||
|   } |   } | ||||||
|   return InferInputListAttrs(inputs.size()); |   return InferInputListAttrs(inputs.size()); | ||||||
| } | } | ||||||
| @ -426,4 +426,10 @@ string EagerOperation::DebugString() const { | |||||||
|   return out; |   return out; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | void EagerOperation::AddTensorHandle(TensorHandle* h) { | ||||||
|  |   h->Ref(); | ||||||
|  |   inputs_.push_back(h); | ||||||
|  |   attrs_.NumInputs(static_cast<int>(inputs_.size())); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace tensorflow
 | }  // namespace tensorflow
 | ||||||
|  | |||||||
| @ -134,7 +134,6 @@ class EagerOperation : public AbstractOperationInterface { | |||||||
|   bool colocation_exempt() const { return colocation_exempt_; } |   bool colocation_exempt() const { return colocation_exempt_; } | ||||||
| 
 | 
 | ||||||
|   tensorflow::EagerContext& EagerContext() { return ctx_; } |   tensorflow::EagerContext& EagerContext() { return ctx_; } | ||||||
|   const tensorflow::EagerContext& EagerContext() const { return ctx_; } |  | ||||||
| 
 | 
 | ||||||
|   AttrBuilder* MutableAttrs() { return &attrs_; } |   AttrBuilder* MutableAttrs() { return &attrs_; } | ||||||
|   const AttrBuilder& Attrs() const { return attrs_; } |   const AttrBuilder& Attrs() const { return attrs_; } | ||||||
| @ -144,11 +143,8 @@ class EagerOperation : public AbstractOperationInterface { | |||||||
|   } |   } | ||||||
|   absl::InlinedVector<TensorHandle*, 4>* MutableInputs() { return &inputs_; } |   absl::InlinedVector<TensorHandle*, 4>* MutableInputs() { return &inputs_; } | ||||||
| 
 | 
 | ||||||
|   void AddInput(TensorHandle* h); |  | ||||||
|   void UpdateInput(int i, TensorHandle* h); |   void UpdateInput(int i, TensorHandle* h); | ||||||
| 
 | 
 | ||||||
|   const AttrTypeMap* AttrTypes() const { return attr_types_; } |  | ||||||
| 
 |  | ||||||
|   // Like TensorHandles, EagerOperations may be placed either on a virtual
 |   // Like TensorHandles, EagerOperations may be placed either on a virtual
 | ||||||
|   // CustomDevice or on a physical Device.
 |   // CustomDevice or on a physical Device.
 | ||||||
|   VariantDevice Device() const { return device_; } |   VariantDevice Device() const { return device_; } | ||||||
| @ -174,12 +170,10 @@ class EagerOperation : public AbstractOperationInterface { | |||||||
| 
 | 
 | ||||||
|   // Op name recorded for memory debugging purpose.
 |   // Op name recorded for memory debugging purpose.
 | ||||||
|   const char* op_name() const { return op_name_; } |   const char* op_name() const { return op_name_; } | ||||||
|   const char* op_name_ = nullptr; |  | ||||||
| 
 |  | ||||||
|   Status MaybeInferSingleInputAttrs(TensorHandle* handle); |  | ||||||
|   Status InferInputListAttrs(int num_inputs); |  | ||||||
| 
 | 
 | ||||||
|  private: |  private: | ||||||
|  |   void AddTensorHandle(TensorHandle* h); | ||||||
|  | 
 | ||||||
|   const tensorflow::OpDef* GetOpDef(Status* status); |   const tensorflow::OpDef* GetOpDef(Status* status); | ||||||
| 
 | 
 | ||||||
|   void ClearInferenceState() { |   void ClearInferenceState() { | ||||||
| @ -187,12 +181,17 @@ class EagerOperation : public AbstractOperationInterface { | |||||||
|     inference_arg_idx_ = 0; |     inference_arg_idx_ = 0; | ||||||
|     inference_attrs_.clear_no_resize(); |     inference_attrs_.clear_no_resize(); | ||||||
|   } |   } | ||||||
|  | 
 | ||||||
|  |   Status MaybeInferSingleInputAttrs(TensorHandle* handle); | ||||||
|  |   Status InferInputListAttrs(int num_inputs); | ||||||
|  | 
 | ||||||
|   void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def, |   void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def, | ||||||
|                                      const DataType dtype, int num_inputs); |                                      const DataType dtype, int num_inputs); | ||||||
|   void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def, |   void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def, | ||||||
|                                     const std::vector<DataType>& dtypes); |                                     const std::vector<DataType>& dtypes); | ||||||
| 
 | 
 | ||||||
|   tensorflow::EagerContext& ctx_; |   tensorflow::EagerContext& ctx_; | ||||||
|  |   const char* op_name_ = nullptr; | ||||||
|   AttrBuilder attrs_; |   AttrBuilder attrs_; | ||||||
|   const AttrTypeMap* attr_types_; |   const AttrTypeMap* attr_types_; | ||||||
|   absl::InlinedVector<TensorHandle*, 4> inputs_; |   absl::InlinedVector<TensorHandle*, 4> inputs_; | ||||||
| @ -232,12 +231,6 @@ class EagerOperation : public AbstractOperationInterface { | |||||||
|   gtl::FlatSet<std::string> inference_attrs_;  // attributes inferred so far
 |   gtl::FlatSet<std::string> inference_attrs_;  // attributes inferred so far
 | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| inline void EagerOperation::AddInput(TensorHandle* h) { |  | ||||||
|   h->Ref(); |  | ||||||
|   inputs_.push_back(h); |  | ||||||
|   attrs_.NumInputs(static_cast<int>(inputs_.size())); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| inline void EagerOperation::UpdateInput(int i, TensorHandle* h) { | inline void EagerOperation::UpdateInput(int i, TensorHandle* h) { | ||||||
|   TensorHandle** slot = &inputs_[i]; |   TensorHandle** slot = &inputs_[i]; | ||||||
|   TensorHandle* existing = *slot; |   TensorHandle* existing = *slot; | ||||||
|  | |||||||
| @ -373,7 +373,7 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation, | |||||||
|         TF_RETURN_IF_ERROR( |         TF_RETURN_IF_ERROR( | ||||||
|             eager_context->RemoteMgr()->DeserializeRemoteTensorHandle( |             eager_context->RemoteMgr()->DeserializeRemoteTensorHandle( | ||||||
|                 input.remote_handle(), &handle)); |                 input.remote_handle(), &handle)); | ||||||
|         op->AddInput(handle); |         TF_RETURN_IF_ERROR(op->AddInput(handle)); | ||||||
|       } else { |       } else { | ||||||
|         Tensor tensor; |         Tensor tensor; | ||||||
|         if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) { |         if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) { | ||||||
| @ -382,7 +382,7 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation, | |||||||
|         } else { |         } else { | ||||||
|           handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr, |           handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr, | ||||||
|                                                    nullptr, eager_context); |                                                    nullptr, eager_context); | ||||||
|           op->AddInput(handle); |           TF_RETURN_IF_ERROR(op->AddInput(handle)); | ||||||
|         } |         } | ||||||
|       } |       } | ||||||
|       // Unref handle since it has a ref as an input now.
 |       // Unref handle since it has a ref as an input now.
 | ||||||
|  | |||||||
| @ -96,7 +96,7 @@ RemoteCopyNode::~RemoteCopyNode() { | |||||||
| Status RemoteCopyNode::RunLocalSend(EagerOperation* op) { | Status RemoteCopyNode::RunLocalSend(EagerOperation* op) { | ||||||
|   TF_RETURN_IF_ERROR(executor_->status()); |   TF_RETURN_IF_ERROR(executor_->status()); | ||||||
| 
 | 
 | ||||||
|   op->AddInput(src_); |   TF_RETURN_IF_ERROR(op->AddInput(src_)); | ||||||
| 
 | 
 | ||||||
|   core::RefCountPtr<KernelAndDevice> kernel; |   core::RefCountPtr<KernelAndDevice> kernel; | ||||||
|   TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel)); |   TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel)); | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user