diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index 2be516382aa..e84d3b0e9bf 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -53,8 +53,7 @@ Status EagerOperation::Reset( return SetDeviceName(raw_device_name, true); } -tensorflow::Status EagerOperation::MaybeInferSingleInputAttrs( - TensorHandle* handle) { +Status EagerOperation::MaybeInferSingleInputAttrs(TensorHandle* handle) { if (!op_def_) return Status::OK(); const auto& input_def = op_def_->input_arg(inference_arg_idx_++); @@ -78,8 +77,7 @@ tensorflow::Status EagerOperation::MaybeInferSingleInputAttrs( } void EagerOperation::InferSingleTypeInputListAttrs( - const tensorflow::OpDef::ArgDef& input_def, - const tensorflow::DataType dtype, int num_inputs) { + const OpDef::ArgDef& input_def, const DataType dtype, int num_inputs) { if (inference_attrs_.find(input_def.number_attr()) == inference_attrs_.end()) { MutableAttrs()->Set(input_def.number_attr(), num_inputs); @@ -92,24 +90,23 @@ void EagerOperation::InferSingleTypeInputListAttrs( } void EagerOperation::InferMixedTypeInputListAttrs( - const tensorflow::OpDef::ArgDef& input_def, - const std::vector<tensorflow::DataType>& dtypes) { + const OpDef::ArgDef& input_def, const std::vector<DataType>& dtypes) { if (inference_attrs_.find(input_def.type_list_attr()) == inference_attrs_.end()) { - MutableAttrs()->Set(input_def.type_list_attr(), - tensorflow::gtl::ArraySlice<const tensorflow::DataType>( - dtypes.data(), dtypes.size())); + MutableAttrs()->Set( + input_def.type_list_attr(), + gtl::ArraySlice<const DataType>(dtypes.data(), dtypes.size())); inference_attrs_.insert(input_def.type_list_attr()); } } -tensorflow::Status EagerOperation::InferInputListAttrs(int num_inputs) { +Status EagerOperation::InferInputListAttrs(int num_inputs) { if (!op_def_) return Status::OK(); int start = inference_arg_idx_; const auto& input_def = op_def_->input_arg(inference_arg_idx_++); if (!input_def.type_list_attr().empty()) { - std::vector<tensorflow::DataType> dtypes(num_inputs); + std::vector<DataType> dtypes(num_inputs); for (int i = 0; i < num_inputs; ++i) { dtypes[i] = inputs_[start + i]->dtype; } @@ -118,13 +115,12 @@ tensorflow::Status EagerOperation::InferInputListAttrs(int num_inputs) { !input_def.number_attr().empty()) { InferSingleTypeInputListAttrs(input_def, inputs_[start]->dtype, num_inputs); } else { - return tensorflow::errors::InvalidArgument("Invalid input list definition"); + return errors::InvalidArgument("Invalid input list definition"); } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status EagerOperation::SetDeviceName(const char* device, - const bool reset) { +Status EagerOperation::SetDeviceName(const char* device, const bool reset) { if (device != nullptr && strlen(device) > 0) { if (device != raw_device_name_) { if (!DeviceNameUtils::ParseFullName(device, &device_parsed_name_)) { diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index c7bc8a4543e..c653a92058a 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -30,7 +30,7 @@ class EagerOperation { public: explicit EagerOperation(tensorflow::EagerContext* ctx) : ctx_(*ctx) {} ~EagerOperation() { - for (tensorflow::TensorHandle* h : inputs_) { + for (TensorHandle* h : inputs_) { h->Unref(); } } @@ -39,41 +39,35 @@ class EagerOperation { // Clear(), and then Reset(...) with the same arguments that would have // been provided to the constructor. void Clear() { - for (tensorflow::TensorHandle* h : inputs_) { + for (TensorHandle* h : inputs_) { h->Unref(); } inputs_.clear(); ClearInferenceState(); } - tensorflow::Status Reset(const char* op, const char* raw_device_name, - bool remote, EagerExecutor* executor, - const absl::optional<EagerRemoteFunctionParams> - remote_func_params = absl::nullopt); + Status Reset(const char* op, const char* raw_device_name, bool remote, + EagerExecutor* executor, + const absl::optional<EagerRemoteFunctionParams> + remote_func_params = absl::nullopt); bool is_function() const { return is_function_; } tensorflow::EagerContext& EagerContext() { return ctx_; } - tensorflow::AttrBuilder* MutableAttrs() { return &attrs_; } - const tensorflow::AttrBuilder& Attrs() const { return attrs_; } + AttrBuilder* MutableAttrs() { return &attrs_; } + const AttrBuilder& Attrs() const { return attrs_; } const tensorflow::OpDef* OpDef() const { return op_def_; } - const tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4>& Inputs() - const { - return inputs_; - } - tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4>* - MutableInputs() { - return &inputs_; - } + const gtl::InlinedVector<TensorHandle*, 4>& Inputs() const { return inputs_; } + gtl::InlinedVector<TensorHandle*, 4>* MutableInputs() { return &inputs_; } - void AddInput(tensorflow::TensorHandle* h); - void UpdateInput(int i, tensorflow::TensorHandle* h); - void ConsumeInput(tensorflow::TensorHandle* h); + void AddInput(TensorHandle* h); + void UpdateInput(int i, TensorHandle* h); + void ConsumeInput(TensorHandle* h); - const tensorflow::string& Name() const { return attrs_.op_name(); } - const tensorflow::AttrTypeMap* AttrTypes() const { return attr_types_; } + const string& Name() const { return attrs_.op_name(); } + const AttrTypeMap* AttrTypes() const { return attr_types_; } tensorflow::Device* Device() const { return device_; } void SetDevice(tensorflow::Device* device) { @@ -87,8 +81,7 @@ class EagerOperation { const DeviceNameUtils::ParsedName& GetDeviceParsedName() const { return device_parsed_name_; } - tensorflow::Status SetDeviceName(const char* device, - const bool reset = false); + Status SetDeviceName(const char* device, const bool reset = false); // Indicates whether the op is assigned to a device that is local to the // current host. @@ -116,7 +109,7 @@ class EagerOperation { const char* op_name_ = nullptr; #endif - Status MaybeInferSingleInputAttrs(tensorflow::TensorHandle* handle); + Status MaybeInferSingleInputAttrs(TensorHandle* handle); Status InferInputListAttrs(int num_inputs); private: @@ -125,17 +118,15 @@ class EagerOperation { inference_arg_idx_ = 0; inference_attrs_.clear_no_resize(); } - void InferSingleTypeInputListAttrs(const tensorflow::OpDef::ArgDef& input_def, - const tensorflow::DataType dtype, - int num_inputs); - void InferMixedTypeInputListAttrs( - const tensorflow::OpDef::ArgDef& input_def, - const std::vector<tensorflow::DataType>& dtypes); + void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def, + const DataType dtype, int num_inputs); + void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def, + const std::vector<DataType>& dtypes); tensorflow::EagerContext& ctx_; - tensorflow::AttrBuilder attrs_; - const tensorflow::AttrTypeMap* attr_types_; - tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs_; + AttrBuilder attrs_; + const AttrTypeMap* attr_types_; + gtl::InlinedVector<TensorHandle*, 4> inputs_; tensorflow::Device* device_; string raw_device_name_; string device_name_; @@ -150,19 +141,18 @@ class EagerOperation { const tensorflow::OpDef* op_def_; // op definition from protobuf int inference_arg_idx_; // arg definition index for the next input to be // added - tensorflow::gtl::FlatSet<std::string> - inference_attrs_; // attributes inferred so far + gtl::FlatSet<std::string> inference_attrs_; // attributes inferred so far }; -inline void EagerOperation::AddInput(tensorflow::TensorHandle* h) { +inline void EagerOperation::AddInput(TensorHandle* h) { h->Ref(); inputs_.push_back(h); attrs_.NumInputs(static_cast<int>(inputs_.size())); } -inline void EagerOperation::UpdateInput(int i, tensorflow::TensorHandle* h) { - tensorflow::TensorHandle** slot = &inputs_[i]; - tensorflow::TensorHandle* existing = *slot; +inline void EagerOperation::UpdateInput(int i, TensorHandle* h) { + TensorHandle** slot = &inputs_[i]; + TensorHandle* existing = *slot; if (existing != h) { h->Ref(); existing->Unref(); @@ -170,11 +160,10 @@ inline void EagerOperation::UpdateInput(int i, tensorflow::TensorHandle* h) { } } -inline void EagerOperation::ConsumeInput(tensorflow::TensorHandle* h) { +inline void EagerOperation::ConsumeInput(TensorHandle* h) { inputs_.push_back(h); attrs_.NumInputs(static_cast<int>(inputs_.size())); } - } // namespace tensorflow #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_