Parallel device: Remove the DeviceID op usage for now
It may be good to have later, but for now having an op registered is an unnecessary API surface. There's not a terribly strong case for the op yet, and having it as a dependency but not a core op makes it hard to depend on the parallel device. Keeps a "device_id" property on the ParallelDevice in Python, which will behave identically to the previous version PiperOrigin-RevId: 329313676 Change-Id: Iaf024d0d762061ba90e3568199334b53432001ce
This commit is contained in:
parent
974fbe0ee4
commit
027c55b667
tensorflow
c/eager/parallel_device
python/distribute/parallel_device
@ -103,7 +103,6 @@ cc_library(
|
||||
hdrs = ["parallel_device_testlib.h"],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
":parallel_device_ops",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
@ -118,7 +117,6 @@ tf_cc_test(
|
||||
srcs = ["parallel_device_test.cc"],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
":parallel_device_ops",
|
||||
":parallel_device_testlib",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
@ -138,7 +136,6 @@ tf_cc_test(
|
||||
args = ["--heap_check=local"],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
":parallel_device_ops",
|
||||
":parallel_device_testlib",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
@ -150,19 +147,3 @@ tf_cc_test(
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
],
|
||||
)
|
||||
|
||||
# Note: ParallelDevice-specific ops are experimental and not currently linked in
|
||||
# to TensorFlow by default, just used in a few tests.
|
||||
filegroup(
|
||||
name = "parallel_device_ops_srcs",
|
||||
srcs = ["parallel_device_ops.cc"],
|
||||
visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device_ops",
|
||||
srcs = [":parallel_device_ops_srcs"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = ["//tensorflow/core:framework"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -136,13 +136,6 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
|
||||
}
|
||||
result.emplace(std::move(outputs));
|
||||
return result;
|
||||
} else if (operation_name == std::string("DeviceID")) {
|
||||
std::vector<MaybeParallelTensorOwned> result_content;
|
||||
result_content.reserve(1);
|
||||
result_content.push_back(parallel_device.DeviceIDs(context, status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
}
|
||||
std::vector<ParallelTensor*> parallel_inputs;
|
||||
std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
|
||||
|
@ -1,26 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
// TODO(allenl): Figure out if we need this op, and if so whether we should move
|
||||
// it to core TF. Right now the eager C API does some checking of op
|
||||
// registrations before calling into custom devices, but we may be able to avoid
|
||||
// that.
|
||||
REGISTER_OP("DeviceID")
|
||||
.Output("device_id: int64")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(tensorflow::shape_inference::ScalarShape);
|
@ -279,30 +279,4 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
// Compute the device ID twice and verify the result
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetDevice(op.get(), device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, result_handle, &components, status.get());
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
ExpectScalarEq<int32_t>(components[0].get(), 0);
|
||||
ExpectScalarEq<int32_t>(components[1].get(), 1);
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,3 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_wrapper_py")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
@ -17,7 +14,6 @@ py_library(
|
||||
srcs = ["parallel_device.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":parallel_device_ops",
|
||||
":saving",
|
||||
"//tensorflow/python:_pywrap_parallel_device",
|
||||
"//tensorflow/python/distribute:device_util",
|
||||
@ -31,25 +27,6 @@ py_library(
|
||||
deps = ["//tensorflow/python:framework_ops"],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "parallel_device_ops_py",
|
||||
out = "gen_parallel_device_ops.py",
|
||||
deps = ["//tensorflow/c/eager/parallel_device:parallel_device_ops"],
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "_parallel_device_ops.so",
|
||||
srcs = ["//tensorflow/c/eager/parallel_device:parallel_device_ops_srcs"],
|
||||
)
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "parallel_device_ops",
|
||||
dso = [":_parallel_device_ops.so"],
|
||||
kernels = ["//tensorflow/c/eager/parallel_device:parallel_device_ops"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [":parallel_device_ops_py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "parallel_device_test",
|
||||
srcs = ["parallel_device_test.py"],
|
||||
|
@ -22,23 +22,23 @@ import threading
|
||||
|
||||
from tensorflow.python import _pywrap_parallel_device
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute.parallel_device import gen_parallel_device_ops
|
||||
from tensorflow.python.distribute.parallel_device import saving
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import load_library
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.tpu.ops import tpu_ops
|
||||
|
||||
load_library.load_op_library(
|
||||
resource_loader.get_path_to_datafile("_parallel_device_ops.so"))
|
||||
|
||||
_next_device_number = 0
|
||||
_next_device_number_lock = threading.Lock()
|
||||
|
||||
|
||||
# TODO(allenl): Expand this docstring once things like getting components on and
|
||||
# off the device are stable.
|
||||
#
|
||||
# TODO(allenl): Make multi-client work; we need an offset for device IDs, and an
|
||||
# indication of how many other devices there are total for collectives which
|
||||
# don't have a number of participants hard-coded in their attributes.
|
||||
class ParallelDevice(object):
|
||||
"""A device which executes operations in parallel."""
|
||||
|
||||
@ -64,8 +64,7 @@ class ParallelDevice(object):
|
||||
device, device_info = _pywrap_parallel_device.GetParallelDeviceCapsules(
|
||||
self._name, self.components)
|
||||
context.register_custom_device(device, self._name, device_info)
|
||||
with ops.device(self._name):
|
||||
self._device_ids = gen_parallel_device_ops.device_id()
|
||||
self._device_ids = None
|
||||
self._device_scope = None
|
||||
self._saving_scope = None
|
||||
|
||||
@ -106,6 +105,27 @@ class ParallelDevice(object):
|
||||
Returns:
|
||||
A parallel tensor containing 0 on the first device, 1 on the second, etc.
|
||||
"""
|
||||
if self._device_ids is None:
|
||||
# device_ids may be called from inside a tf.function, in which case the
|
||||
# function captures the eager tensor. We can't pack tensors in a function
|
||||
# at the moment, and even if we could we don't want to hold on to a
|
||||
# symbolic tensor, so we need to init_scope out of the function
|
||||
# temporarily.
|
||||
with ops.init_scope():
|
||||
# TODO(allenl): Functions which capture eager device ID tensors won't be
|
||||
# saveable in SavedModels. Ideally we'd run a DeviceID op every time
|
||||
# device IDs are required, with functions using the op in their bodies
|
||||
# but not hard-coding a fixed number of devices (so they can be re-used
|
||||
# with a different replica count).
|
||||
device_ids_list = []
|
||||
for index, device in enumerate(self.components):
|
||||
with ops.device(device):
|
||||
# The identity op ensures each device ID tensor is placed on its
|
||||
# device.
|
||||
device_ids_list.append(
|
||||
array_ops.identity(constant_op.constant(index)))
|
||||
self._device_ids = self.pack(device_ids_list)
|
||||
|
||||
return self._device_ids
|
||||
|
||||
def _assert_eager(self):
|
||||
|
Loading…
Reference in New Issue
Block a user