Creating a registry for ops that should be excluded from input resource collocation constraints and making use of it in both graph-mode and eager-mode.

PiperOrigin-RevId: 250518974
This commit is contained in:
Jiri Simsa 2019-05-29 10:32:54 -07:00 committed by TensorFlower Gardener
parent acf4e786f7
commit c7a1366349
9 changed files with 185 additions and 3 deletions

View File

@ -3165,6 +3165,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/executor.h",
"common_runtime/executor_factory.h",
"common_runtime/graph_optimizer.h",
"common_runtime/input_colocation_exemption_registry.h",
"common_runtime/isolate_placer_inspection_required_ops_pass.h",
"common_runtime/local_device.h",
"common_runtime/lower_function_call_op.h",
@ -3225,6 +3226,7 @@ tf_cuda_library(
"common_runtime/graph_optimizer.cc",
"common_runtime/graph_runner.cc",
"common_runtime/hierarchical_tree_broadcaster.cc",
"common_runtime/input_colocation_exemption_registry.cc",
"common_runtime/inspecting_placer.cc",
"common_runtime/inspecting_placer.h",
"common_runtime/isolate_placer_inspection_required_ops_pass.cc",
@ -5261,6 +5263,19 @@ tf_cc_test_gpu(
],
)
tf_cc_tests(
name = "common_runtime_input_colocation_exemption_registry_test",
size = "small",
srcs = ["common_runtime/input_colocation_exemption_registry_test.cc"],
deps = [
":core_cpu",
":core_cpu_internal",
":test",
":test_main",
":testlib",
],
)
tf_cc_tests(
name = "common_runtime_lower_function_call_test",
size = "small",

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include "tensorflow/core/common_runtime/inspecting_placer.h"
#include "tensorflow/core/common_runtime/partitioning_utils.h"
#include "tensorflow/core/framework/attr_value.pb.h"
@ -150,8 +151,8 @@ bool IsExemptFromResourceInputColocation(const Node* node) {
// ref inputs to operations that are appropriately placed, instead of
// dereferencing them.
const string& op_type = node->op_def().name();
return op_type == "PartitionedCall" || op_type == "StatefulPartitionedCall" ||
op_type == "ReduceDataset" || op_type == "ExperimentalScanDataset";
auto exempt_ops = InputColocationExemptionRegistry::Global()->Get();
return exempt_ops.find(op_type) != exempt_ops.end();
}
bool HasPriorities(const PrioritizedDeviceTypeVector& device_types) {

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/execute_node.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/logging.h"
#include "tensorflow/core/framework/node_def_util.h"
@ -997,7 +998,8 @@ bool IsPinnableOp(const string& op_type) {
// (int32/int64). This can be disabled by setting the environment variable
// "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false".
Status MaybeUpdateOpDevice(EagerOperation* op) {
if (op->is_function()) {
auto exempt_ops = InputColocationExemptionRegistry::Global()->Get();
if (op->is_function() || exempt_ops.find(op->Name()) != exempt_ops.end()) {
// Don't update the device of direct function calls.
// Particularly, if the user did not explicitly request any device for this
// function, picking a device would result in this device being the default

View File

@ -0,0 +1,42 @@
/* Copyright 2019 The TensorFlow Authors. Al Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include <set>
#include <string>
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
InputColocationExemptionRegistry* InputColocationExemptionRegistry::Global() {
static InputColocationExemptionRegistry* registry =
new InputColocationExemptionRegistry;
return registry;
}
const std::set<string>& InputColocationExemptionRegistry::Get() { return ops_; }
void InputColocationExemptionRegistry::Register(const string& op) {
auto it = ops_.find(op);
if (it != ops_.end()) {
LOG(WARNING) << "Input colocation exemption for op: " << op
<< " already registered";
} else {
ops_.insert(op);
}
}
} // namespace tensorflow

View File

@ -0,0 +1,76 @@
/* Copyright 2019 The TensorFlow Authors. Al Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_INPUT_COLOCATION_EXEMPTION_REGISTRY_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_INPUT_COLOCATION_EXEMPTION_REGISTRY_H_
#include <set>
#include <string>
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// TensorFlow runtime (both eager and graph) will aim to colocate ops with
// their resource inputs so that the ops can access the resource state. In some
// cases, such as tf.data ops, this is not desirable as the ops themselves might
// not have a kernel registered for the device on which the resource is placed
// and instead use a mechanism, such as a multi-device function, to access the
// resource state.
//
// This registry can be used to register and list ops that should be exempt from
// the input colocation described above.
//
// Example usage:
// REGISTER_INPUT_COLOCATION_EXEMPTION("MapDataset");
class InputColocationExemptionRegistry {
public:
// Returns a pointer to a global InputColocationExemptionRegistry object.
static InputColocationExemptionRegistry* Global();
// Returns the set of ops exempt from the input colocation constraints.
const std::set<string>& Get();
// Registers an op to be excluded from the input colocation constraints.
void Register(const string& op);
private:
std::set<string> ops_;
};
namespace input_colocation_exemption_registration {
class InputColocationExemptionRegistration {
public:
explicit InputColocationExemptionRegistration(const string& op) {
InputColocationExemptionRegistry::Global()->Register(op);
}
};
} // namespace input_colocation_exemption_registration
#define REGISTER_INPUT_COLOCATION_EXEMPTION(op) \
REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ_HELPER(__COUNTER__, op)
#define REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ_HELPER(ctr, op) \
REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ(ctr, op)
#define REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ(ctr, op) \
static input_colocation_exemption_registration:: \
InputColocationExemptionRegistration \
input_colocation_exemption_registration_fn_##ctr(op)
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_INPUT_COLOCATION_EXEMPTION_REGISTRY_H_

View File

@ -0,0 +1,35 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
REGISTER_INPUT_COLOCATION_EXEMPTION("op 1");
REGISTER_INPUT_COLOCATION_EXEMPTION("op 2");
} // namespace
TEST(RPCFactoryRegistryTest, TestBasic) {
auto exempt_ops = InputColocationExemptionRegistry::Global()->Get();
EXPECT_EQ(exempt_ops.size(), 2);
EXPECT_NE(exempt_ops.find("op 1"), exempt_ops.end());
EXPECT_NE(exempt_ops.find("op 2"), exempt_ops.end());
EXPECT_EQ(exempt_ops.find("op 3"), exempt_ops.end());
}
} // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
@ -292,6 +293,8 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("ExperimentalScanDataset").Device(DEVICE_CPU),
ScanDatasetOp);
REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalScanDataset");
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include "tensorflow/core/common_runtime/renamed_device.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/framework/function.h"
@ -1303,6 +1304,8 @@ REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
DeserializeIteratorOp);
REGISTER_INPUT_COLOCATION_EXEMPTION("ReduceDataset");
} // namespace
} // namespace data

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "absl/strings/match.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
@ -281,6 +282,10 @@ REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device(DEVICE_GPU),
PartitionedCallOp);
REGISTER_KERNEL_BUILDER(Name("StatefulPartitionedCall").Device(DEVICE_GPU),
PartitionedCallOp);
REGISTER_INPUT_COLOCATION_EXEMPTION("PartitionedCall");
REGISTER_INPUT_COLOCATION_EXEMPTION("StatefulPartitionedCall");
#if TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device(DEVICE_SYCL),
PartitionedCallOp);