Clean up some methods in EagerOperation

Rename AddInput and make it along with a few other methods private.

PiperOrigin-RevId: 306745935
Change-Id: I1e333419552a28e96755bb249448974ba6a49eb7
This commit is contained in:
Gaurav Jain 2020-04-15 16:45:46 -07:00 committed by TensorFlower Gardener
parent ac271534b8
commit f290fd7f48
4 changed files with 18 additions and 19 deletions

View File

@ -249,7 +249,7 @@ Status EagerOperation::OutputLength(const char* output_name, int* length) {
Status EagerOperation::AddInput(AbstractTensorHandleInterface* input) {
TensorHandle* h = TensorHandleFromInterface(input);
AddInput(h);
AddTensorHandle(h);
return MaybeInferSingleInputAttrs(h);
}
@ -257,7 +257,7 @@ Status EagerOperation::AddInputList(
absl::Span<AbstractTensorHandleInterface*> inputs) {
for (auto& input : inputs) {
TensorHandle* h = TensorHandleFromInterface(input);
AddInput(h);
AddTensorHandle(h);
}
return InferInputListAttrs(inputs.size());
}
@ -426,4 +426,10 @@ string EagerOperation::DebugString() const {
return out;
}
void EagerOperation::AddTensorHandle(TensorHandle* h) {
h->Ref();
inputs_.push_back(h);
attrs_.NumInputs(static_cast<int>(inputs_.size()));
}
} // namespace tensorflow

View File

@ -134,7 +134,6 @@ class EagerOperation : public AbstractOperationInterface {
bool colocation_exempt() const { return colocation_exempt_; }
tensorflow::EagerContext& EagerContext() { return ctx_; }
const tensorflow::EagerContext& EagerContext() const { return ctx_; }
AttrBuilder* MutableAttrs() { return &attrs_; }
const AttrBuilder& Attrs() const { return attrs_; }
@ -144,11 +143,8 @@ class EagerOperation : public AbstractOperationInterface {
}
absl::InlinedVector<TensorHandle*, 4>* MutableInputs() { return &inputs_; }
void AddInput(TensorHandle* h);
void UpdateInput(int i, TensorHandle* h);
const AttrTypeMap* AttrTypes() const { return attr_types_; }
// Like TensorHandles, EagerOperations may be placed either on a virtual
// CustomDevice or on a physical Device.
VariantDevice Device() const { return device_; }
@ -174,12 +170,10 @@ class EagerOperation : public AbstractOperationInterface {
// Op name recorded for memory debugging purpose.
const char* op_name() const { return op_name_; }
const char* op_name_ = nullptr;
Status MaybeInferSingleInputAttrs(TensorHandle* handle);
Status InferInputListAttrs(int num_inputs);
private:
void AddTensorHandle(TensorHandle* h);
const tensorflow::OpDef* GetOpDef(Status* status);
void ClearInferenceState() {
@ -187,12 +181,17 @@ class EagerOperation : public AbstractOperationInterface {
inference_arg_idx_ = 0;
inference_attrs_.clear_no_resize();
}
Status MaybeInferSingleInputAttrs(TensorHandle* handle);
Status InferInputListAttrs(int num_inputs);
void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def,
const DataType dtype, int num_inputs);
void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def,
const std::vector<DataType>& dtypes);
tensorflow::EagerContext& ctx_;
const char* op_name_ = nullptr;
AttrBuilder attrs_;
const AttrTypeMap* attr_types_;
absl::InlinedVector<TensorHandle*, 4> inputs_;
@ -232,12 +231,6 @@ class EagerOperation : public AbstractOperationInterface {
gtl::FlatSet<std::string> inference_attrs_; // attributes inferred so far
};
inline void EagerOperation::AddInput(TensorHandle* h) {
h->Ref();
inputs_.push_back(h);
attrs_.NumInputs(static_cast<int>(inputs_.size()));
}
inline void EagerOperation::UpdateInput(int i, TensorHandle* h) {
TensorHandle** slot = &inputs_[i];
TensorHandle* existing = *slot;

View File

@ -373,7 +373,7 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
TF_RETURN_IF_ERROR(
eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
input.remote_handle(), &handle));
op->AddInput(handle);
TF_RETURN_IF_ERROR(op->AddInput(handle));
} else {
Tensor tensor;
if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) {
@ -382,7 +382,7 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
} else {
handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr,
nullptr, eager_context);
op->AddInput(handle);
TF_RETURN_IF_ERROR(op->AddInput(handle));
}
}
// Unref handle since it has a ref as an input now.

View File

@ -96,7 +96,7 @@ RemoteCopyNode::~RemoteCopyNode() {
Status RemoteCopyNode::RunLocalSend(EagerOperation* op) {
TF_RETURN_IF_ERROR(executor_->status());
op->AddInput(src_);
TF_RETURN_IF_ERROR(op->AddInput(src_));
core::RefCountPtr<KernelAndDevice> kernel;
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));