Support running a tf.function with packed variable inputs both locally and remotely.

- Support packing multiple EagerTensors of the same dtype and shape.
- Create CompositeDevices on the same task as the local host CPU, in order to correctly trigger packed TensorHandle copy from a client to a remote worker.

PiperOrigin-RevId: 312164194
Change-Id: Ia15718309c8c68eb645bfe0bf967ddd6d2551b3a
This commit is contained in:
Yujing Zhang 2020-05-18 15:17:54 -07:00 committed by TensorFlower Gardener
parent 1a07ecf852
commit 3c54ef5ab9
17 changed files with 274 additions and 23 deletions

View File

@ -24,7 +24,7 @@ const char* const kCompositeDeviceType = "COMPOSITE";
std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice( std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
const std::vector<string>& underlying_devices, const int unique_device_id, const std::vector<string>& underlying_devices, const int unique_device_id,
Status* status) { const DeviceNameUtils::ParsedName& host_name, Status* status) {
if (underlying_devices.empty()) { if (underlying_devices.empty()) {
status->Update( status->Update(
errors::InvalidArgument("underlying_devices should not be empty.")); errors::InvalidArgument("underlying_devices should not be empty."));
@ -62,13 +62,15 @@ std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
return nullptr; return nullptr;
} }
} }
DeviceNameUtils::ParsedName parsed_composite_name = host_name;
DeviceAttributes device_attributes; DeviceAttributes device_attributes;
parsed_name.type = kCompositeDeviceType; parsed_composite_name.type = kCompositeDeviceType;
device_attributes.set_device_type(parsed_name.type); parsed_composite_name.id = unique_device_id;
parsed_name.id = unique_device_id;
const string composite_name = const string composite_name =
DeviceNameUtils::ParsedNameToString(parsed_name); DeviceNameUtils::ParsedNameToString(parsed_composite_name);
device_attributes.set_name(composite_name); device_attributes.set_name(composite_name);
device_attributes.set_device_type(kCompositeDeviceType);
return absl::WrapUnique( return absl::WrapUnique(
new CompositeDevice(device_attributes, underlying_devices)); new CompositeDevice(device_attributes, underlying_devices));

View File

