Clean up some methods in EagerOperation
Rename AddInput and make it along with a few other methods private. PiperOrigin-RevId: 306745935 Change-Id: I1e333419552a28e96755bb249448974ba6a49eb7
This commit is contained in:
parent
ac271534b8
commit
f290fd7f48
@ -249,7 +249,7 @@ Status EagerOperation::OutputLength(const char* output_name, int* length) {
|
||||
|
||||
Status EagerOperation::AddInput(AbstractTensorHandleInterface* input) {
|
||||
TensorHandle* h = TensorHandleFromInterface(input);
|
||||
AddInput(h);
|
||||
AddTensorHandle(h);
|
||||
return MaybeInferSingleInputAttrs(h);
|
||||
}
|
||||
|
||||
@ -257,7 +257,7 @@ Status EagerOperation::AddInputList(
|
||||
absl::Span<AbstractTensorHandleInterface*> inputs) {
|
||||
for (auto& input : inputs) {
|
||||
TensorHandle* h = TensorHandleFromInterface(input);
|
||||
AddInput(h);
|
||||
AddTensorHandle(h);
|
||||
}
|
||||
return InferInputListAttrs(inputs.size());
|
||||
}
|
||||
@ -426,4 +426,10 @@ string EagerOperation::DebugString() const {
|
||||
return out;
|
||||
}
|
||||
|
||||
void EagerOperation::AddTensorHandle(TensorHandle* h) {
|
||||
h->Ref();
|
||||
inputs_.push_back(h);
|
||||
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -134,7 +134,6 @@ class EagerOperation : public AbstractOperationInterface {
|
||||
bool colocation_exempt() const { return colocation_exempt_; }
|
||||
|
||||
tensorflow::EagerContext& EagerContext() { return ctx_; }
|
||||
const tensorflow::EagerContext& EagerContext() const { return ctx_; }
|
||||
|
||||
AttrBuilder* MutableAttrs() { return &attrs_; }
|
||||
const AttrBuilder& Attrs() const { return attrs_; }
|
||||
@ -144,11 +143,8 @@ class EagerOperation : public AbstractOperationInterface {
|
||||
}
|
||||
absl::InlinedVector<TensorHandle*, 4>* MutableInputs() { return &inputs_; }
|
||||
|
||||
void AddInput(TensorHandle* h);
|
||||
void UpdateInput(int i, TensorHandle* h);
|
||||
|
||||
const AttrTypeMap* AttrTypes() const { return attr_types_; }
|
||||
|
||||
// Like TensorHandles, EagerOperations may be placed either on a virtual
|
||||
// CustomDevice or on a physical Device.
|
||||
VariantDevice Device() const { return device_; }
|
||||
@ -174,12 +170,10 @@ class EagerOperation : public AbstractOperationInterface {
|
||||
|
||||
// Op name recorded for memory debugging purpose.
|
||||
const char* op_name() const { return op_name_; }
|
||||
const char* op_name_ = nullptr;
|
||||
|
||||
Status MaybeInferSingleInputAttrs(TensorHandle* handle);
|
||||
Status InferInputListAttrs(int num_inputs);
|
||||
|
||||
private:
|
||||
void AddTensorHandle(TensorHandle* h);
|
||||
|
||||
const tensorflow::OpDef* GetOpDef(Status* status);
|
||||
|
||||
void ClearInferenceState() {
|
||||
@ -187,12 +181,17 @@ class EagerOperation : public AbstractOperationInterface {
|
||||
inference_arg_idx_ = 0;
|
||||
inference_attrs_.clear_no_resize();
|
||||
}
|
||||
|
||||
Status MaybeInferSingleInputAttrs(TensorHandle* handle);
|
||||
Status InferInputListAttrs(int num_inputs);
|
||||
|
||||
void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def,
|
||||
const DataType dtype, int num_inputs);
|
||||
void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def,
|
||||
const std::vector<DataType>& dtypes);
|
||||
|
||||
tensorflow::EagerContext& ctx_;
|
||||
const char* op_name_ = nullptr;
|
||||
AttrBuilder attrs_;
|
||||
const AttrTypeMap* attr_types_;
|
||||
absl::InlinedVector<TensorHandle*, 4> inputs_;
|
||||
@ -232,12 +231,6 @@ class EagerOperation : public AbstractOperationInterface {
|
||||
gtl::FlatSet<std::string> inference_attrs_; // attributes inferred so far
|
||||
};
|
||||
|
||||
inline void EagerOperation::AddInput(TensorHandle* h) {
|
||||
h->Ref();
|
||||
inputs_.push_back(h);
|
||||
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
||||
}
|
||||
|
||||
inline void EagerOperation::UpdateInput(int i, TensorHandle* h) {
|
||||
TensorHandle** slot = &inputs_[i];
|
||||
TensorHandle* existing = *slot;
|
||||
|
@ -373,7 +373,7 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
|
||||
TF_RETURN_IF_ERROR(
|
||||
eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
|
||||
input.remote_handle(), &handle));
|
||||
op->AddInput(handle);
|
||||
TF_RETURN_IF_ERROR(op->AddInput(handle));
|
||||
} else {
|
||||
Tensor tensor;
|
||||
if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) {
|
||||
@ -382,7 +382,7 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
|
||||
} else {
|
||||
handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr,
|
||||
nullptr, eager_context);
|
||||
op->AddInput(handle);
|
||||
TF_RETURN_IF_ERROR(op->AddInput(handle));
|
||||
}
|
||||
}
|
||||
// Unref handle since it has a ref as an input now.
|
||||
|
@ -96,7 +96,7 @@ RemoteCopyNode::~RemoteCopyNode() {
|
||||
Status RemoteCopyNode::RunLocalSend(EagerOperation* op) {
|
||||
TF_RETURN_IF_ERROR(executor_->status());
|
||||
|
||||
op->AddInput(src_);
|
||||
TF_RETURN_IF_ERROR(op->AddInput(src_));
|
||||
|
||||
core::RefCountPtr<KernelAndDevice> kernel;
|
||||
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
|
||||
|
Loading…
Reference in New Issue
Block a user