Clean up some methods in EagerOperation
Rename AddInput and make it along with a few other methods private. PiperOrigin-RevId: 306745935 Change-Id: I1e333419552a28e96755bb249448974ba6a49eb7
This commit is contained in:
parent
ac271534b8
commit
f290fd7f48
@ -249,7 +249,7 @@ Status EagerOperation::OutputLength(const char* output_name, int* length) {
|
|||||||
|
|
||||||
Status EagerOperation::AddInput(AbstractTensorHandleInterface* input) {
|
Status EagerOperation::AddInput(AbstractTensorHandleInterface* input) {
|
||||||
TensorHandle* h = TensorHandleFromInterface(input);
|
TensorHandle* h = TensorHandleFromInterface(input);
|
||||||
AddInput(h);
|
AddTensorHandle(h);
|
||||||
return MaybeInferSingleInputAttrs(h);
|
return MaybeInferSingleInputAttrs(h);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -257,7 +257,7 @@ Status EagerOperation::AddInputList(
|
|||||||
absl::Span<AbstractTensorHandleInterface*> inputs) {
|
absl::Span<AbstractTensorHandleInterface*> inputs) {
|
||||||
for (auto& input : inputs) {
|
for (auto& input : inputs) {
|
||||||
TensorHandle* h = TensorHandleFromInterface(input);
|
TensorHandle* h = TensorHandleFromInterface(input);
|
||||||
AddInput(h);
|
AddTensorHandle(h);
|
||||||
}
|
}
|
||||||
return InferInputListAttrs(inputs.size());
|
return InferInputListAttrs(inputs.size());
|
||||||
}
|
}
|
||||||
@ -426,4 +426,10 @@ string EagerOperation::DebugString() const {
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void EagerOperation::AddTensorHandle(TensorHandle* h) {
|
||||||
|
h->Ref();
|
||||||
|
inputs_.push_back(h);
|
||||||
|
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|||||||
@ -134,7 +134,6 @@ class EagerOperation : public AbstractOperationInterface {
|
|||||||
bool colocation_exempt() const { return colocation_exempt_; }
|
bool colocation_exempt() const { return colocation_exempt_; }
|
||||||
|
|
||||||
tensorflow::EagerContext& EagerContext() { return ctx_; }
|
tensorflow::EagerContext& EagerContext() { return ctx_; }
|
||||||
const tensorflow::EagerContext& EagerContext() const { return ctx_; }
|
|
||||||
|
|
||||||
AttrBuilder* MutableAttrs() { return &attrs_; }
|
AttrBuilder* MutableAttrs() { return &attrs_; }
|
||||||
const AttrBuilder& Attrs() const { return attrs_; }
|
const AttrBuilder& Attrs() const { return attrs_; }
|
||||||
@ -144,11 +143,8 @@ class EagerOperation : public AbstractOperationInterface {
|
|||||||
}
|
}
|
||||||
absl::InlinedVector<TensorHandle*, 4>* MutableInputs() { return &inputs_; }
|
absl::InlinedVector<TensorHandle*, 4>* MutableInputs() { return &inputs_; }
|
||||||
|
|
||||||
void AddInput(TensorHandle* h);
|
|
||||||
void UpdateInput(int i, TensorHandle* h);
|
void UpdateInput(int i, TensorHandle* h);
|
||||||
|
|
||||||
const AttrTypeMap* AttrTypes() const { return attr_types_; }
|
|
||||||
|
|
||||||
// Like TensorHandles, EagerOperations may be placed either on a virtual
|
// Like TensorHandles, EagerOperations may be placed either on a virtual
|
||||||
// CustomDevice or on a physical Device.
|
// CustomDevice or on a physical Device.
|
||||||
VariantDevice Device() const { return device_; }
|
VariantDevice Device() const { return device_; }
|
||||||
@ -174,12 +170,10 @@ class EagerOperation : public AbstractOperationInterface {
|
|||||||
|
|
||||||
// Op name recorded for memory debugging purpose.
|
// Op name recorded for memory debugging purpose.
|
||||||
const char* op_name() const { return op_name_; }
|
const char* op_name() const { return op_name_; }
|
||||||
const char* op_name_ = nullptr;
|
|
||||||
|
|
||||||
Status MaybeInferSingleInputAttrs(TensorHandle* handle);
|
|
||||||
Status InferInputListAttrs(int num_inputs);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void AddTensorHandle(TensorHandle* h);
|
||||||
|
|
||||||
const tensorflow::OpDef* GetOpDef(Status* status);
|
const tensorflow::OpDef* GetOpDef(Status* status);
|
||||||
|
|
||||||
void ClearInferenceState() {
|
void ClearInferenceState() {
|
||||||
@ -187,12 +181,17 @@ class EagerOperation : public AbstractOperationInterface {
|
|||||||
inference_arg_idx_ = 0;
|
inference_arg_idx_ = 0;
|
||||||
inference_attrs_.clear_no_resize();
|
inference_attrs_.clear_no_resize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status MaybeInferSingleInputAttrs(TensorHandle* handle);
|
||||||
|
Status InferInputListAttrs(int num_inputs);
|
||||||
|
|
||||||
void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def,
|
void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def,
|
||||||
const DataType dtype, int num_inputs);
|
const DataType dtype, int num_inputs);
|
||||||
void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def,
|
void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def,
|
||||||
const std::vector<DataType>& dtypes);
|
const std::vector<DataType>& dtypes);
|
||||||
|
|
||||||
tensorflow::EagerContext& ctx_;
|
tensorflow::EagerContext& ctx_;
|
||||||
|
const char* op_name_ = nullptr;
|
||||||
AttrBuilder attrs_;
|
AttrBuilder attrs_;
|
||||||
const AttrTypeMap* attr_types_;
|
const AttrTypeMap* attr_types_;
|
||||||
absl::InlinedVector<TensorHandle*, 4> inputs_;
|
absl::InlinedVector<TensorHandle*, 4> inputs_;
|
||||||
@ -232,12 +231,6 @@ class EagerOperation : public AbstractOperationInterface {
|
|||||||
gtl::FlatSet<std::string> inference_attrs_; // attributes inferred so far
|
gtl::FlatSet<std::string> inference_attrs_; // attributes inferred so far
|
||||||
};
|
};
|
||||||
|
|
||||||
inline void EagerOperation::AddInput(TensorHandle* h) {
|
|
||||||
h->Ref();
|
|
||||||
inputs_.push_back(h);
|
|
||||||
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void EagerOperation::UpdateInput(int i, TensorHandle* h) {
|
inline void EagerOperation::UpdateInput(int i, TensorHandle* h) {
|
||||||
TensorHandle** slot = &inputs_[i];
|
TensorHandle** slot = &inputs_[i];
|
||||||
TensorHandle* existing = *slot;
|
TensorHandle* existing = *slot;
|
||||||
|
|||||||
@ -373,7 +373,7 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
|
eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
|
||||||
input.remote_handle(), &handle));
|
input.remote_handle(), &handle));
|
||||||
op->AddInput(handle);
|
TF_RETURN_IF_ERROR(op->AddInput(handle));
|
||||||
} else {
|
} else {
|
||||||
Tensor tensor;
|
Tensor tensor;
|
||||||
if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) {
|
if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) {
|
||||||
@ -382,7 +382,7 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
|
|||||||
} else {
|
} else {
|
||||||
handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr,
|
handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr,
|
||||||
nullptr, eager_context);
|
nullptr, eager_context);
|
||||||
op->AddInput(handle);
|
TF_RETURN_IF_ERROR(op->AddInput(handle));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Unref handle since it has a ref as an input now.
|
// Unref handle since it has a ref as an input now.
|
||||||
|
|||||||
@ -96,7 +96,7 @@ RemoteCopyNode::~RemoteCopyNode() {
|
|||||||
Status RemoteCopyNode::RunLocalSend(EagerOperation* op) {
|
Status RemoteCopyNode::RunLocalSend(EagerOperation* op) {
|
||||||
TF_RETURN_IF_ERROR(executor_->status());
|
TF_RETURN_IF_ERROR(executor_->status());
|
||||||
|
|
||||||
op->AddInput(src_);
|
TF_RETURN_IF_ERROR(op->AddInput(src_));
|
||||||
|
|
||||||
core::RefCountPtr<KernelAndDevice> kernel;
|
core::RefCountPtr<KernelAndDevice> kernel;
|
||||||
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
|
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user