@ -42,10 +42,11 @@ class CompositeDevice : public Device {
return &underlying_devices_; return &underlying_devices_;
} }
// Helper for creating a CompositeDevice. // Helper for creating a CompositeDevice on the same task as the given host
// CPU.
static std::unique_ptr<CompositeDevice> MakeDevice( static std::unique_ptr<CompositeDevice> MakeDevice(
const std::vector<string>& underlying_devices, const int unique_device_id, const std::vector<string>& underlying_devices, const int unique_device_id,
Status* status); const DeviceNameUtils::ParsedName& host_name, Status* status);
private: private:
CompositeDevice(const DeviceAttributes& device_attributes, CompositeDevice(const DeviceAttributes& device_attributes,

View File

@ -20,12 +20,15 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
TEST(CompositeDeviceTest, Basic) { TEST(CompositeDeviceTest, Basic) {
const string host_name = "/job:localhost/replica:0/task:0/device:CPU:0";
DeviceNameUtils::ParsedName parsed_host_name;
EXPECT_TRUE(DeviceNameUtils::ParseFullName(host_name, &parsed_host_name));
std::vector<string> underlying_devices; std::vector<string> underlying_devices;
{ {
Status status; Status status;
std::unique_ptr<CompositeDevice> composite_device = std::unique_ptr<CompositeDevice> composite_device =
CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/0, CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/0,
&status); parsed_host_name, &status);
EXPECT_EQ(composite_device, nullptr); EXPECT_EQ(composite_device, nullptr);
EXPECT_EQ(error::INVALID_ARGUMENT, status.code()); EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
EXPECT_TRUE(absl::StrContains(status.error_message(), EXPECT_TRUE(absl::StrContains(status.error_message(),
@ -41,7 +44,7 @@ TEST(CompositeDeviceTest, Basic) {
"/job:localhost/replica:0/task:0/device:CPU:1"); "/job:localhost/replica:0/task:0/device:CPU:1");
std::unique_ptr<CompositeDevice> composite_device = std::unique_ptr<CompositeDevice> composite_device =
CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/0, CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/0,
&status); parsed_host_name, &status);
TF_ASSERT_OK(status); TF_ASSERT_OK(status);
EXPECT_EQ(composite_device->device_type(), kCompositeDeviceType); EXPECT_EQ(composite_device->device_type(), kCompositeDeviceType);
EXPECT_EQ(underlying_devices, *composite_device->underlying_devices()); EXPECT_EQ(underlying_devices, *composite_device->underlying_devices());
@ -53,7 +56,7 @@ TEST(CompositeDeviceTest, Basic) {
"/job:localhost/replica:0/task:0/device:CPU:0"); "/job:localhost/replica:0/task:0/device:CPU:0");
std::unique_ptr<CompositeDevice> composite_device = std::unique_ptr<CompositeDevice> composite_device =
CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/1, CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/1,
&status); parsed_host_name, &status);
EXPECT_EQ(composite_device, nullptr); EXPECT_EQ(composite_device, nullptr);
EXPECT_EQ(error::INVALID_ARGUMENT, status.code()); EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
EXPECT_TRUE( EXPECT_TRUE(
@ -68,7 +71,7 @@ TEST(CompositeDeviceTest, Basic) {
"/job:localhost/replica:0/task:0/device:GPU:0"); "/job:localhost/replica:0/task:0/device:GPU:0");
std::unique_ptr<CompositeDevice> composite_device = std::unique_ptr<CompositeDevice> composite_device =
CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/1, CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/1,
&status); parsed_host_name, &status);
EXPECT_EQ(composite_device, nullptr); EXPECT_EQ(composite_device, nullptr);
EXPECT_EQ(error::INVALID_ARGUMENT, status.code()); EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
EXPECT_TRUE(absl::StrContains(status.error_message(), EXPECT_TRUE(absl::StrContains(status.error_message(),

View File

@ -935,8 +935,11 @@ Status EagerContext::FindOrCreateCompositeDevice(
} }
Status s; Status s;
auto device = CompositeDevice::MakeDevice(underlying_devices, // Create a CompositeDevice on the same task as the host CPU, in order to
composite_devices_.size(), &s); // trigger packed TensorHandle copy from a client to a remote worker.
auto device =
CompositeDevice::MakeDevice(underlying_devices, composite_devices_.size(),
HostCPU()->parsed_name(), &s);
TF_RETURN_IF_ERROR(s); TF_RETURN_IF_ERROR(s);
*composite_device = device.get(); *composite_device = device.get();
pflr_->AddCompositeDevice(*composite_device); pflr_->AddCompositeDevice(*composite_device);

View File

@ -31,7 +31,7 @@ static Device* CreateDevice(const string& type, int n) {
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
}; };
DeviceAttributes attr; DeviceAttributes attr;
attr.set_name("/job:a/replica:0/task:0/device:" + type + ":" + attr.set_name("/job:localhost/replica:0/task:0/device:" + type + ":" +
std::to_string(n)); std::to_string(n));
attr.set_device_type(type); attr.set_device_type(type);
return new FakeDevice(attr); return new FakeDevice(attr);
@ -179,10 +179,10 @@ TEST_F(EagerContextTest, CompositeDevice) {
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices, TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
&composite_device_0)); &composite_device_0));
EXPECT_EQ(composite_device_0->name(), EXPECT_EQ(composite_device_0->name(),
"/job:worker/replica:0/task:0/device:COMPOSITE:0"); "/job:localhost/replica:0/task:0/device:COMPOSITE:0");
CompositeDevice* device = nullptr; CompositeDevice* device = nullptr;
TF_EXPECT_OK(context()->FindCompositeDeviceFromName( TF_EXPECT_OK(context()->FindCompositeDeviceFromName(
"/job:worker/replica:0/task:0/device:COMPOSITE:0", &device)); "/job:localhost/replica:0/task:0/device:COMPOSITE:0", &device));
EXPECT_EQ(device, composite_device_0); EXPECT_EQ(device, composite_device_0);
CompositeDevice* composite_device_1 = nullptr; CompositeDevice* composite_device_1 = nullptr;
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices, TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
@ -193,13 +193,13 @@ TEST_F(EagerContextTest, CompositeDevice) {
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices, TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
&composite_device_2)); &composite_device_2));
EXPECT_EQ(composite_device_2->name(), EXPECT_EQ(composite_device_2->name(),
"/job:worker/replica:0/task:0/device:COMPOSITE:1"); "/job:localhost/replica:0/task:0/device:COMPOSITE:1");
TF_EXPECT_OK(context()->FindCompositeDeviceFromName( TF_EXPECT_OK(context()->FindCompositeDeviceFromName(
"/job:worker/replica:0/task:0/device:COMPOSITE:1", &device)); "/job:localhost/replica:0/task:0/device:COMPOSITE:1", &device));
EXPECT_EQ(device, composite_device_2); EXPECT_EQ(device, composite_device_2);
EXPECT_TRUE(errors::IsNotFound(context()->FindCompositeDeviceFromName( EXPECT_TRUE(errors::IsNotFound(context()->FindCompositeDeviceFromName(
"/job:worker/replica:0/task:0/device:COMPOSITE:2", &device))); "/job:localhost/replica:0/task:0/device:COMPOSITE:2", &device)));
} }
} // namespace } // namespace

