Creating a registry for ops that should be excluded from input resource collocation constraints and making use of it in both graph-mode and eager-mode.
PiperOrigin-RevId: 250518974
This commit is contained in:
parent
acf4e786f7
commit
c7a1366349
tensorflow/core
@ -3165,6 +3165,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
|
||||
"common_runtime/executor.h",
|
||||
"common_runtime/executor_factory.h",
|
||||
"common_runtime/graph_optimizer.h",
|
||||
"common_runtime/input_colocation_exemption_registry.h",
|
||||
"common_runtime/isolate_placer_inspection_required_ops_pass.h",
|
||||
"common_runtime/local_device.h",
|
||||
"common_runtime/lower_function_call_op.h",
|
||||
@ -3225,6 +3226,7 @@ tf_cuda_library(
|
||||
"common_runtime/graph_optimizer.cc",
|
||||
"common_runtime/graph_runner.cc",
|
||||
"common_runtime/hierarchical_tree_broadcaster.cc",
|
||||
"common_runtime/input_colocation_exemption_registry.cc",
|
||||
"common_runtime/inspecting_placer.cc",
|
||||
"common_runtime/inspecting_placer.h",
|
||||
"common_runtime/isolate_placer_inspection_required_ops_pass.cc",
|
||||
@ -5261,6 +5263,19 @@ tf_cc_test_gpu(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_tests(
|
||||
name = "common_runtime_input_colocation_exemption_registry_test",
|
||||
size = "small",
|
||||
srcs = ["common_runtime/input_colocation_exemption_registry_test.cc"],
|
||||
deps = [
|
||||
":core_cpu",
|
||||
":core_cpu_internal",
|
||||
":test",
|
||||
":test_main",
|
||||
":testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_tests(
|
||||
name = "common_runtime_lower_function_call_test",
|
||||
size = "small",
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||
#include "tensorflow/core/common_runtime/inspecting_placer.h"
|
||||
#include "tensorflow/core/common_runtime/partitioning_utils.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
@ -150,8 +151,8 @@ bool IsExemptFromResourceInputColocation(const Node* node) {
|
||||
// ref inputs to operations that are appropriately placed, instead of
|
||||
// dereferencing them.
|
||||
const string& op_type = node->op_def().name();
|
||||
return op_type == "PartitionedCall" || op_type == "StatefulPartitionedCall" ||
|
||||
op_type == "ReduceDataset" || op_type == "ExperimentalScanDataset";
|
||||
auto exempt_ops = InputColocationExemptionRegistry::Global()->Get();
|
||||
return exempt_ops.find(op_type) != exempt_ops.end();
|
||||
}
|
||||
|
||||
bool HasPriorities(const PrioritizedDeviceTypeVector& device_types) {
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/eager/execute_node.h"
|
||||
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/logging.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
@ -997,7 +998,8 @@ bool IsPinnableOp(const string& op_type) {
|
||||
// (int32/int64). This can be disabled by setting the environment variable
|
||||
// "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false".
|
||||
Status MaybeUpdateOpDevice(EagerOperation* op) {
|
||||
if (op->is_function()) {
|
||||
auto exempt_ops = InputColocationExemptionRegistry::Global()->Get();
|
||||
if (op->is_function() || exempt_ops.find(op->Name()) != exempt_ops.end()) {
|
||||
// Don't update the device of direct function calls.
|
||||
// Particularly, if the user did not explicitly request any device for this
|
||||
// function, picking a device would result in this device being the default
|
||||
|
@ -0,0 +1,42 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. Al Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
InputColocationExemptionRegistry* InputColocationExemptionRegistry::Global() {
|
||||
static InputColocationExemptionRegistry* registry =
|
||||
new InputColocationExemptionRegistry;
|
||||
return registry;
|
||||
}
|
||||
|
||||
const std::set<string>& InputColocationExemptionRegistry::Get() { return ops_; }
|
||||
|
||||
void InputColocationExemptionRegistry::Register(const string& op) {
|
||||
auto it = ops_.find(op);
|
||||
if (it != ops_.end()) {
|
||||
LOG(WARNING) << "Input colocation exemption for op: " << op
|
||||
<< " already registered";
|
||||
} else {
|
||||
ops_.insert(op);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,76 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. Al Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_INPUT_COLOCATION_EXEMPTION_REGISTRY_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_INPUT_COLOCATION_EXEMPTION_REGISTRY_H_
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// TensorFlow runtime (both eager and graph) will aim to colocate ops with
|
||||
// their resource inputs so that the ops can access the resource state. In some
|
||||
// cases, such as tf.data ops, this is not desirable as the ops themselves might
|
||||
// not have a kernel registered for the device on which the resource is placed
|
||||
// and instead use a mechanism, such as a multi-device function, to access the
|
||||
// resource state.
|
||||
//
|
||||
// This registry can be used to register and list ops that should be exempt from
|
||||
// the input colocation described above.
|
||||
//
|
||||
// Example usage:
|
||||
// REGISTER_INPUT_COLOCATION_EXEMPTION("MapDataset");
|
||||
class InputColocationExemptionRegistry {
|
||||
public:
|
||||
// Returns a pointer to a global InputColocationExemptionRegistry object.
|
||||
static InputColocationExemptionRegistry* Global();
|
||||
|
||||
// Returns the set of ops exempt from the input colocation constraints.
|
||||
const std::set<string>& Get();
|
||||
|
||||
// Registers an op to be excluded from the input colocation constraints.
|
||||
void Register(const string& op);
|
||||
|
||||
private:
|
||||
std::set<string> ops_;
|
||||
};
|
||||
|
||||
namespace input_colocation_exemption_registration {
|
||||
|
||||
class InputColocationExemptionRegistration {
|
||||
public:
|
||||
explicit InputColocationExemptionRegistration(const string& op) {
|
||||
InputColocationExemptionRegistry::Global()->Register(op);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace input_colocation_exemption_registration
|
||||
|
||||
#define REGISTER_INPUT_COLOCATION_EXEMPTION(op) \
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ_HELPER(__COUNTER__, op)
|
||||
|
||||
#define REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ_HELPER(ctr, op) \
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ(ctr, op)
|
||||
|
||||
#define REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ(ctr, op) \
|
||||
static input_colocation_exemption_registration:: \
|
||||
InputColocationExemptionRegistration \
|
||||
input_colocation_exemption_registration_fn_##ctr(op)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_INPUT_COLOCATION_EXEMPTION_REGISTRY_H_
|
@ -0,0 +1,35 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("op 1");
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("op 2");
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(RPCFactoryRegistryTest, TestBasic) {
|
||||
auto exempt_ops = InputColocationExemptionRegistry::Global()->Get();
|
||||
EXPECT_EQ(exempt_ops.size(), 2);
|
||||
EXPECT_NE(exempt_ops.find("op 1"), exempt_ops.end());
|
||||
EXPECT_NE(exempt_ops.find("op 2"), exempt_ops.end());
|
||||
EXPECT_EQ(exempt_ops.find("op 3"), exempt_ops.end());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -292,6 +293,8 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("ExperimentalScanDataset").Device(DEVICE_CPU),
|
||||
ScanDatasetOp);
|
||||
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalScanDataset");
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||
#include "tensorflow/core/common_runtime/renamed_device.h"
|
||||
#include "tensorflow/core/common_runtime/threadpool_device.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -1303,6 +1304,8 @@ REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
|
||||
REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
|
||||
DeserializeIteratorOp);
|
||||
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ReduceDataset");
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace data
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -281,6 +282,10 @@ REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device(DEVICE_GPU),
|
||||
PartitionedCallOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("StatefulPartitionedCall").Device(DEVICE_GPU),
|
||||
PartitionedCallOp);
|
||||
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("PartitionedCall");
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("StatefulPartitionedCall");
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device(DEVICE_SYCL),
|
||||
PartitionedCallOp);
|
||||
|
Loading…
Reference in New Issue
Block a user