Clean up some methods in EagerOperation

Rename AddInput and make it along with a few other methods private.

PiperOrigin-RevId: 306745935
Change-Id: I1e333419552a28e96755bb249448974ba6a49eb7
This commit is contained in:
Gaurav Jain 2020-04-15 16:45:46 -07:00 committed by TensorFlower Gardener
parent ac271534b8
commit f290fd7f48
4 changed files with 18 additions and 19 deletions

View File

@ -249,7 +249,7 @@ Status EagerOperation::OutputLength(const char* output_name, int* length) {
Status EagerOperation::AddInput(AbstractTensorHandleInterface* input) { Status EagerOperation::AddInput(AbstractTensorHandleInterface* input) {
TensorHandle* h = TensorHandleFromInterface(input); TensorHandle* h = TensorHandleFromInterface(input);
AddInput(h); AddTensorHandle(h);
return MaybeInferSingleInputAttrs(h); return MaybeInferSingleInputAttrs(h);
} }
@ -257,7 +257,7 @@ Status EagerOperation::AddInputList(
absl::Span<AbstractTensorHandleInterface*> inputs) { absl::Span<AbstractTensorHandleInterface*> inputs) {
for (auto& input : inputs) { for (auto& input : inputs) {
TensorHandle* h = TensorHandleFromInterface(input); TensorHandle* h = TensorHandleFromInterface(input);
AddInput(h); AddTensorHandle(h);
} }
return InferInputListAttrs(inputs.size()); return InferInputListAttrs(inputs.size());
} }
@ -426,4 +426,10 @@ string EagerOperation::DebugString() const {
return out; return out;
} }
void EagerOperation::AddTensorHandle(TensorHandle* h) {
h->Ref();
inputs_.push_back(h);
attrs_.NumInputs(static_cast<int>(inputs_.size()));
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -134,7 +134,6 @@ class EagerOperation : public AbstractOperationInterface {
bool colocation_exempt() const { return colocation_exempt_; } bool colocation_exempt() const { return colocation_exempt_; }
tensorflow::EagerContext& EagerContext() { return ctx_; } tensorflow::EagerContext& EagerContext() { return ctx_; }
const tensorflow::EagerContext& EagerContext() const { return ctx_; }
AttrBuilder* MutableAttrs() { return &attrs_; } AttrBuilder* MutableAttrs() { return &attrs_; }
const AttrBuilder& Attrs() const { return attrs_; } const AttrBuilder& Attrs() const { return attrs_; }
@ -144,11 +143,8 @@ class EagerOperation : public AbstractOperationInterface {
} }
absl::InlinedVector<TensorHandle*, 4>* MutableInputs() { return &inputs_; } absl::InlinedVector<TensorHandle*, 4>* MutableInputs() { return &inputs_; }
void AddInput(TensorHandle* h);
void UpdateInput(int i, TensorHandle* h); void UpdateInput(int i, TensorHandle* h);
const AttrTypeMap* AttrTypes() const { return attr_types_; }
// Like TensorHandles, EagerOperations may be placed either on a virtual // Like TensorHandles, EagerOperations may be placed either on a virtual
// CustomDevice or on a physical Device. // CustomDevice or on a physical Device.
VariantDevice Device() const { return device_; } VariantDevice Device() const { return device_; }
@ -174,12 +170,10 @@ class EagerOperation : public AbstractOperationInterface {
// Op name recorded for memory debugging purpose. // Op name recorded for memory debugging purpose.
const char* op_name() const { return op_name_; } const char* op_name() const { return op_name_; }
const char* op_name_ = nullptr;
Status MaybeInferSingleInputAttrs(TensorHandle* handle);
Status InferInputListAttrs(int num_inputs);
private: private:
void AddTensorHandle(TensorHandle* h);
const tensorflow::OpDef* GetOpDef(Status* status); const tensorflow::OpDef* GetOpDef(Status* status);
void ClearInferenceState() { void ClearInferenceState() {
@ -187,12 +181,17 @@ class EagerOperation : public AbstractOperationInterface {
inference_arg_idx_ = 0; inference_arg_idx_ = 0;
inference_attrs_.clear_no_resize(); inference_attrs_.clear_no_resize();
} }
Status MaybeInferSingleInputAttrs(TensorHandle* handle);
Status InferInputListAttrs(int num_inputs);
void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def, void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def,
const DataType dtype, int num_inputs); const DataType dtype, int num_inputs);
void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def, void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def,
const std::vector<DataType>& dtypes); const std::vector<DataType>& dtypes);
tensorflow::EagerContext& ctx_; tensorflow::EagerContext& ctx_;
const char* op_name_ = nullptr;
AttrBuilder attrs_; AttrBuilder attrs_;
const AttrTypeMap* attr_types_; const AttrTypeMap* attr_types_;
absl::InlinedVector<TensorHandle*, 4> inputs_; absl::InlinedVector<TensorHandle*, 4> inputs_;
@ -232,12 +231,6 @@ class EagerOperation : public AbstractOperationInterface {
gtl::FlatSet<std::string> inference_attrs_; // attributes inferred so far gtl::FlatSet<std::string> inference_attrs_; // attributes inferred so far
}; };
inline void EagerOperation::AddInput(TensorHandle* h) {
h->Ref();
inputs_.push_back(h);
attrs_.NumInputs(static_cast<int>(inputs_.size()));
}
inline void EagerOperation::UpdateInput(int i, TensorHandle* h) { inline void EagerOperation::UpdateInput(int i, TensorHandle* h) {
TensorHandle** slot = &inputs_[i]; TensorHandle** slot = &inputs_[i];
TensorHandle* existing = *slot; TensorHandle* existing = *slot;

View File

@ -373,7 +373,7 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
eager_context->RemoteMgr()->DeserializeRemoteTensorHandle( eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
input.remote_handle(), &handle)); input.remote_handle(), &handle));
op->AddInput(handle); TF_RETURN_IF_ERROR(op->AddInput(handle));
} else { } else {
Tensor tensor; Tensor tensor;
if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) { if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) {
@ -382,7 +382,7 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
} else { } else {
handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr, handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr,
nullptr, eager_context); nullptr, eager_context);
op->AddInput(handle); TF_RETURN_IF_ERROR(op->AddInput(handle));
} }
} }
// Unref handle since it has a ref as an input now. // Unref handle since it has a ref as an input now.

View File

@ -96,7 +96,7 @@ RemoteCopyNode::~RemoteCopyNode() {
Status RemoteCopyNode::RunLocalSend(EagerOperation* op) { Status RemoteCopyNode::RunLocalSend(EagerOperation* op) {
TF_RETURN_IF_ERROR(executor_->status()); TF_RETURN_IF_ERROR(executor_->status());
op->AddInput(src_); TF_RETURN_IF_ERROR(op->AddInput(src_));
core::RefCountPtr<KernelAndDevice> kernel; core::RefCountPtr<KernelAndDevice> kernel;
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel)); TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));