View File

@ -61,7 +61,8 @@ TEST(ExecuteNodeTest, ExecuteNodeArgs) {
Status s; Status s;
std::unique_ptr<CompositeDevice> composite_device = std::unique_ptr<CompositeDevice> composite_device =
CompositeDevice::MakeDevice({device0->name(), device1->name()}, CompositeDevice::MakeDevice({device0->name(), device1->name()},
/*unique_device_id=*/0, &s); /*unique_device_id=*/0,
device_mgr.HostCPU()->parsed_name(), &s);
TF_ASSERT_OK(s); TF_ASSERT_OK(s);
auto ctx = new EagerContext( auto ctx = new EagerContext(

View File

@ -100,6 +100,7 @@ class PackedTensorHandleTest : public ::testing::Test {
for (const char* name : device_names_) { for (const char* name : device_names_) {
devices.emplace_back(CreateDevice("GPU", name)); devices.emplace_back(CreateDevice("GPU", name));
} }
devices.emplace_back(CreateDevice("CPU", host_name_));
device_mgr_ = new StaticDeviceMgr(std::move(devices)); device_mgr_ = new StaticDeviceMgr(std::move(devices));
context_ = new EagerContext( context_ = new EagerContext(
@ -132,6 +133,8 @@ class PackedTensorHandleTest : public ::testing::Test {
"/job:worker/replica:0/task:1/device:GPU:0", "/job:worker/replica:0/task:1/device:GPU:0",
"/job:worker/replica:0/task:1/device:GPU:1"}; "/job:worker/replica:0/task:1/device:GPU:1"};
const char* host_name_ = "/job:worker/replica:0/task:0/device:CPU:0";
StaticDeviceMgr* device_mgr_; StaticDeviceMgr* device_mgr_;
EagerContext* context_; EagerContext* context_;
}; };

View File

@ -820,7 +820,8 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) {
Status s; Status s;
std::unique_ptr<CompositeDevice> composite_device = std::unique_ptr<CompositeDevice> composite_device =
CompositeDevice::MakeDevice({device0_->name(), device1_->name()}, CompositeDevice::MakeDevice({device0_->name(), device1_->name()},
/*unique_device_id=*/0, &s); /*unique_device_id=*/0,
device_mgr_->HostCPU()->parsed_name(), &s);
TF_ASSERT_OK(s); TF_ASSERT_OK(s);
AddCompositeDevice(composite_device.get()); AddCompositeDevice(composite_device.get());

View File

@ -241,6 +241,11 @@ def implicit_val_and_grad(f):
"function was being computed.") "function was being computed.")
sources = [v.handle for v in variables] sources = [v.handle for v in variables]
for s in sources:
if getattr(s, "is_packed", False):
raise ValueError(
"GradientTape.gradient is not supported on packed EagerTensors yet."
)
grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node), grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node),
sources) sources)
return end_node, list(zip(grad, variables)) return end_node, list(zip(grad, variables))
@ -548,6 +553,10 @@ def make_vjp(f, params=None, persistent=True):
] ]
args = _ensure_unique_tensor_objects(parameter_positions, args) args = _ensure_unique_tensor_objects(parameter_positions, args)
for i in parameter_positions: for i in parameter_positions:
if getattr(args[i], "is_packed", False):
raise ValueError(
"GradientTape.gradient is not supported on packed EagerTensors"
"yet.")
sources.append(args[i]) sources.append(args[i])
tape.watch(this_tape, args[i]) tape.watch(this_tape, args[i])
result = f(*args) result = f(*args)
@ -1032,6 +1041,10 @@ class GradientTape(object):
logging.WARN, "The dtype of the source tensor must be " logging.WARN, "The dtype of the source tensor must be "
"floating (e.g. tf.float32) when calling GradientTape.gradient, " "floating (e.g. tf.float32) when calling GradientTape.gradient, "
"got %r", t.dtype) "got %r", t.dtype)
if getattr(t, "is_packed", False):
raise ValueError(
"GradientTape.gradient is not supported on packed EagerTensors yet."
)
if output_gradients is not None: if output_gradients is not None:
output_gradients = [None if x is None else ops.convert_to_tensor(x) output_gradients = [None if x is None else ops.convert_to_tensor(x)

View File

@ -1123,6 +1123,22 @@ class Context(object):
pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule, pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule,
device_name, device_info_capsule) device_name, device_info_capsule)
def pack_eager_tensors(self, tensors):
"""Pack multiple `EagerTensor`s of the same dtype and shape.
Args:
tensors: a list of EagerTensors to pack.
Returns:
A packed EagerTensor.
"""
self.ensure_initialized()
if self._lazy_remote_inputs_copy is not None and (
not self._lazy_remote_inputs_copy):
raise ValueError("Packing eager tensors is not supported when "
"lazy_remote_inputs_copy is disabled.")
return pywrap_tfe.TFE_Py_PackEagerTensors(self._handle, tensors)
def remove_function(self, name): def remove_function(self, name):
"""Remove a function from the context. """Remove a function from the context.

