Add support for a device ID op in parallel_device
The op doesn't really make sense to register kernels for, so I'm not registering it anywhere by default yet; it's currently just registered in the parallel device tests. PiperOrigin-RevId: 311141160 Change-Id: Iff1839112dac6fe3406e4b31f0e6f7239809a5bb
This commit is contained in:
parent
adb282e47c
commit
8e3bc844b1
@ -44,6 +44,7 @@ tf_cc_test(
|
|||||||
srcs = ["parallel_device_test.cc"],
|
srcs = ["parallel_device_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":parallel_device",
|
":parallel_device",
|
||||||
|
":parallel_device_ops",
|
||||||
"//tensorflow/c:c_api",
|
"//tensorflow/c:c_api",
|
||||||
"//tensorflow/c:c_api_experimental",
|
"//tensorflow/c:c_api_experimental",
|
||||||
"//tensorflow/c/eager:c_api",
|
"//tensorflow/c/eager:c_api",
|
||||||
@ -53,3 +54,19 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Note: ParallelDevice-specific ops are experimental and not currently linked in
|
||||||
|
# to TensorFlow by default, just used in a few tests.
|
||||||
|
filegroup(
|
||||||
|
name = "parallel_device_ops_srcs",
|
||||||
|
srcs = ["parallel_device_ops.cc"],
|
||||||
|
visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "parallel_device_ops",
|
||||||
|
srcs = [":parallel_device_ops_srcs"],
|
||||||
|
visibility = ["//tensorflow:internal"],
|
||||||
|
deps = ["//tensorflow/core:framework"],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
@ -92,6 +92,10 @@ class ParallelDevice {
|
|||||||
TFE_TensorHandle* tensor,
|
TFE_TensorHandle* tensor,
|
||||||
TF_Status* status) const;
|
TF_Status* status) const;
|
||||||
|
|
||||||
|
// A parallel tensor with scalar integers numbering component devices.
|
||||||
|
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
|
||||||
|
TF_Status* status) const;
|
||||||
|
|
||||||
// Takes a description of a single operation being executed on the
|
// Takes a description of a single operation being executed on the
|
||||||
// ParallelDevice, and in turn runs one operation per component device with
|
// ParallelDevice, and in turn runs one operation per component device with
|
||||||
// its corresponding inputs from the input ParallelTensors (or
|
// its corresponding inputs from the input ParallelTensors (or
|
||||||
@ -208,6 +212,46 @@ std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
|||||||
status);
|
status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||||
|
TFE_Context* context, TF_Status* status) const {
|
||||||
|
// TODO(allenl): We could cache DeviceIDs (keyed by context).
|
||||||
|
std::vector<TensorHandlePtr> components;
|
||||||
|
components.reserve(underlying_devices_.size());
|
||||||
|
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||||
|
++device_index) {
|
||||||
|
int64_t* device_id = new int64_t;
|
||||||
|
*device_id = device_index;
|
||||||
|
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||||
|
TF_NewTensor(
|
||||||
|
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
||||||
|
sizeof(int64_t),
|
||||||
|
[](void* data, size_t, void* arg) {
|
||||||
|
delete reinterpret_cast<int64_t*>(data);
|
||||||
|
},
|
||||||
|
nullptr),
|
||||||
|
TF_DeleteTensor);
|
||||||
|
// TODO(allenl): Here and when executing regular operations, we could hold
|
||||||
|
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
|
||||||
|
// device names repeatedly.
|
||||||
|
OpPtr const_op(TFE_NewOp(context, "Const", status));
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
|
||||||
|
status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
|
||||||
|
TFE_TensorHandle* device_handle;
|
||||||
|
int num_outputs = 1;
|
||||||
|
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
components.emplace_back(device_handle);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
}
|
||||||
|
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||||
|
status);
|
||||||
|
}
|
||||||
|
|
||||||
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||||
@ -282,6 +326,13 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
|||||||
}
|
}
|
||||||
result.emplace(std::move(outputs));
|
result.emplace(std::move(outputs));
|
||||||
return result;
|
return result;
|
||||||
|
} else if (operation_name == std::string("DeviceID")) {
|
||||||
|
std::vector<MaybeParallelTensorOwned> result_content;
|
||||||
|
result_content.reserve(1);
|
||||||
|
result_content.push_back(DeviceIDs(context, status));
|
||||||
|
if (TF_GetCode(status) != TF_OK) return result;
|
||||||
|
result.emplace(std::move(result_content));
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||||
maybe_parallel_results(
|
maybe_parallel_results(
|
||||||
|
26
tensorflow/c/eager/parallel_device/parallel_device_ops.cc
Normal file
26
tensorflow/c/eager/parallel_device/parallel_device_ops.cc
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
|
||||||
|
// TODO(allenl): Figure out if we need this op, and if so whether we should move
|
||||||
|
// it to core TF. Right now the eager C API does some checking of op
|
||||||
|
// registrations before calling into custom devices, but we may be able to avoid
|
||||||
|
// that.
|
||||||
|
REGISTER_OP("DeviceID")
|
||||||
|
.Output("device_id: int64")
|
||||||
|
.SetIsStateful()
|
||||||
|
.SetShapeFn(tensorflow::shape_inference::ScalarShape);
|
@ -278,14 +278,15 @@ TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Assert that `handle` is equal to `expected_value`.
|
// Assert that `handle` is equal to `expected_value`.
|
||||||
void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
|
template <typename value_type>
|
||||||
|
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
|
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
|
||||||
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
|
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
ASSERT_EQ(expected_value,
|
EXPECT_EQ(expected_value,
|
||||||
*static_cast<float*>(TF_TensorData(value_zero.get())));
|
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <std::size_t num_devices>
|
template <std::size_t num_devices>
|
||||||
@ -343,8 +344,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
|||||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
AssertScalarFloatEq(components[0].get(), 20.);
|
ExpectScalarEq<float>(components[0].get(), 20.);
|
||||||
AssertScalarFloatEq(components[1].get(), 20.);
|
ExpectScalarEq<float>(components[1].get(), 20.);
|
||||||
|
|
||||||
std::string first_device =
|
std::string first_device =
|
||||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||||
@ -373,8 +374,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
|||||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
AssertScalarFloatEq(components[0].get(), 23.);
|
ExpectScalarEq<float>(components[0].get(), 23.);
|
||||||
AssertScalarFloatEq(components[1].get(), 18.);
|
ExpectScalarEq<float>(components[1].get(), 18.);
|
||||||
|
|
||||||
std::string first_device =
|
std::string first_device =
|
||||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||||
@ -383,6 +384,32 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
|||||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||||
ASSERT_EQ(underlying_devices[1], second_device);
|
ASSERT_EQ(underlying_devices[1], second_device);
|
||||||
}
|
}
|
||||||
|
// Compute the device ID twice and verify the result
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||||
|
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_OpSetDevice(op.get(), device_name, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
TFE_TensorHandle* result_handle;
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
std::array<TensorHandlePtr, 2> components;
|
||||||
|
ExtractPerDeviceValues(context, result_handle, &components, status.get());
|
||||||
|
TFE_DeleteTensorHandle(result_handle);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
ExpectScalarEq<int64_t>(components[0].get(), 0);
|
||||||
|
ExpectScalarEq<int64_t>(components[1].get(), 1);
|
||||||
|
std::string first_device =
|
||||||
|
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||||
|
ASSERT_EQ(underlying_devices[0], first_device);
|
||||||
|
std::string second_device =
|
||||||
|
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||||
|
ASSERT_EQ(underlying_devices[1], second_device);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(PARALLEL_DEVICE, TestBasicCPU) {
|
TEST(PARALLEL_DEVICE, TestBasicCPU) {
|
||||||
@ -498,8 +525,8 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) {
|
|||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
// The value of the original tensor is replicated on each device.
|
// The value of the original tensor is replicated on each device.
|
||||||
AssertScalarFloatEq(components[0].get(), 3.);
|
ExpectScalarEq<float>(components[0].get(), 3.);
|
||||||
AssertScalarFloatEq(components[1].get(), 3.);
|
ExpectScalarEq<float>(components[1].get(), 3.);
|
||||||
|
|
||||||
// Verify that the mirrors are placed on the component devices.
|
// Verify that the mirrors are placed on the component devices.
|
||||||
std::string first_device =
|
std::string first_device =
|
||||||
@ -630,7 +657,7 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
|||||||
&second_components, status.get());
|
&second_components, status.get());
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
AssertScalarFloatEq(second_components[1].get(), 9.);
|
ExpectScalarEq<float>(second_components[1].get(), 9.);
|
||||||
|
|
||||||
// Verify that the mirrors are placed on the component devices.
|
// Verify that the mirrors are placed on the component devices.
|
||||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||||
@ -644,8 +671,8 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
|||||||
std::array<TensorHandlePtr, 2> first_components;
|
std::array<TensorHandlePtr, 2> first_components;
|
||||||
ExtractPerDeviceValues(context.get(), second_components[0].get(),
|
ExtractPerDeviceValues(context.get(), second_components[0].get(),
|
||||||
&first_components, status.get());
|
&first_components, status.get());
|
||||||
AssertScalarFloatEq(first_components[0].get(), 3.);
|
ExpectScalarEq<float>(first_components[0].get(), 3.);
|
||||||
AssertScalarFloatEq(first_components[1].get(), 6.);
|
ExpectScalarEq<float>(first_components[1].get(), 6.);
|
||||||
|
|
||||||
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
|
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
|
||||||
status.get());
|
status.get());
|
||||||
@ -806,8 +833,8 @@ TEST(PARALLEL_DEVICE, TestCollective) {
|
|||||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||||
status.get());
|
status.get());
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
AssertScalarFloatEq(result_components[0].get(), 3.);
|
ExpectScalarEq<float>(result_components[0].get(), 3.);
|
||||||
AssertScalarFloatEq(result_components[1].get(), 3.);
|
ExpectScalarEq<float>(result_components[1].get(), 3.);
|
||||||
}
|
}
|
||||||
|
|
||||||
void RegisterCollectiveMulFunction(TFE_Context* context,
|
void RegisterCollectiveMulFunction(TFE_Context* context,
|
||||||
@ -909,8 +936,8 @@ TEST(PARALLEL_DEVICE, TestFunction) {
|
|||||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||||
status.get());
|
status.get());
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
AssertScalarFloatEq(result_components[0].get(), 7. * 9.);
|
ExpectScalarEq<float>(result_components[0].get(), 7. * 9.);
|
||||||
AssertScalarFloatEq(result_components[1].get(), 7. * 9.);
|
ExpectScalarEq<float>(result_components[1].get(), 7. * 9.);
|
||||||
|
|
||||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||||
result_components[0].get(), status.get());
|
result_components[0].get(), status.get());
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_wrapper_py")
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//tensorflow:internal"],
|
default_visibility = ["//tensorflow:internal"],
|
||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
@ -14,6 +17,7 @@ py_library(
|
|||||||
srcs = ["parallel_device.py"],
|
srcs = ["parallel_device.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":parallel_device_ops",
|
||||||
":saving",
|
":saving",
|
||||||
"//tensorflow/python:_pywrap_parallel_device",
|
"//tensorflow/python:_pywrap_parallel_device",
|
||||||
],
|
],
|
||||||
@ -26,6 +30,25 @@ py_library(
|
|||||||
deps = ["//tensorflow/python:framework_ops"],
|
deps = ["//tensorflow/python:framework_ops"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_gen_op_wrapper_py(
|
||||||
|
name = "parallel_device_ops_py",
|
||||||
|
out = "gen_parallel_device_ops.py",
|
||||||
|
deps = ["//tensorflow/c/eager/parallel_device:parallel_device_ops"],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_custom_op_library(
|
||||||
|
name = "_parallel_device_ops.so",
|
||||||
|
srcs = ["//tensorflow/c/eager/parallel_device:parallel_device_ops_srcs"],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_custom_op_py_library(
|
||||||
|
name = "parallel_device_ops",
|
||||||
|
dso = [":_parallel_device_ops.so"],
|
||||||
|
kernels = ["//tensorflow/c/eager/parallel_device:parallel_device_ops"],
|
||||||
|
visibility = ["//tensorflow:internal"],
|
||||||
|
deps = [":parallel_device_ops_py"],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "parallel_device_test",
|
name = "parallel_device_test",
|
||||||
srcs = ["parallel_device_test.py"],
|
srcs = ["parallel_device_test.py"],
|
||||||
|
@ -22,11 +22,17 @@ import contextlib
|
|||||||
import threading
|
import threading
|
||||||
|
|
||||||
from tensorflow.python import _pywrap_parallel_device
|
from tensorflow.python import _pywrap_parallel_device
|
||||||
|
from tensorflow.python.distribute.parallel_device import gen_parallel_device_ops
|
||||||
from tensorflow.python.distribute.parallel_device import saving
|
from tensorflow.python.distribute.parallel_device import saving
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import load_library
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.platform import resource_loader
|
||||||
from tensorflow.python.tpu.ops import tpu_ops
|
from tensorflow.python.tpu.ops import tpu_ops
|
||||||
|
|
||||||
|
load_library.load_op_library(
|
||||||
|
resource_loader.get_path_to_datafile("_parallel_device_ops.so"))
|
||||||
|
|
||||||
_next_device_number = 0
|
_next_device_number = 0
|
||||||
_next_device_number_lock = threading.Lock()
|
_next_device_number_lock = threading.Lock()
|
||||||
|
|
||||||
@ -58,6 +64,8 @@ class ParallelDevice(object):
|
|||||||
device, device_info = _pywrap_parallel_device.GetParallelDeviceCapsules(
|
device, device_info = _pywrap_parallel_device.GetParallelDeviceCapsules(
|
||||||
self.name, self.components)
|
self.name, self.components)
|
||||||
context.register_custom_device(device, self.name, device_info)
|
context.register_custom_device(device, self.name, device_info)
|
||||||
|
with ops.device(self.name):
|
||||||
|
self._device_ids = gen_parallel_device_ops.device_id()
|
||||||
|
|
||||||
def pack(self, tensors):
|
def pack(self, tensors):
|
||||||
"""Create a tensor on the parallel device from a sequence of tensors.
|
"""Create a tensor on the parallel device from a sequence of tensors.
|
||||||
@ -84,6 +92,18 @@ class ParallelDevice(object):
|
|||||||
return tpu_ops.tpu_replicated_output(
|
return tpu_ops.tpu_replicated_output(
|
||||||
parallel_tensor, num_replicas=len(self.components))
|
parallel_tensor, num_replicas=len(self.components))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device_ids(self):
|
||||||
|
"""A parallel tensor with scalar integers numbering component devices.
|
||||||
|
|
||||||
|
Each device ID is placed on its corresponding device, in the same order as
|
||||||
|
the `components` constructor argument.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A parallel tensor containing 0 on the first device, 1 on the second, etc.
|
||||||
|
"""
|
||||||
|
return self._device_ids
|
||||||
|
|
||||||
# TODO(allenl): Fixing saving in Python is a bit odd. One alternative would be
|
# TODO(allenl): Fixing saving in Python is a bit odd. One alternative would be
|
||||||
# to provide a hook for the custom device to create save specs/etc., then call
|
# to provide a hook for the custom device to create save specs/etc., then call
|
||||||
# that hook from the default variable implementation if the variable is on a
|
# that hook from the default variable implementation if the variable is on a
|
||||||
|
@ -119,6 +119,12 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
|
|||||||
self.assertIn(self.device.components[0], outputs[0].backing_device)
|
self.assertIn(self.device.components[0], outputs[0].backing_device)
|
||||||
self.assertIn(self.device.components[1], outputs[1].backing_device)
|
self.assertIn(self.device.components[1], outputs[1].backing_device)
|
||||||
|
|
||||||
|
def test_device_id(self):
|
||||||
|
device_ids = self.device.unpack(self.device.device_ids)
|
||||||
|
self.assertAllClose([0, 1], device_ids)
|
||||||
|
self.assertIn(self.device.components[0], device_ids[0].backing_device)
|
||||||
|
self.assertIn(self.device.components[1], device_ids[1].backing_device)
|
||||||
|
|
||||||
def test_collective_reduce(self):
|
def test_collective_reduce(self):
|
||||||
with ops.device(self.device.name):
|
with ops.device(self.device.name):
|
||||||
x = self.device.pack(
|
x = self.device.pack(
|
||||||
|
Loading…
Reference in New Issue
Block a user