Remove unnecessary namespacing
PiperOrigin-RevId: 294298556 Change-Id: I2c3d1eb46aa150a8165f502c3e6a08b99b4377c9
This commit is contained in:
parent
fcd06dbbeb
commit
9759636453
tensorflow/core/common_runtime/eager
@ -53,8 +53,7 @@ Status EagerOperation::Reset(
|
|||||||
return SetDeviceName(raw_device_name, true);
|
return SetDeviceName(raw_device_name, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::Status EagerOperation::MaybeInferSingleInputAttrs(
|
Status EagerOperation::MaybeInferSingleInputAttrs(TensorHandle* handle) {
|
||||||
TensorHandle* handle) {
|
|
||||||
if (!op_def_) return Status::OK();
|
if (!op_def_) return Status::OK();
|
||||||
|
|
||||||
const auto& input_def = op_def_->input_arg(inference_arg_idx_++);
|
const auto& input_def = op_def_->input_arg(inference_arg_idx_++);
|
||||||
@ -78,8 +77,7 @@ tensorflow::Status EagerOperation::MaybeInferSingleInputAttrs(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void EagerOperation::InferSingleTypeInputListAttrs(
|
void EagerOperation::InferSingleTypeInputListAttrs(
|
||||||
const tensorflow::OpDef::ArgDef& input_def,
|
const OpDef::ArgDef& input_def, const DataType dtype, int num_inputs) {
|
||||||
const tensorflow::DataType dtype, int num_inputs) {
|
|
||||||
if (inference_attrs_.find(input_def.number_attr()) ==
|
if (inference_attrs_.find(input_def.number_attr()) ==
|
||||||
inference_attrs_.end()) {
|
inference_attrs_.end()) {
|
||||||
MutableAttrs()->Set(input_def.number_attr(), num_inputs);
|
MutableAttrs()->Set(input_def.number_attr(), num_inputs);
|
||||||
@ -92,24 +90,23 @@ void EagerOperation::InferSingleTypeInputListAttrs(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void EagerOperation::InferMixedTypeInputListAttrs(
|
void EagerOperation::InferMixedTypeInputListAttrs(
|
||||||
const tensorflow::OpDef::ArgDef& input_def,
|
const OpDef::ArgDef& input_def, const std::vector<DataType>& dtypes) {
|
||||||
const std::vector<tensorflow::DataType>& dtypes) {
|
|
||||||
if (inference_attrs_.find(input_def.type_list_attr()) ==
|
if (inference_attrs_.find(input_def.type_list_attr()) ==
|
||||||
inference_attrs_.end()) {
|
inference_attrs_.end()) {
|
||||||
MutableAttrs()->Set(input_def.type_list_attr(),
|
MutableAttrs()->Set(
|
||||||
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
|
input_def.type_list_attr(),
|
||||||
dtypes.data(), dtypes.size()));
|
gtl::ArraySlice<const DataType>(dtypes.data(), dtypes.size()));
|
||||||
inference_attrs_.insert(input_def.type_list_attr());
|
inference_attrs_.insert(input_def.type_list_attr());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::Status EagerOperation::InferInputListAttrs(int num_inputs) {
|
Status EagerOperation::InferInputListAttrs(int num_inputs) {
|
||||||
if (!op_def_) return Status::OK();
|
if (!op_def_) return Status::OK();
|
||||||
|
|
||||||
int start = inference_arg_idx_;
|
int start = inference_arg_idx_;
|
||||||
const auto& input_def = op_def_->input_arg(inference_arg_idx_++);
|
const auto& input_def = op_def_->input_arg(inference_arg_idx_++);
|
||||||
if (!input_def.type_list_attr().empty()) {
|
if (!input_def.type_list_attr().empty()) {
|
||||||
std::vector<tensorflow::DataType> dtypes(num_inputs);
|
std::vector<DataType> dtypes(num_inputs);
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
dtypes[i] = inputs_[start + i]->dtype;
|
dtypes[i] = inputs_[start + i]->dtype;
|
||||||
}
|
}
|
||||||
@ -118,13 +115,12 @@ tensorflow::Status EagerOperation::InferInputListAttrs(int num_inputs) {
|
|||||||
!input_def.number_attr().empty()) {
|
!input_def.number_attr().empty()) {
|
||||||
InferSingleTypeInputListAttrs(input_def, inputs_[start]->dtype, num_inputs);
|
InferSingleTypeInputListAttrs(input_def, inputs_[start]->dtype, num_inputs);
|
||||||
} else {
|
} else {
|
||||||
return tensorflow::errors::InvalidArgument("Invalid input list definition");
|
return errors::InvalidArgument("Invalid input list definition");
|
||||||
}
|
}
|
||||||
return tensorflow::Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::Status EagerOperation::SetDeviceName(const char* device,
|
Status EagerOperation::SetDeviceName(const char* device, const bool reset) {
|
||||||
const bool reset) {
|
|
||||||
if (device != nullptr && strlen(device) > 0) {
|
if (device != nullptr && strlen(device) > 0) {
|
||||||
if (device != raw_device_name_) {
|
if (device != raw_device_name_) {
|
||||||
if (!DeviceNameUtils::ParseFullName(device, &device_parsed_name_)) {
|
if (!DeviceNameUtils::ParseFullName(device, &device_parsed_name_)) {
|
||||||
|
@ -30,7 +30,7 @@ class EagerOperation {
|
|||||||
public:
|
public:
|
||||||
explicit EagerOperation(tensorflow::EagerContext* ctx) : ctx_(*ctx) {}
|
explicit EagerOperation(tensorflow::EagerContext* ctx) : ctx_(*ctx) {}
|
||||||
~EagerOperation() {
|
~EagerOperation() {
|
||||||
for (tensorflow::TensorHandle* h : inputs_) {
|
for (TensorHandle* h : inputs_) {
|
||||||
h->Unref();
|
h->Unref();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -39,15 +39,15 @@ class EagerOperation {
|
|||||||
// Clear(), and then Reset(...) with the same arguments that would have
|
// Clear(), and then Reset(...) with the same arguments that would have
|
||||||
// been provided to the constructor.
|
// been provided to the constructor.
|
||||||
void Clear() {
|
void Clear() {
|
||||||
for (tensorflow::TensorHandle* h : inputs_) {
|
for (TensorHandle* h : inputs_) {
|
||||||
h->Unref();
|
h->Unref();
|
||||||
}
|
}
|
||||||
inputs_.clear();
|
inputs_.clear();
|
||||||
ClearInferenceState();
|
ClearInferenceState();
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::Status Reset(const char* op, const char* raw_device_name,
|
Status Reset(const char* op, const char* raw_device_name, bool remote,
|
||||||
bool remote, EagerExecutor* executor,
|
EagerExecutor* executor,
|
||||||
const absl::optional<EagerRemoteFunctionParams>
|
const absl::optional<EagerRemoteFunctionParams>
|
||||||
remote_func_params = absl::nullopt);
|
remote_func_params = absl::nullopt);
|
||||||
|
|
||||||
@ -55,25 +55,19 @@ class EagerOperation {
|
|||||||
|
|
||||||
tensorflow::EagerContext& EagerContext() { return ctx_; }
|
tensorflow::EagerContext& EagerContext() { return ctx_; }
|
||||||
|
|
||||||
tensorflow::AttrBuilder* MutableAttrs() { return &attrs_; }
|
AttrBuilder* MutableAttrs() { return &attrs_; }
|
||||||
const tensorflow::AttrBuilder& Attrs() const { return attrs_; }
|
const AttrBuilder& Attrs() const { return attrs_; }
|
||||||
const tensorflow::OpDef* OpDef() const { return op_def_; }
|
const tensorflow::OpDef* OpDef() const { return op_def_; }
|
||||||
|
|
||||||
const tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4>& Inputs()
|
const gtl::InlinedVector<TensorHandle*, 4>& Inputs() const { return inputs_; }
|
||||||
const {
|
gtl::InlinedVector<TensorHandle*, 4>* MutableInputs() { return &inputs_; }
|
||||||
return inputs_;
|
|
||||||
}
|
|
||||||
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4>*
|
|
||||||
MutableInputs() {
|
|
||||||
return &inputs_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void AddInput(tensorflow::TensorHandle* h);
|
void AddInput(TensorHandle* h);
|
||||||
void UpdateInput(int i, tensorflow::TensorHandle* h);
|
void UpdateInput(int i, TensorHandle* h);
|
||||||
void ConsumeInput(tensorflow::TensorHandle* h);
|
void ConsumeInput(TensorHandle* h);
|
||||||
|
|
||||||
const tensorflow::string& Name() const { return attrs_.op_name(); }
|
const string& Name() const { return attrs_.op_name(); }
|
||||||
const tensorflow::AttrTypeMap* AttrTypes() const { return attr_types_; }
|
const AttrTypeMap* AttrTypes() const { return attr_types_; }
|
||||||
|
|
||||||
tensorflow::Device* Device() const { return device_; }
|
tensorflow::Device* Device() const { return device_; }
|
||||||
void SetDevice(tensorflow::Device* device) {
|
void SetDevice(tensorflow::Device* device) {
|
||||||
@ -87,8 +81,7 @@ class EagerOperation {
|
|||||||
const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
|
const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
|
||||||
return device_parsed_name_;
|
return device_parsed_name_;
|
||||||
}
|
}
|
||||||
tensorflow::Status SetDeviceName(const char* device,
|
Status SetDeviceName(const char* device, const bool reset = false);
|
||||||
const bool reset = false);
|
|
||||||
|
|
||||||
// Indicates whether the op is assigned to a device that is local to the
|
// Indicates whether the op is assigned to a device that is local to the
|
||||||
// current host.
|
// current host.
|
||||||
@ -116,7 +109,7 @@ class EagerOperation {
|
|||||||
const char* op_name_ = nullptr;
|
const char* op_name_ = nullptr;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
Status MaybeInferSingleInputAttrs(tensorflow::TensorHandle* handle);
|
Status MaybeInferSingleInputAttrs(TensorHandle* handle);
|
||||||
Status InferInputListAttrs(int num_inputs);
|
Status InferInputListAttrs(int num_inputs);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -125,17 +118,15 @@ class EagerOperation {
|
|||||||
inference_arg_idx_ = 0;
|
inference_arg_idx_ = 0;
|
||||||
inference_attrs_.clear_no_resize();
|
inference_attrs_.clear_no_resize();
|
||||||
}
|
}
|
||||||
void InferSingleTypeInputListAttrs(const tensorflow::OpDef::ArgDef& input_def,
|
void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def,
|
||||||
const tensorflow::DataType dtype,
|
const DataType dtype, int num_inputs);
|
||||||
int num_inputs);
|
void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def,
|
||||||
void InferMixedTypeInputListAttrs(
|
const std::vector<DataType>& dtypes);
|
||||||
const tensorflow::OpDef::ArgDef& input_def,
|
|
||||||
const std::vector<tensorflow::DataType>& dtypes);
|
|
||||||
|
|
||||||
tensorflow::EagerContext& ctx_;
|
tensorflow::EagerContext& ctx_;
|
||||||
tensorflow::AttrBuilder attrs_;
|
AttrBuilder attrs_;
|
||||||
const tensorflow::AttrTypeMap* attr_types_;
|
const AttrTypeMap* attr_types_;
|
||||||
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs_;
|
gtl::InlinedVector<TensorHandle*, 4> inputs_;
|
||||||
tensorflow::Device* device_;
|
tensorflow::Device* device_;
|
||||||
string raw_device_name_;
|
string raw_device_name_;
|
||||||
string device_name_;
|
string device_name_;
|
||||||
@ -150,19 +141,18 @@ class EagerOperation {
|
|||||||
const tensorflow::OpDef* op_def_; // op definition from protobuf
|
const tensorflow::OpDef* op_def_; // op definition from protobuf
|
||||||
int inference_arg_idx_; // arg definition index for the next input to be
|
int inference_arg_idx_; // arg definition index for the next input to be
|
||||||
// added
|
// added
|
||||||
tensorflow::gtl::FlatSet<std::string>
|
gtl::FlatSet<std::string> inference_attrs_; // attributes inferred so far
|
||||||
inference_attrs_; // attributes inferred so far
|
|
||||||
};
|
};
|
||||||
|
|
||||||
inline void EagerOperation::AddInput(tensorflow::TensorHandle* h) {
|
inline void EagerOperation::AddInput(TensorHandle* h) {
|
||||||
h->Ref();
|
h->Ref();
|
||||||
inputs_.push_back(h);
|
inputs_.push_back(h);
|
||||||
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void EagerOperation::UpdateInput(int i, tensorflow::TensorHandle* h) {
|
inline void EagerOperation::UpdateInput(int i, TensorHandle* h) {
|
||||||
tensorflow::TensorHandle** slot = &inputs_[i];
|
TensorHandle** slot = &inputs_[i];
|
||||||
tensorflow::TensorHandle* existing = *slot;
|
TensorHandle* existing = *slot;
|
||||||
if (existing != h) {
|
if (existing != h) {
|
||||||
h->Ref();
|
h->Ref();
|
||||||
existing->Unref();
|
existing->Unref();
|
||||||
@ -170,11 +160,10 @@ inline void EagerOperation::UpdateInput(int i, tensorflow::TensorHandle* h) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void EagerOperation::ConsumeInput(tensorflow::TensorHandle* h) {
|
inline void EagerOperation::ConsumeInput(TensorHandle* h) {
|
||||||
inputs_.push_back(h);
|
inputs_.push_back(h);
|
||||||
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
|
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
|
||||||
|
Loading…
Reference in New Issue
Block a user