View File

@ -186,6 +186,43 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegexp(AttributeError, 'no attribute'): with self.assertRaisesRegexp(AttributeError, 'no attribute'):
add(c) add(c)
def testPackedVariable(self):
with ops.device('/cpu:0'):
v0_0 = resource_variable_ops.ResourceVariable(1.0)
with ops.device('/cpu:1'):
v0_1 = resource_variable_ops.ResourceVariable(2.0)
v1_0 = resource_variable_ops.ResourceVariable(3.0)
with ops.device('/cpu:2'):
v1_1 = resource_variable_ops.ResourceVariable(4.0)
packed_var_0 = ops.pack_eager_tensors([v0_0.handle, v0_1.handle])
packed_var_1 = ops.pack_eager_tensors([v1_0.handle, v1_1.handle])
# TODO(b/145922293): use ResourceVariable.assign_add and
# ResourceVariable.read_value directly once we support packing multiple
# ResourceVariable into one ResourceVariable.
@def_function.function
def read_var():
resource_variable_ops.assign_add_variable_op(
packed_var_0, constant_op.constant(5.0))
resource_variable_ops.assign_add_variable_op(
packed_var_1, constant_op.constant(6.0))
with ops.device('/cpu:0'):
read0 = resource_variable_ops.read_variable_op(
packed_var_0, dtype=dtypes.float32)
with ops.device('/cpu:1'):
read1 = resource_variable_ops.read_variable_op(
packed_var_0, dtype=dtypes.float32)
read2 = resource_variable_ops.read_variable_op(
packed_var_1, dtype=dtypes.float32)
with ops.device('/cpu:2'):
read3 = resource_variable_ops.read_variable_op(
packed_var_1, dtype=dtypes.float32)
return read0, read1, read2, read3
self.assertAllEqual(read_var(), (1 + 5, 2 + 5, 3 + 6, 4 + 6))
def testImplementsAttributeBasic(self): def testImplementsAttributeBasic(self):
v = def_function.function( v = def_function.function(
experimental_implements='func')(lambda x, y: x + y) experimental_implements='func')(lambda x, y: x + y)

View File

