Support running a tf.function with packed variable inputs both locally and remotely.
- Support packing multiple EagerTensors of the same dtype and shape. - Create CompositeDevices on the same task as the local host CPU, in order to correctly trigger packed TensorHandle copy from a client to a remote worker. PiperOrigin-RevId: 312164194 Change-Id: Ia15718309c8c68eb645bfe0bf967ddd6d2551b3a
This commit is contained in:
parent
1a07ecf852
commit
3c54ef5ab9
|
@ -24,7 +24,7 @@ const char* const kCompositeDeviceType = "COMPOSITE";
|
||||||
|
|
||||||
std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
|
std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
|
||||||
const std::vector<string>& underlying_devices, const int unique_device_id,
|
const std::vector<string>& underlying_devices, const int unique_device_id,
|
||||||
Status* status) {
|
const DeviceNameUtils::ParsedName& host_name, Status* status) {
|
||||||
if (underlying_devices.empty()) {
|
if (underlying_devices.empty()) {
|
||||||
status->Update(
|
status->Update(
|
||||||
errors::InvalidArgument("underlying_devices should not be empty."));
|
errors::InvalidArgument("underlying_devices should not be empty."));
|
||||||
|
@ -62,13 +62,15 @@ std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DeviceNameUtils::ParsedName parsed_composite_name = host_name;
|
||||||
DeviceAttributes device_attributes;
|
DeviceAttributes device_attributes;
|
||||||
parsed_name.type = kCompositeDeviceType;
|
parsed_composite_name.type = kCompositeDeviceType;
|
||||||
device_attributes.set_device_type(parsed_name.type);
|
parsed_composite_name.id = unique_device_id;
|
||||||
parsed_name.id = unique_device_id;
|
|
||||||
const string composite_name =
|
const string composite_name =
|
||||||
DeviceNameUtils::ParsedNameToString(parsed_name);
|
DeviceNameUtils::ParsedNameToString(parsed_composite_name);
|
||||||
device_attributes.set_name(composite_name);
|
device_attributes.set_name(composite_name);
|
||||||
|
device_attributes.set_device_type(kCompositeDeviceType);
|
||||||
|
|
||||||
return absl::WrapUnique(
|
return absl::WrapUnique(
|
||||||
new CompositeDevice(device_attributes, underlying_devices));
|
new CompositeDevice(device_attributes, underlying_devices));
|
||||||
|
|
|
@ -42,10 +42,11 @@ class CompositeDevice : public Device {
|
||||||
return &underlying_devices_;
|
return &underlying_devices_;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper for creating a CompositeDevice.
|
// Helper for creating a CompositeDevice on the same task as the given host
|
||||||
|
// CPU.
|
||||||
static std::unique_ptr<CompositeDevice> MakeDevice(
|
static std::unique_ptr<CompositeDevice> MakeDevice(
|
||||||
const std::vector<string>& underlying_devices, const int unique_device_id,
|
const std::vector<string>& underlying_devices, const int unique_device_id,
|
||||||
Status* status);
|
const DeviceNameUtils::ParsedName& host_name, Status* status);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CompositeDevice(const DeviceAttributes& device_attributes,
|
CompositeDevice(const DeviceAttributes& device_attributes,
|
||||||
|
|
|
@ -20,12 +20,15 @@ limitations under the License.
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
TEST(CompositeDeviceTest, Basic) {
|
TEST(CompositeDeviceTest, Basic) {
|
||||||
|
const string host_name = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||||
|
DeviceNameUtils::ParsedName parsed_host_name;
|
||||||
|
EXPECT_TRUE(DeviceNameUtils::ParseFullName(host_name, &parsed_host_name));
|
||||||
std::vector<string> underlying_devices;
|
std::vector<string> underlying_devices;
|
||||||
{
|
{
|
||||||
Status status;
|
Status status;
|
||||||
std::unique_ptr<CompositeDevice> composite_device =
|
std::unique_ptr<CompositeDevice> composite_device =
|
||||||
CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/0,
|
CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/0,
|
||||||
&status);
|
parsed_host_name, &status);
|
||||||
EXPECT_EQ(composite_device, nullptr);
|
EXPECT_EQ(composite_device, nullptr);
|
||||||
EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
|
EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
|
||||||
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||||
|
@ -41,7 +44,7 @@ TEST(CompositeDeviceTest, Basic) {
|
||||||
"/job:localhost/replica:0/task:0/device:CPU:1");
|
"/job:localhost/replica:0/task:0/device:CPU:1");
|
||||||
std::unique_ptr<CompositeDevice> composite_device =
|
std::unique_ptr<CompositeDevice> composite_device =
|
||||||
CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/0,
|
CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/0,
|
||||||
&status);
|
parsed_host_name, &status);
|
||||||
TF_ASSERT_OK(status);
|
TF_ASSERT_OK(status);
|
||||||
EXPECT_EQ(composite_device->device_type(), kCompositeDeviceType);
|
EXPECT_EQ(composite_device->device_type(), kCompositeDeviceType);
|
||||||
EXPECT_EQ(underlying_devices, *composite_device->underlying_devices());
|
EXPECT_EQ(underlying_devices, *composite_device->underlying_devices());
|
||||||
|
@ -53,7 +56,7 @@ TEST(CompositeDeviceTest, Basic) {
|
||||||
"/job:localhost/replica:0/task:0/device:CPU:0");
|
"/job:localhost/replica:0/task:0/device:CPU:0");
|
||||||
std::unique_ptr<CompositeDevice> composite_device =
|
std::unique_ptr<CompositeDevice> composite_device =
|
||||||
CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/1,
|
CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/1,
|
||||||
&status);
|
parsed_host_name, &status);
|
||||||
EXPECT_EQ(composite_device, nullptr);
|
EXPECT_EQ(composite_device, nullptr);
|
||||||
EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
|
EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
|
@ -68,7 +71,7 @@ TEST(CompositeDeviceTest, Basic) {
|
||||||
"/job:localhost/replica:0/task:0/device:GPU:0");
|
"/job:localhost/replica:0/task:0/device:GPU:0");
|
||||||
std::unique_ptr<CompositeDevice> composite_device =
|
std::unique_ptr<CompositeDevice> composite_device =
|
||||||
CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/1,
|
CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/1,
|
||||||
&status);
|
parsed_host_name, &status);
|
||||||
EXPECT_EQ(composite_device, nullptr);
|
EXPECT_EQ(composite_device, nullptr);
|
||||||
EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
|
EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
|
||||||
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||||
|
|
|
@ -935,8 +935,11 @@ Status EagerContext::FindOrCreateCompositeDevice(
|
||||||
}
|
}
|
||||||
|
|
||||||
Status s;
|
Status s;
|
||||||
auto device = CompositeDevice::MakeDevice(underlying_devices,
|
// Create a CompositeDevice on the same task as the host CPU, in order to
|
||||||
composite_devices_.size(), &s);
|
// trigger packed TensorHandle copy from a client to a remote worker.
|
||||||
|
auto device =
|
||||||
|
CompositeDevice::MakeDevice(underlying_devices, composite_devices_.size(),
|
||||||
|
HostCPU()->parsed_name(), &s);
|
||||||
TF_RETURN_IF_ERROR(s);
|
TF_RETURN_IF_ERROR(s);
|
||||||
*composite_device = device.get();
|
*composite_device = device.get();
|
||||||
pflr_->AddCompositeDevice(*composite_device);
|
pflr_->AddCompositeDevice(*composite_device);
|
||||||
|
|
|
@ -31,7 +31,7 @@ static Device* CreateDevice(const string& type, int n) {
|
||||||
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
|
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
|
||||||
};
|
};
|
||||||
DeviceAttributes attr;
|
DeviceAttributes attr;
|
||||||
attr.set_name("/job:a/replica:0/task:0/device:" + type + ":" +
|
attr.set_name("/job:localhost/replica:0/task:0/device:" + type + ":" +
|
||||||
std::to_string(n));
|
std::to_string(n));
|
||||||
attr.set_device_type(type);
|
attr.set_device_type(type);
|
||||||
return new FakeDevice(attr);
|
return new FakeDevice(attr);
|
||||||
|
@ -179,10 +179,10 @@ TEST_F(EagerContextTest, CompositeDevice) {
|
||||||
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
||||||
&composite_device_0));
|
&composite_device_0));
|
||||||
EXPECT_EQ(composite_device_0->name(),
|
EXPECT_EQ(composite_device_0->name(),
|
||||||
"/job:worker/replica:0/task:0/device:COMPOSITE:0");
|
"/job:localhost/replica:0/task:0/device:COMPOSITE:0");
|
||||||
CompositeDevice* device = nullptr;
|
CompositeDevice* device = nullptr;
|
||||||
TF_EXPECT_OK(context()->FindCompositeDeviceFromName(
|
TF_EXPECT_OK(context()->FindCompositeDeviceFromName(
|
||||||
"/job:worker/replica:0/task:0/device:COMPOSITE:0", &device));
|
"/job:localhost/replica:0/task:0/device:COMPOSITE:0", &device));
|
||||||
EXPECT_EQ(device, composite_device_0);
|
EXPECT_EQ(device, composite_device_0);
|
||||||
CompositeDevice* composite_device_1 = nullptr;
|
CompositeDevice* composite_device_1 = nullptr;
|
||||||
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
||||||
|
@ -193,13 +193,13 @@ TEST_F(EagerContextTest, CompositeDevice) {
|
||||||
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
||||||
&composite_device_2));
|
&composite_device_2));
|
||||||
EXPECT_EQ(composite_device_2->name(),
|
EXPECT_EQ(composite_device_2->name(),
|
||||||
"/job:worker/replica:0/task:0/device:COMPOSITE:1");
|
"/job:localhost/replica:0/task:0/device:COMPOSITE:1");
|
||||||
TF_EXPECT_OK(context()->FindCompositeDeviceFromName(
|
TF_EXPECT_OK(context()->FindCompositeDeviceFromName(
|
||||||
"/job:worker/replica:0/task:0/device:COMPOSITE:1", &device));
|
"/job:localhost/replica:0/task:0/device:COMPOSITE:1", &device));
|
||||||
EXPECT_EQ(device, composite_device_2);
|
EXPECT_EQ(device, composite_device_2);
|
||||||
|
|
||||||
EXPECT_TRUE(errors::IsNotFound(context()->FindCompositeDeviceFromName(
|
EXPECT_TRUE(errors::IsNotFound(context()->FindCompositeDeviceFromName(
|
||||||
"/job:worker/replica:0/task:0/device:COMPOSITE:2", &device)));
|
"/job:localhost/replica:0/task:0/device:COMPOSITE:2", &device)));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -61,7 +61,8 @@ TEST(ExecuteNodeTest, ExecuteNodeArgs) {
|
||||||
Status s;
|
Status s;
|
||||||
std::unique_ptr<CompositeDevice> composite_device =
|
std::unique_ptr<CompositeDevice> composite_device =
|
||||||
CompositeDevice::MakeDevice({device0->name(), device1->name()},
|
CompositeDevice::MakeDevice({device0->name(), device1->name()},
|
||||||
/*unique_device_id=*/0, &s);
|
/*unique_device_id=*/0,
|
||||||
|
device_mgr.HostCPU()->parsed_name(), &s);
|
||||||
TF_ASSERT_OK(s);
|
TF_ASSERT_OK(s);
|
||||||
|
|
||||||
auto ctx = new EagerContext(
|
auto ctx = new EagerContext(
|
||||||
|
|
|
@ -100,6 +100,7 @@ class PackedTensorHandleTest : public ::testing::Test {
|
||||||
for (const char* name : device_names_) {
|
for (const char* name : device_names_) {
|
||||||
devices.emplace_back(CreateDevice("GPU", name));
|
devices.emplace_back(CreateDevice("GPU", name));
|
||||||
}
|
}
|
||||||
|
devices.emplace_back(CreateDevice("CPU", host_name_));
|
||||||
device_mgr_ = new StaticDeviceMgr(std::move(devices));
|
device_mgr_ = new StaticDeviceMgr(std::move(devices));
|
||||||
|
|
||||||
context_ = new EagerContext(
|
context_ = new EagerContext(
|
||||||
|
@ -132,6 +133,8 @@ class PackedTensorHandleTest : public ::testing::Test {
|
||||||
"/job:worker/replica:0/task:1/device:GPU:0",
|
"/job:worker/replica:0/task:1/device:GPU:0",
|
||||||
"/job:worker/replica:0/task:1/device:GPU:1"};
|
"/job:worker/replica:0/task:1/device:GPU:1"};
|
||||||
|
|
||||||
|
const char* host_name_ = "/job:worker/replica:0/task:0/device:CPU:0";
|
||||||
|
|
||||||
StaticDeviceMgr* device_mgr_;
|
StaticDeviceMgr* device_mgr_;
|
||||||
EagerContext* context_;
|
EagerContext* context_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -820,7 +820,8 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) {
|
||||||
Status s;
|
Status s;
|
||||||
std::unique_ptr<CompositeDevice> composite_device =
|
std::unique_ptr<CompositeDevice> composite_device =
|
||||||
CompositeDevice::MakeDevice({device0_->name(), device1_->name()},
|
CompositeDevice::MakeDevice({device0_->name(), device1_->name()},
|
||||||
/*unique_device_id=*/0, &s);
|
/*unique_device_id=*/0,
|
||||||
|
device_mgr_->HostCPU()->parsed_name(), &s);
|
||||||
TF_ASSERT_OK(s);
|
TF_ASSERT_OK(s);
|
||||||
AddCompositeDevice(composite_device.get());
|
AddCompositeDevice(composite_device.get());
|
||||||
|
|
||||||
|
|
|
@ -241,6 +241,11 @@ def implicit_val_and_grad(f):
|
||||||
"function was being computed.")
|
"function was being computed.")
|
||||||
|
|
||||||
sources = [v.handle for v in variables]
|
sources = [v.handle for v in variables]
|
||||||
|
for s in sources:
|
||||||
|
if getattr(s, "is_packed", False):
|
||||||
|
raise ValueError(
|
||||||
|
"GradientTape.gradient is not supported on packed EagerTensors yet."
|
||||||
|
)
|
||||||
grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node),
|
grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node),
|
||||||
sources)
|
sources)
|
||||||
return end_node, list(zip(grad, variables))
|
return end_node, list(zip(grad, variables))
|
||||||
|
@ -548,6 +553,10 @@ def make_vjp(f, params=None, persistent=True):
|
||||||
]
|
]
|
||||||
args = _ensure_unique_tensor_objects(parameter_positions, args)
|
args = _ensure_unique_tensor_objects(parameter_positions, args)
|
||||||
for i in parameter_positions:
|
for i in parameter_positions:
|
||||||
|
if getattr(args[i], "is_packed", False):
|
||||||
|
raise ValueError(
|
||||||
|
"GradientTape.gradient is not supported on packed EagerTensors"
|
||||||
|
"yet.")
|
||||||
sources.append(args[i])
|
sources.append(args[i])
|
||||||
tape.watch(this_tape, args[i])
|
tape.watch(this_tape, args[i])
|
||||||
result = f(*args)
|
result = f(*args)
|
||||||
|
@ -1032,6 +1041,10 @@ class GradientTape(object):
|
||||||
logging.WARN, "The dtype of the source tensor must be "
|
logging.WARN, "The dtype of the source tensor must be "
|
||||||
"floating (e.g. tf.float32) when calling GradientTape.gradient, "
|
"floating (e.g. tf.float32) when calling GradientTape.gradient, "
|
||||||
"got %r", t.dtype)
|
"got %r", t.dtype)
|
||||||
|
if getattr(t, "is_packed", False):
|
||||||
|
raise ValueError(
|
||||||
|
"GradientTape.gradient is not supported on packed EagerTensors yet."
|
||||||
|
)
|
||||||
|
|
||||||
if output_gradients is not None:
|
if output_gradients is not None:
|
||||||
output_gradients = [None if x is None else ops.convert_to_tensor(x)
|
output_gradients = [None if x is None else ops.convert_to_tensor(x)
|
||||||
|
|
|
@ -1123,6 +1123,22 @@ class Context(object):
|
||||||
pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule,
|
pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule,
|
||||||
device_name, device_info_capsule)
|
device_name, device_info_capsule)
|
||||||
|
|
||||||
|
def pack_eager_tensors(self, tensors):
|
||||||
|
"""Pack multiple `EagerTensor`s of the same dtype and shape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensors: a list of EagerTensors to pack.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A packed EagerTensor.
|
||||||
|
"""
|
||||||
|
self.ensure_initialized()
|
||||||
|
if self._lazy_remote_inputs_copy is not None and (
|
||||||
|
not self._lazy_remote_inputs_copy):
|
||||||
|
raise ValueError("Packing eager tensors is not supported when "
|
||||||
|
"lazy_remote_inputs_copy is disabled.")
|
||||||
|
return pywrap_tfe.TFE_Py_PackEagerTensors(self._handle, tensors)
|
||||||
|
|
||||||
def remove_function(self, name):
|
def remove_function(self, name):
|
||||||
"""Remove a function from the context.
|
"""Remove a function from the context.
|
||||||
|
|
||||||
|
|
|
@ -186,6 +186,43 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||||
with self.assertRaisesRegexp(AttributeError, 'no attribute'):
|
with self.assertRaisesRegexp(AttributeError, 'no attribute'):
|
||||||
add(c)
|
add(c)
|
||||||
|
|
||||||
|
def testPackedVariable(self):
|
||||||
|
with ops.device('/cpu:0'):
|
||||||
|
v0_0 = resource_variable_ops.ResourceVariable(1.0)
|
||||||
|
with ops.device('/cpu:1'):
|
||||||
|
v0_1 = resource_variable_ops.ResourceVariable(2.0)
|
||||||
|
v1_0 = resource_variable_ops.ResourceVariable(3.0)
|
||||||
|
with ops.device('/cpu:2'):
|
||||||
|
v1_1 = resource_variable_ops.ResourceVariable(4.0)
|
||||||
|
|
||||||
|
packed_var_0 = ops.pack_eager_tensors([v0_0.handle, v0_1.handle])
|
||||||
|
packed_var_1 = ops.pack_eager_tensors([v1_0.handle, v1_1.handle])
|
||||||
|
|
||||||
|
# TODO(b/145922293): use ResourceVariable.assign_add and
|
||||||
|
# ResourceVariable.read_value directly once we support packing multiple
|
||||||
|
# ResourceVariable into one ResourceVariable.
|
||||||
|
@def_function.function
|
||||||
|
def read_var():
|
||||||
|
resource_variable_ops.assign_add_variable_op(
|
||||||
|
packed_var_0, constant_op.constant(5.0))
|
||||||
|
resource_variable_ops.assign_add_variable_op(
|
||||||
|
packed_var_1, constant_op.constant(6.0))
|
||||||
|
with ops.device('/cpu:0'):
|
||||||
|
read0 = resource_variable_ops.read_variable_op(
|
||||||
|
packed_var_0, dtype=dtypes.float32)
|
||||||
|
with ops.device('/cpu:1'):
|
||||||
|
read1 = resource_variable_ops.read_variable_op(
|
||||||
|
packed_var_0, dtype=dtypes.float32)
|
||||||
|
read2 = resource_variable_ops.read_variable_op(
|
||||||
|
packed_var_1, dtype=dtypes.float32)
|
||||||
|
with ops.device('/cpu:2'):
|
||||||
|
read3 = resource_variable_ops.read_variable_op(
|
||||||
|
packed_var_1, dtype=dtypes.float32)
|
||||||
|
|
||||||
|
return read0, read1, read2, read3
|
||||||
|
|
||||||
|
self.assertAllEqual(read_var(), (1 + 5, 2 + 5, 3 + 6, 4 + 6))
|
||||||
|
|
||||||
def testImplementsAttributeBasic(self):
|
def testImplementsAttributeBasic(self):
|
||||||
v = def_function.function(
|
v = def_function.function(
|
||||||
experimental_implements='func')(lambda x, y: x + y)
|
experimental_implements='func')(lambda x, y: x + y)
|
||||||
|
|
|
@ -345,6 +345,8 @@ typedef struct EagerTensor {
|
||||||
char unused[kMaxEagerTensorParentSize];
|
char unused[kMaxEagerTensorParentSize];
|
||||||
TFE_TensorHandle* handle;
|
TFE_TensorHandle* handle;
|
||||||
int64_t id;
|
int64_t id;
|
||||||
|
// Indicates whether it's a packed tensor or not.
|
||||||
|
bool is_packed;
|
||||||
// This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
|
// This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
|
||||||
// be None for tensors of type other than DT_RESOURCE. For DT_RESOURCE
|
// be None for tensors of type other than DT_RESOURCE. For DT_RESOURCE
|
||||||
// tensors, this will contain a serialized HandleData proto with shape
|
// tensors, this will contain a serialized HandleData proto with shape
|
||||||
|
@ -418,6 +420,7 @@ bool MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor* created_tensor) {
|
||||||
int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
|
int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
|
||||||
self->id = get_uid();
|
self->id = get_uid();
|
||||||
self->handle = nullptr;
|
self->handle = nullptr;
|
||||||
|
self->is_packed = false;
|
||||||
Py_INCREF(Py_None);
|
Py_INCREF(Py_None);
|
||||||
self->handle_data = Py_None;
|
self->handle_data = Py_None;
|
||||||
Py_INCREF(Py_None);
|
Py_INCREF(Py_None);
|
||||||
|
@ -647,6 +650,11 @@ static PyObject* EagerTensor_backing_device(EagerTensor* self) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Getter `is_packed`.
|
||||||
|
static PyObject* EagerTensor_is_packed(EagerTensor* self) {
|
||||||
|
return PyBool_FromLong(self->is_packed);
|
||||||
|
}
|
||||||
|
|
||||||
static PyGetSetDef EagerTensor_getsetters[] = {
|
static PyGetSetDef EagerTensor_getsetters[] = {
|
||||||
{const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr,
|
{const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr,
|
||||||
const_cast<char*>("Tensor ID."), nullptr},
|
const_cast<char*>("Tensor ID."), nullptr},
|
||||||
|
@ -655,6 +663,9 @@ static PyGetSetDef EagerTensor_getsetters[] = {
|
||||||
{const_cast<char*>("backing_device"), (getter)EagerTensor_backing_device,
|
{const_cast<char*>("backing_device"), (getter)EagerTensor_backing_device,
|
||||||
nullptr, const_cast<char*>("Device on which tensor's memory is resident."),
|
nullptr, const_cast<char*>("Device on which tensor's memory is resident."),
|
||||||
nullptr},
|
nullptr},
|
||||||
|
{const_cast<char*>("is_packed"), (getter)EagerTensor_is_packed, nullptr,
|
||||||
|
const_cast<char*>("Whether the EagerTensor is a packed tensor or not."),
|
||||||
|
nullptr},
|
||||||
{const_cast<char*>("_handle_data"), (getter)EagerTensor_handle_data,
|
{const_cast<char*>("_handle_data"), (getter)EagerTensor_handle_data,
|
||||||
(setter)EagerTensor_sethandle_data,
|
(setter)EagerTensor_sethandle_data,
|
||||||
const_cast<char*>("Shape/DType data if the EagerTensor is a DT_RESOURCE"),
|
const_cast<char*>("Shape/DType data if the EagerTensor is a DT_RESOURCE"),
|
||||||
|
@ -813,7 +824,8 @@ TFE_TensorHandle* EagerTensor_Handle(const PyObject* o) {
|
||||||
return reinterpret_cast<const EagerTensor*>(o)->handle;
|
return reinterpret_cast<const EagerTensor*>(o)->handle;
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
|
PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle,
|
||||||
|
const bool is_packed) {
|
||||||
if (handle == nullptr) {
|
if (handle == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -821,6 +833,7 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
|
||||||
EagerTensorType->tp_new(EagerTensorType, EmptyTuple(), EmptyDict()));
|
EagerTensorType->tp_new(EagerTensorType, EmptyTuple(), EmptyDict()));
|
||||||
if (t != nullptr) {
|
if (t != nullptr) {
|
||||||
t->id = get_uid();
|
t->id = get_uid();
|
||||||
|
t->is_packed = is_packed;
|
||||||
Py_INCREF(Py_None);
|
Py_INCREF(Py_None);
|
||||||
t->handle_data = Py_None;
|
t->handle_data = Py_None;
|
||||||
Py_INCREF(Py_None);
|
Py_INCREF(Py_None);
|
||||||
|
|
|
@ -129,7 +129,8 @@ void TFE_DeleteContextCapsule(PyObject* context);
|
||||||
bool EagerTensor_CheckExact(const PyObject* o);
|
bool EagerTensor_CheckExact(const PyObject* o);
|
||||||
|
|
||||||
// Helper function to construct a new EagerTensor from a TFE_TensorHandle.
|
// Helper function to construct a new EagerTensor from a TFE_TensorHandle.
|
||||||
PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle);
|
PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle,
|
||||||
|
const bool is_packed = false);
|
||||||
|
|
||||||
// Extracts the handle inside EagerTensor object `o`. Returns nullptr on error.
|
// Extracts the handle inside EagerTensor object `o`. Returns nullptr on error.
|
||||||
TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
|
TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
|
||||||
|
|
|
@ -40,6 +40,7 @@ from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import functional_ops
|
from tensorflow.python.ops import functional_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
from tensorflow.python.training.server_lib import ClusterSpec
|
from tensorflow.python.training.server_lib import ClusterSpec
|
||||||
|
@ -324,6 +325,36 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
||||||
|
|
||||||
|
def testMultiDeviceFunctionWithPackedVariable(self):
|
||||||
|
with ops.device('/job:worker/replica:0/task:0/device:CPU:0'):
|
||||||
|
var0 = resource_variable_ops.ResourceVariable(1.0)
|
||||||
|
with ops.device('/job:worker/replica:0/task:1/device:CPU:0'):
|
||||||
|
var1 = resource_variable_ops.ResourceVariable(2.0)
|
||||||
|
|
||||||
|
packed_var = ops.pack_eager_tensors([var0.handle, var1.handle])
|
||||||
|
self.assertEqual(packed_var.device,
|
||||||
|
'/job:localhost/replica:0/task:0/device:COMPOSITE:0')
|
||||||
|
self.assertEqual(packed_var.backing_device,
|
||||||
|
'/job:localhost/replica:0/task:0/device:COMPOSITE:0')
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def add_variables():
|
||||||
|
with ops.device('/job:worker/replica:0/task:0/device:CPU:0'):
|
||||||
|
read0 = resource_variable_ops.read_variable_op(
|
||||||
|
packed_var, dtype=dtypes.float32)
|
||||||
|
with ops.device('/job:worker/replica:0/task:1/device:CPU:0'):
|
||||||
|
read1 = resource_variable_ops.read_variable_op(
|
||||||
|
packed_var, dtype=dtypes.float32)
|
||||||
|
|
||||||
|
return read0 + read1
|
||||||
|
|
||||||
|
# Run the function on a remote device
|
||||||
|
with ops.device('/job:worker/replica:0/task:0'):
|
||||||
|
self.assertAllEqual(add_variables().numpy(), 3.0)
|
||||||
|
|
||||||
|
# Run the function on a local worker
|
||||||
|
self.assertAllEqual(add_variables().numpy(), 3.0)
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
@test_util.eager_lazy_remote_copy_on_and_off
|
||||||
def testMultiDeviceFunctionOnRemoteDeviceWithWait(self):
|
def testMultiDeviceFunctionOnRemoteDeviceWithWait(self):
|
||||||
with ops.device('/job:worker/replica:0/task:1'):
|
with ops.device('/job:worker/replica:0/task:1'):
|
||||||
|
|
|
@ -1394,6 +1394,65 @@ def _error_prefix(name):
|
||||||
return "" if name is None else "%s: " % name
|
return "" if name is None else "%s: " % name
|
||||||
|
|
||||||
|
|
||||||
|
def pack_eager_tensors(tensors, ctx=None):
|
||||||
|
"""Pack multiple `EagerTensor`s of the same dtype and shape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensors: a list of EagerTensors to pack.
|
||||||
|
ctx: context.context().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A packed EagerTensor.
|
||||||
|
"""
|
||||||
|
if not isinstance(tensors, list):
|
||||||
|
raise TypeError("tensors must be a list or a tuple: %s" % tensors)
|
||||||
|
|
||||||
|
if not tensors:
|
||||||
|
raise ValueError("Empty tensors is unexpected for packing.")
|
||||||
|
|
||||||
|
dtype = tensors[0].dtype
|
||||||
|
shape = tensors[0].shape
|
||||||
|
handle_data = tensors[0]._handle_data # pylint: disable=protected-access
|
||||||
|
is_resource = dtype == dtypes.resource
|
||||||
|
for i in range(len(tensors)):
|
||||||
|
t = tensors[i]
|
||||||
|
if not isinstance(t, EagerTensor):
|
||||||
|
raise TypeError("tensors must be a list of EagerTensors: %s" % t)
|
||||||
|
|
||||||
|
if t.dtype != dtype:
|
||||||
|
raise ValueError(
|
||||||
|
"All tensors being packed should have the same dtype %s, "
|
||||||
|
"but the %d-th tensor is of dtype %s" % (dtype, i, t.dtype))
|
||||||
|
if t.shape != shape:
|
||||||
|
raise ValueError(
|
||||||
|
"All tensors being packed should have the same shape %s, "
|
||||||
|
"but the %d-th tensor is of shape %s" % (shape, i, t.shape))
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
if is_resource and t._handle_data != handle_data:
|
||||||
|
raise ValueError(
|
||||||
|
"All tensors being packed should have the same handle data %s, "
|
||||||
|
"but the %d-th tensor is of handle data %s" %
|
||||||
|
(handle_data, i, t._handle_data))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
if ctx is None:
|
||||||
|
ctx = context.context()
|
||||||
|
|
||||||
|
# Propogate handle data for resource variables
|
||||||
|
packed_tensor = ctx.pack_eager_tensors(tensors)
|
||||||
|
if handle_data is not None:
|
||||||
|
packed_tensor._handle_data = handle_data # pylint: disable=protected-access
|
||||||
|
|
||||||
|
def grad_fun(_):
|
||||||
|
raise ValueError(
|
||||||
|
"Gradients through pack_eager_tensors are not supported yet.")
|
||||||
|
|
||||||
|
tape.record_operation("pack_eager_tensors", [packed_tensor], tensors,
|
||||||
|
grad_fun)
|
||||||
|
|
||||||
|
return packed_tensor
|
||||||
|
|
||||||
|
|
||||||
def convert_to_tensor(value,
|
def convert_to_tensor(value,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
name=None,
|
name=None,
|
||||||
|
|
|
@ -34,6 +34,7 @@ from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import function as eager_function
|
from tensorflow.python.eager import function as eager_function
|
||||||
from tensorflow.python.eager import wrap_function
|
from tensorflow.python.eager import wrap_function
|
||||||
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import composite_tensor
|
from tensorflow.python.framework import composite_tensor
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import device as pydev
|
from tensorflow.python.framework import device as pydev
|
||||||
|
@ -3408,5 +3409,51 @@ class CustomConvertToCompositeTensorTest(test_util.TensorFlowTestCase):
|
||||||
self.assertAllEqual(x_, tensor_util.constant_value(y_))
|
self.assertAllEqual(x_, tensor_util.constant_value(y_))
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.disable_tfrt("Packing EagerTensors is not supported yet.")
|
||||||
|
class PackEagerTensorTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(PackEagerTensorTest, self).setUp()
|
||||||
|
context._reset_context()
|
||||||
|
cpus = config.list_physical_devices("CPU")
|
||||||
|
# Set 2 virtual CPUs
|
||||||
|
config.set_logical_device_configuration(cpus[0], [
|
||||||
|
context.LogicalDeviceConfiguration(),
|
||||||
|
context.LogicalDeviceConfiguration(),
|
||||||
|
])
|
||||||
|
|
||||||
|
def testPack(self):
|
||||||
|
with context.eager_mode():
|
||||||
|
with ops.device("CPU:0"):
|
||||||
|
var0 = resource_variable_ops.ResourceVariable(1.0)
|
||||||
|
c0 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
with ops.device("CPU:1"):
|
||||||
|
var1 = resource_variable_ops.ResourceVariable(2.0)
|
||||||
|
var2 = resource_variable_ops.ResourceVariable([3.0])
|
||||||
|
c1 = constant_op.constant([9.0])
|
||||||
|
|
||||||
|
packed_var0 = ops.pack_eager_tensors([var0.handle, var1.handle])
|
||||||
|
self.assertTrue(packed_var0.is_packed)
|
||||||
|
self.assertEqual(packed_var0.dtype, var0.handle.dtype)
|
||||||
|
self.assertEqual(packed_var0.shape, var0.handle.shape)
|
||||||
|
self.assertEqual(packed_var0._handle_data, var0.handle._handle_data)
|
||||||
|
self.assertIn("COMPOSITE:0", packed_var0.device)
|
||||||
|
self.assertIn("COMPOSITE:0", packed_var0.backing_device)
|
||||||
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
|
packed_var0.numpy()
|
||||||
|
|
||||||
|
# Different dtypes
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
ops.pack_eager_tensors([var0.handle, c1])
|
||||||
|
|
||||||
|
# Different shapes
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
ops.pack_eager_tensors([c0, c1])
|
||||||
|
|
||||||
|
# Different handle data
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
ops.pack_eager_tensors([var0.handle, var2.handle])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
|
|
@ -210,6 +210,22 @@ TFE_OutputTensorHandles InputTFE_OutputTensorHandles(
|
||||||
return output_tensor_handles;
|
return output_tensor_handles;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Packs multiple `EagerTensor`s of the same dtype and shape into one
|
||||||
|
// `EagerTensor`.
|
||||||
|
py::object TFE_Py_PackEagerTensors_wrapper(const py::handle& context,
|
||||||
|
const py::handle& tensors) {
|
||||||
|
TFE_Context* ctx = tensorflow::InputTFE_Context(context);
|
||||||
|
TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(tensors);
|
||||||
|
tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
|
||||||
|
int size = handles.size();
|
||||||
|
TFE_TensorHandle* packed_handle =
|
||||||
|
TFE_CreatePackedTensorHandle(ctx, handles.data(), &size, status.get());
|
||||||
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
|
PyObject* packed_tensor =
|
||||||
|
EagerTensorFromHandle(packed_handle, /*is_packed=*/true);
|
||||||
|
return tensorflow::PyoOrThrow(packed_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
// This function was created from fusing the typemap logic in platform/base.i.
|
// This function was created from fusing the typemap logic in platform/base.i.
|
||||||
py::object TFE_Py_ExecuteCancelable_wrapper(
|
py::object TFE_Py_ExecuteCancelable_wrapper(
|
||||||
const py::handle& context, const char* device_name, const char* op_name,
|
const py::handle& context, const char* device_name, const char* op_name,
|
||||||
|
@ -558,6 +574,10 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||||
m.def("TFE_Py_InitEagerTensor", [](const py::handle& o) {
|
m.def("TFE_Py_InitEagerTensor", [](const py::handle& o) {
|
||||||
return tensorflow::PyoOrThrow(TFE_Py_InitEagerTensor(o.ptr()));
|
return tensorflow::PyoOrThrow(TFE_Py_InitEagerTensor(o.ptr()));
|
||||||
});
|
});
|
||||||
|
m.def("TFE_Py_PackEagerTensors",
|
||||||
|
[](const py::handle& context, const py::handle& handles) {
|
||||||
|
return tensorflow::TFE_Py_PackEagerTensors_wrapper(context, handles);
|
||||||
|
});
|
||||||
m.def("TFE_Py_SetEagerTensorProfiler", &TFE_Py_SetEagerTensorProfiler);
|
m.def("TFE_Py_SetEagerTensorProfiler", &TFE_Py_SetEagerTensorProfiler);
|
||||||
m.def("TFE_Py_RegisterJVPFunction", [](const py::handle& o) {
|
m.def("TFE_Py_RegisterJVPFunction", [](const py::handle& o) {
|
||||||
return tensorflow::PyoOrThrow(TFE_Py_RegisterJVPFunction(o.ptr()));
|
return tensorflow::PyoOrThrow(TFE_Py_RegisterJVPFunction(o.ptr()));
|
||||||
|
|
Loading…
Reference in New Issue