Parallel device: Remove the DeviceID op usage for now

It may be good to have later, but for now having an op registered is an unnecessary API surface. There's not a terribly strong case for the op yet, and having it as a dependency but not a core op makes it hard to depend on the parallel device.

Keeps a "device_id" property on the ParallelDevice in Python, which will behave identically to the previous version

PiperOrigin-RevId: 329313676
Change-Id: Iaf024d0d762061ba90e3568199334b53432001ce
This commit is contained in:
Allen Lavoie 2020-08-31 09:08:26 -07:00 committed by TensorFlower Gardener
parent 974fbe0ee4
commit 027c55b667
6 changed files with 28 additions and 109 deletions
tensorflow

View File

@ -103,7 +103,6 @@ cc_library(
hdrs = ["parallel_device_testlib.h"],
deps = [
":parallel_device",
":parallel_device_ops",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
@ -118,7 +117,6 @@ tf_cc_test(
srcs = ["parallel_device_test.cc"],
deps = [
":parallel_device",
":parallel_device_ops",
":parallel_device_testlib",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
@ -138,7 +136,6 @@ tf_cc_test(
args = ["--heap_check=local"],
deps = [
":parallel_device",
":parallel_device_ops",
":parallel_device_testlib",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
@ -150,19 +147,3 @@ tf_cc_test(
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
],
)
# Note: ParallelDevice-specific ops are experimental and not currently linked in
# to TensorFlow by default, just used in a few tests.
filegroup(
name = "parallel_device_ops_srcs",
srcs = ["parallel_device_ops.cc"],
visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
)
cc_library(
name = "parallel_device_ops",
srcs = [":parallel_device_ops_srcs"],
visibility = ["//tensorflow:internal"],
deps = ["//tensorflow/core:framework"],
alwayslink = 1,
)

View File

@ -136,13 +136,6 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
}
result.emplace(std::move(outputs));
return result;
} else if (operation_name == std::string("DeviceID")) {
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
result_content.push_back(parallel_device.DeviceIDs(context, status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
}
std::vector<ParallelTensor*> parallel_inputs;
std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;

View File

@ -1,26 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
// TODO(allenl): Figure out if we need this op, and if so whether we should move
// it to core TF. Right now the eager C API does some checking of op
// registrations before calling into custom devices, but we may be able to avoid
// that.
REGISTER_OP("DeviceID")
.Output("device_id: int64")
.SetIsStateful()
.SetShapeFn(tensorflow::shape_inference::ScalarShape);

View File

@ -279,30 +279,4 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Compute the device ID twice and verify the result
for (int i = 0; i < 2; ++i) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetDevice(op.get(), device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, result_handle, &components, status.get());
TFE_DeleteTensorHandle(result_handle);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<int32_t>(components[0].get(), 0);
ExpectScalarEq<int32_t>(components[1].get(), 1);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
}

View File

@ -1,6 +1,3 @@
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_wrapper_py")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
@ -17,7 +14,6 @@ py_library(
srcs = ["parallel_device.py"],
srcs_version = "PY2AND3",
deps = [
":parallel_device_ops",
":saving",
"//tensorflow/python:_pywrap_parallel_device",
"//tensorflow/python/distribute:device_util",
@ -31,25 +27,6 @@ py_library(
deps = ["//tensorflow/python:framework_ops"],
)
tf_gen_op_wrapper_py(
name = "parallel_device_ops_py",
out = "gen_parallel_device_ops.py",
deps = ["//tensorflow/c/eager/parallel_device:parallel_device_ops"],
)
tf_custom_op_library(
name = "_parallel_device_ops.so",
srcs = ["//tensorflow/c/eager/parallel_device:parallel_device_ops_srcs"],
)
tf_custom_op_py_library(
name = "parallel_device_ops",
dso = [":_parallel_device_ops.so"],
kernels = ["//tensorflow/c/eager/parallel_device:parallel_device_ops"],
visibility = ["//tensorflow:internal"],
deps = [":parallel_device_ops_py"],
)
py_test(
name = "parallel_device_test",
srcs = ["parallel_device_test.py"],

View File

@ -22,23 +22,23 @@ import threading
from tensorflow.python import _pywrap_parallel_device
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute.parallel_device import gen_parallel_device_ops
from tensorflow.python.distribute.parallel_device import saving
from tensorflow.python.eager import context
from tensorflow.python.framework import load_library
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader
from tensorflow.python.ops import array_ops
from tensorflow.python.tpu.ops import tpu_ops
load_library.load_op_library(
resource_loader.get_path_to_datafile("_parallel_device_ops.so"))
_next_device_number = 0
_next_device_number_lock = threading.Lock()
# TODO(allenl): Expand this docstring once things like getting components on and
# off the device are stable.
#
# TODO(allenl): Make multi-client work; we need an offset for device IDs, and an
# indication of how many other devices there are total for collectives which
# don't have a number of participants hard-coded in their attributes.
class ParallelDevice(object):
"""A device which executes operations in parallel."""
@ -64,8 +64,7 @@ class ParallelDevice(object):
device, device_info = _pywrap_parallel_device.GetParallelDeviceCapsules(
self._name, self.components)
context.register_custom_device(device, self._name, device_info)
with ops.device(self._name):
self._device_ids = gen_parallel_device_ops.device_id()
self._device_ids = None
self._device_scope = None
self._saving_scope = None
@ -106,6 +105,27 @@ class ParallelDevice(object):
Returns:
A parallel tensor containing 0 on the first device, 1 on the second, etc.
"""
if self._device_ids is None:
# device_ids may be called from inside a tf.function, in which case the
# function captures the eager tensor. We can't pack tensors in a function
# at the moment, and even if we could we don't want to hold on to a
# symbolic tensor, so we need to init_scope out of the function
# temporarily.
with ops.init_scope():
# TODO(allenl): Functions which capture eager device ID tensors won't be
# saveable in SavedModels. Ideally we'd run a DeviceID op every time
# device IDs are required, with functions using the op in their bodies
# but not hard-coding a fixed number of devices (so they can be re-used
# with a different replica count).
device_ids_list = []
for index, device in enumerate(self.components):
with ops.device(device):
# The identity op ensures each device ID tensor is placed on its
# device.
device_ids_list.append(
array_ops.identity(constant_op.constant(index)))
self._device_ids = self.pack(device_ids_list)
return self._device_ids
def _assert_eager(self):