@ -345,6 +345,8 @@ typedef struct EagerTensor {
char unused[kMaxEagerTensorParentSize]; char unused[kMaxEagerTensorParentSize];
TFE_TensorHandle* handle; TFE_TensorHandle* handle;
int64_t id; int64_t id;
// Indicates whether it's a packed tensor or not.
bool is_packed;
// This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will // This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
// be None for tensors of type other than DT_RESOURCE. For DT_RESOURCE // be None for tensors of type other than DT_RESOURCE. For DT_RESOURCE
// tensors, this will contain a serialized HandleData proto with shape // tensors, this will contain a serialized HandleData proto with shape
@ -418,6 +420,7 @@ bool MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor* created_tensor) {
int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
self->id = get_uid(); self->id = get_uid();
self->handle = nullptr; self->handle = nullptr;
self->is_packed = false;
Py_INCREF(Py_None); Py_INCREF(Py_None);
self->handle_data = Py_None; self->handle_data = Py_None;
Py_INCREF(Py_None); Py_INCREF(Py_None);
@ -647,6 +650,11 @@ static PyObject* EagerTensor_backing_device(EagerTensor* self) {
#endif #endif
} }
// Getter `is_packed`.
static PyObject* EagerTensor_is_packed(EagerTensor* self) {
return PyBool_FromLong(self->is_packed);
}
static PyGetSetDef EagerTensor_getsetters[] = { static PyGetSetDef EagerTensor_getsetters[] = {
{const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr, {const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr,
const_cast<char*>("Tensor ID."), nullptr}, const_cast<char*>("Tensor ID."), nullptr},
@ -655,6 +663,9 @@ static PyGetSetDef EagerTensor_getsetters[] = {
{const_cast<char*>("backing_device"), (getter)EagerTensor_backing_device, {const_cast<char*>("backing_device"), (getter)EagerTensor_backing_device,
nullptr, const_cast<char*>("Device on which tensor's memory is resident."), nullptr, const_cast<char*>("Device on which tensor's memory is resident."),
nullptr}, nullptr},
{const_cast<char*>("is_packed"), (getter)EagerTensor_is_packed, nullptr,
const_cast<char*>("Whether the EagerTensor is a packed tensor or not."),
nullptr},
{const_cast<char*>("_handle_data"), (getter)EagerTensor_handle_data, {const_cast<char*>("_handle_data"), (getter)EagerTensor_handle_data,
(setter)EagerTensor_sethandle_data, (setter)EagerTensor_sethandle_data,
const_cast<char*>("Shape/DType data if the EagerTensor is a DT_RESOURCE"), const_cast<char*>("Shape/DType data if the EagerTensor is a DT_RESOURCE"),
@ -813,7 +824,8 @@ TFE_TensorHandle* EagerTensor_Handle(const PyObject* o) {
return reinterpret_cast<const EagerTensor*>(o)->handle; return reinterpret_cast<const EagerTensor*>(o)->handle;
} }
PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) { PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle,
const bool is_packed) {
if (handle == nullptr) { if (handle == nullptr) {
return nullptr; return nullptr;
} }
@ -821,6 +833,7 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
EagerTensorType->tp_new(EagerTensorType, EmptyTuple(), EmptyDict())); EagerTensorType->tp_new(EagerTensorType, EmptyTuple(), EmptyDict()));
if (t != nullptr) { if (t != nullptr) {
t->id = get_uid(); t->id = get_uid();
t->is_packed = is_packed;
Py_INCREF(Py_None); Py_INCREF(Py_None);
t->handle_data = Py_None; t->handle_data = Py_None;
Py_INCREF(Py_None); Py_INCREF(Py_None);

View File

@ -129,7 +129,8 @@ void TFE_DeleteContextCapsule(PyObject* context);
bool EagerTensor_CheckExact(const PyObject* o); bool EagerTensor_CheckExact(const PyObject* o);
// Helper function to construct a new EagerTensor from a TFE_TensorHandle. // Helper function to construct a new EagerTensor from a TFE_TensorHandle.
PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle); PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle,
const bool is_packed = false);
// Extracts the handle inside EagerTensor object `o`. Returns nullptr on error. // Extracts the handle inside EagerTensor object `o`. Returns nullptr on error.
TFE_TensorHandle* EagerTensor_Handle(const PyObject* o); TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);

View File

@ -40,6 +40,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib
from tensorflow.python.training.server_lib import ClusterSpec from tensorflow.python.training.server_lib import ClusterSpec
@ -324,6 +325,36 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
def testMultiDeviceFunctionWithPackedVariable(self):
with ops.device('/job:worker/replica:0/task:0/device:CPU:0'):
var0 = resource_variable_ops.ResourceVariable(1.0)
with ops.device('/job:worker/replica:0/task:1/device:CPU:0'):
var1 = resource_variable_ops.ResourceVariable(2.0)
packed_var = ops.pack_eager_tensors([var0.handle, var1.handle])
self.assertEqual(packed_var.device,
'/job:localhost/replica:0/task:0/device:COMPOSITE:0')
self.assertEqual(packed_var.backing_device,
'/job:localhost/replica:0/task:0/device:COMPOSITE:0')
@def_function.function
def add_variables():
with ops.device('/job:worker/replica:0/task:0/device:CPU:0'):
read0 = resource_variable_ops.read_variable_op(
packed_var, dtype=dtypes.float32)
with ops.device('/job:worker/replica:0/task:1/device:CPU:0'):
read1 = resource_variable_ops.read_variable_op(
packed_var, dtype=dtypes.float32)
return read0 + read1
# Run the function on a remote device
with ops.device('/job:worker/replica:0/task:0'):
self.assertAllEqual(add_variables().numpy(), 3.0)
# Run the function on a local worker
self.assertAllEqual(add_variables().numpy(), 3.0)
@test_util.eager_lazy_remote_copy_on_and_off @test_util.eager_lazy_remote_copy_on_and_off
def testMultiDeviceFunctionOnRemoteDeviceWithWait(self): def testMultiDeviceFunctionOnRemoteDeviceWithWait(self):
with ops.device('/job:worker/replica:0/task:1'): with ops.device('/job:worker/replica:0/task:1'):

View File

@ -1394,6 +1394,65 @@ def _error_prefix(name):
return "" if name is None else "%s: " % name return "" if name is None else "%s: " % name
def pack_eager_tensors(tensors, ctx=None):
"""Pack multiple `EagerTensor`s of the same dtype and shape.
Args:
tensors: a list of EagerTensors to pack.
ctx: context.context().
Returns:
A packed EagerTensor.
"""
if not isinstance(tensors, list):
raise TypeError("tensors must be a list or a tuple: %s" % tensors)
if not tensors:
raise ValueError("Empty tensors is unexpected for packing.")
dtype = tensors[0].dtype
shape = tensors[0].shape
handle_data = tensors[0]._handle_data # pylint: disable=protected-access
is_resource = dtype == dtypes.resource
for i in range(len(tensors)):
t = tensors[i]
if not isinstance(t, EagerTensor):
raise TypeError("tensors must be a list of EagerTensors: %s" % t)
if t.dtype != dtype:
raise ValueError(
"All tensors being packed should have the same dtype %s, "
"but the %d-th tensor is of dtype %s" % (dtype, i, t.dtype))
if t.shape != shape:
raise ValueError(
"All tensors being packed should have the same shape %s, "
"but the %d-th tensor is of shape %s" % (shape, i, t.shape))
# pylint: disable=protected-access
if is_resource and t._handle_data != handle_data:
raise ValueError(
"All tensors being packed should have the same handle data %s, "
"but the %d-th tensor is of handle data %s" %
(handle_data, i, t._handle_data))
# pylint: enable=protected-access
if ctx is None:
ctx = context.context()
# Propogate handle data for resource variables
packed_tensor = ctx.pack_eager_tensors(tensors)
if handle_data is not None:
packed_tensor._handle_data = handle_data # pylint: disable=protected-access
def grad_fun(_):
raise ValueError(
"Gradients through pack_eager_tensors are not supported yet.")
tape.record_operation("pack_eager_tensors", [packed_tensor], tensors,
grad_fun)
return packed_tensor
def convert_to_tensor(value, def convert_to_tensor(value,
dtype=None, dtype=None,
name=None, name=None,

View File

@ -34,6 +34,7 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as eager_function from tensorflow.python.eager import function as eager_function
from tensorflow.python.eager import wrap_function from tensorflow.python.eager import wrap_function
from tensorflow.python.framework import config
from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev from tensorflow.python.framework import device as pydev
@ -3408,5 +3409,51 @@ class CustomConvertToCompositeTensorTest(test_util.TensorFlowTestCase):
self.assertAllEqual(x_, tensor_util.constant_value(y_)) self.assertAllEqual(x_, tensor_util.constant_value(y_))
@test_util.disable_tfrt("Packing EagerTensors is not supported yet.")
class PackEagerTensorTest(test_util.TensorFlowTestCase):
def setUp(self):
super(PackEagerTensorTest, self).setUp()
context._reset_context()
cpus = config.list_physical_devices("CPU")
# Set 2 virtual CPUs
config.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
])
def testPack(self):
with context.eager_mode():
with ops.device("CPU:0"):
var0 = resource_variable_ops.ResourceVariable(1.0)
c0 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
with ops.device("CPU:1"):
var1 = resource_variable_ops.ResourceVariable(2.0)
var2 = resource_variable_ops.ResourceVariable([3.0])
c1 = constant_op.constant([9.0])
packed_var0 = ops.pack_eager_tensors([var0.handle, var1.handle])
self.assertTrue(packed_var0.is_packed)
self.assertEqual(packed_var0.dtype, var0.handle.dtype)
self.assertEqual(packed_var0.shape, var0.handle.shape)
self.assertEqual(packed_var0._handle_data, var0.handle._handle_data)
self.assertIn("COMPOSITE:0", packed_var0.device)
self.assertIn("COMPOSITE:0", packed_var0.backing_device)
with self.assertRaises(errors.InvalidArgumentError):
packed_var0.numpy()
# Different dtypes
with self.assertRaises(ValueError):
ops.pack_eager_tensors([var0.handle, c1])
# Different shapes
with self.assertRaises(ValueError):
ops.pack_eager_tensors([c0, c1])
# Different handle data
with self.assertRaises(ValueError):
ops.pack_eager_tensors([var0.handle, var2.handle])
if __name__ == "__main__": if __name__ == "__main__":
googletest.main() googletest.main()

View File

@ -210,6 +210,22 @@ TFE_OutputTensorHandles InputTFE_OutputTensorHandles(
return output_tensor_handles; return output_tensor_handles;
} }
// Packs multiple `EagerTensor`s of the same dtype and shape into one
// `EagerTensor`.
py::object TFE_Py_PackEagerTensors_wrapper(const py::handle& context,
const py::handle& tensors) {
TFE_Context* ctx = tensorflow::InputTFE_Context(context);
TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(tensors);
tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
int size = handles.size();
TFE_TensorHandle* packed_handle =
TFE_CreatePackedTensorHandle(ctx, handles.data(), &size, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
PyObject* packed_tensor =
EagerTensorFromHandle(packed_handle, /*is_packed=*/true);
return tensorflow::PyoOrThrow(packed_tensor);
}
// This function was created from fusing the typemap logic in platform/base.i. // This function was created from fusing the typemap logic in platform/base.i.
py::object TFE_Py_ExecuteCancelable_wrapper( py::object TFE_Py_ExecuteCancelable_wrapper(
const py::handle& context, const char* device_name, const char* op_name, const py::handle& context, const char* device_name, const char* op_name,
@ -558,6 +574,10 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
m.def("TFE_Py_InitEagerTensor", [](const py::handle& o) { m.def("TFE_Py_InitEagerTensor", [](const py::handle& o) {
return tensorflow::PyoOrThrow(TFE_Py_InitEagerTensor(o.ptr())); return tensorflow::PyoOrThrow(TFE_Py_InitEagerTensor(o.ptr()));
}); });
m.def("TFE_Py_PackEagerTensors",
[](const py::handle& context, const py::handle& handles) {
return tensorflow::TFE_Py_PackEagerTensors_wrapper(context, handles);
});
m.def("TFE_Py_SetEagerTensorProfiler", &TFE_Py_SetEagerTensorProfiler); m.def("TFE_Py_SetEagerTensorProfiler", &TFE_Py_SetEagerTensorProfiler);
m.def("TFE_Py_RegisterJVPFunction", [](const py::handle& o) { m.def("TFE_Py_RegisterJVPFunction", [](const py::handle& o) {
return tensorflow::PyoOrThrow(TFE_Py_RegisterJVPFunction(o.ptr())); return tensorflow::PyoOrThrow(TFE_Py_RegisterJVPFunction(o.ptr()));