Allow components of parallel device tensors to have different shapes

Throws an error if the shape of the overall tensor is queried for now. The plumbing required to make the shape information look like not-fully-defined-shape graph tensors looks very shallow if we want to go that route.

This means that querying the shape of a parallel tensor is now a blocking operation (and needs a status return) rather than creation itself blocking.

PiperOrigin-RevId: 351907155
Change-Id: I2610613efd4bb6aafa44fc78ee53824fb6020b6a
This commit is contained in:
Allen Lavoie 2021-01-14 17:03:52 -08:00 committed by TensorFlower Gardener
parent 5fb1d0e838
commit 571c19440d
14 changed files with 299 additions and 73 deletions

View File

@ -132,6 +132,7 @@ filegroup(
"tfe_monitoring_internal.h",
"tfe_op_attrs_internal.h",
"tfe_tensor_debug_info_internal.h",
"tfe_tensorhandle_internal.h",
],
visibility = [
"//tensorflow/core:__pkg__",

View File

@ -484,43 +484,73 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
class CAPICustomDeviceTensorHandle
: public tensorflow::CustomDeviceTensorHandle {
public:
using NumDimsCallback = std::function<int(TF_Status* status)>;
using DimCallback = std::function<int64_t(int dim_index, TF_Status* status)>;
using DeallocatorCallback = std::function<void()>;
CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext* context,
tensorflow::CustomDevice* device,
tensorflow::DataType dtype, void* data,
size_t len, std::vector<tensorflow::int64> shape,
void (*deallocator)(void* data, size_t len,
void* arg),
void* deallocator_arg)
NumDimsCallback num_dims_callback,
DimCallback dim_callback,
DeallocatorCallback deallocator)
: tensorflow::CustomDeviceTensorHandle(context, device, dtype),
data_(data),
len_(len),
shape_(shape),
deallocator_(deallocator),
deallocator_arg_(deallocator_arg) {}
~CAPICustomDeviceTensorHandle() override {
deallocator_(data_, len_, deallocator_arg_);
}
num_dims_callback_(num_dims_callback),
dim_callback_(dim_callback),
deallocator_(deallocator) {}
~CAPICustomDeviceTensorHandle() override { deallocator_(); }
void* DevicePointer() const override { return data_; }
Status NumDims(int* num_dims) const override {
*num_dims = shape_.size();
return Status::OK();
TF_Status s;
*num_dims = num_dims_callback_(&s);
return s.status;
}
Status Dim(int dim_index, int64* dim) const override {
*dim = shape_[dim_index];
return Status::OK();
TF_Status s;
*dim = dim_callback_(dim_index, &s);
return s.status;
}
private:
void* const data_;
size_t len_;
std::vector<tensorflow::int64> shape_;
void (*const deallocator_)(void* data, size_t len, void* arg);
void* const deallocator_arg_;
NumDimsCallback num_dims_callback_;
DimCallback dim_callback_;
DeallocatorCallback deallocator_;
};
} // namespace
} // namespace tensorflow
TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle(
TFE_Context* ctx, const char* device_name, TF_DataType dtype, void* data,
int (*num_dims_callback)(void* data, void* arg, TF_Status* status),
int64_t (*dim_callback)(void* data, int dim_index, void* arg,
TF_Status* status),
void (*deallocator)(void* data, void* arg), void* arg, TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::CustomDevice* device = nullptr;
if (!context->FindCustomDeviceFromName(device_name, &device)) {
deallocator(data, arg);
status->status =
tensorflow::errors::InvalidArgument(device_name, " unknown device.");
return nullptr;
}
return tensorflow::wrap(new tensorflow::CAPICustomDeviceTensorHandle(
context, device, *reinterpret_cast<tensorflow::DataType*>(&dtype), data,
/*num_dims_callback=*/
[num_dims_callback, data, arg](TF_Status* status) {
return num_dims_callback(data, arg, status);
},
/*dim_callback=*/
[dim_callback, data, arg](int dim_index, TF_Status* status) {
return dim_callback(data, dim_index, arg, status);
},
/*deallocator=*/[deallocator, data, arg]() { deallocator(data, arg); }));
}
TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
TFE_Context* ctx, const char* device_name, TF_DataType dtype,
const int64_t* dims, int num_dims, void* data, size_t len,
@ -548,8 +578,17 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
if (custom_device != nullptr) {
return tensorflow::wrap(new tensorflow::CAPICustomDeviceTensorHandle(
context, custom_device,
*reinterpret_cast<tensorflow::DataType*>(&dtype), data, len, dimvec,
deallocator, deallocator_arg));
*reinterpret_cast<tensorflow::DataType*>(&dtype), data,
/*num_dims_callback=*/
[num_dims](TF_Status* status) { return num_dims; },
/*dim_callback=*/
[dimvec](int dim_index, TF_Status* status) {
return dimvec[dim_index];
},
/*deallocator=*/
[data, len, deallocator, deallocator_arg]() {
deallocator(data, len, deallocator_arg);
}));
}
// TODO(apassos) do we need to wrap the deallocator here to make sure to sync

View File

@ -517,6 +517,29 @@ TF_CAPI_EXPORT extern void TFE_RegisterCustomDevice(TFE_Context* ctx,
void* device_info,
TF_Status* status);
// Creates a new TensorHandle from memory residing in a custom device. Takes
// ownership of the memory, and will call `deallocator` to release it after TF
// no longer needs it or in case of error.
//
// `num_dims_callback` is a callback computing the rank of the tensor, and
// `dim_callback` computes the axis length at `dim_index`. Shapes are specified
// via callbacks because retrieving the shape of a tensor is a blocking
// operation for async eager; custom devices should avoid retrieving shapes of
// tensors they wrap until the custom device tensor's shape is explicitly
// requested where possible.
//
// `arg` is passed to the callbacks unmodified for any extra information the
// caller needs to provide them.
//
// This call is similar to `TFE_NewTensorHandleFromDeviceMemory`, but does not
// require blocking waiting for exact shapes.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle(
TFE_Context* ctx, const char* device_name, TF_DataType, void* data,
int (*num_dims_callback)(void* data, void* arg, TF_Status* status),
int64_t (*dim_callback)(void* data, int dim_index, void* arg,
TF_Status* status),
void (*deallocator)(void* data, void* arg), void* arg, TF_Status* status);
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
const char* function_name,
TF_Buffer* buf,

View File

@ -59,6 +59,7 @@ cc_library(
deps = [
":parallel_device_lib",
"//tensorflow/c:c_api",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"@com_google_absl//absl/strings",
@ -74,8 +75,10 @@ cc_library(
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c:tf_status_internal",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"@com_google_absl//absl/types:optional",
@ -89,6 +92,7 @@ tf_cc_test(
srcs = ["parallel_device_lib_test.cc"],
deps = [
":parallel_device_lib",
":parallel_device_testlib",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
@ -106,6 +110,7 @@ cc_library(
hdrs = ["parallel_device_testlib.h"],
deps = [
":parallel_device",
":parallel_device_lib",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
namespace tensorflow {
namespace parallel_device {
@ -177,13 +178,38 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
return result;
}
// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how
// Used as an argument to TFE_NewCustomDeviceTensorHandle, indicating how
// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
// reference counts drop to zero.
void ParallelTensorDeallocator(void* data, size_t len, void* arg) {
void ParallelTensorDeallocator(void* data, void* arg) {
delete reinterpret_cast<ParallelTensor*>(data);
}
// Used as an argument to TFE_NewCustomDeviceTensorHandle, for computing the
// number of dimensions of a parallel tensor.
int ParallelTensorNumDims(void* data, void* arg, TF_Status* status) {
const std::vector<int64_t>* shape;
Status s = reinterpret_cast<ParallelTensor*>(data)->Shape(&shape);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return -1;
}
return shape->size();
}
// Used as an argument to TFE_NewCustomDeviceTensorHandle, for computing a
// dimension of a parallel tensor.
int64_t ParallelTensorDim(void* data, int dim_index, void* arg,
TF_Status* status) {
const std::vector<int64_t>* shape;
Status s = reinterpret_cast<ParallelTensor*>(data)->Shape(&shape);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return -1;
}
return (*shape)[dim_index];
}
TensorHandlePtr ParallelTensorToTensorHandle(
const std::string& parallel_device_name, TFE_Context* context,
std::unique_ptr<ParallelTensor> t, TF_Status* status) {
@ -191,11 +217,10 @@ TensorHandlePtr ParallelTensorToTensorHandle(
// for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
// deleted, it will call ParallelTensorDeallocator to free the struct.
ParallelTensor* t_released = t.release();
const std::vector<int64_t>& shape(t_released->shape());
return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory(
context, parallel_device_name.c_str(), t_released->dtype(), shape.data(),
shape.size(), t_released, 1, &ParallelTensorDeallocator, nullptr,
status));
return TensorHandlePtr(TFE_NewCustomDeviceTensorHandle(
context, parallel_device_name.c_str(), t_released->dtype(), t_released,
&ParallelTensorNumDims, &ParallelTensorDim, &ParallelTensorDeallocator,
nullptr, status));
}
// For TFE_CustomDevice::copy_tensor_to_device in the parallel device

View File

@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
@ -434,30 +436,45 @@ std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status) {
std::vector<int64> shape(
TFE_TensorHandleNumDims(components[0].get(), status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
// Verify that the TensorHandle's shape matches all of the component shapes.
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
// Verify that the combined TensorHandle's dtype matches all of the component
// dtypes.
for (TensorHandlePtr& component : components) {
for (int i = 0; i < shape.size(); ++i) {
int64 tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (tensor_dim != shape[i]) {
// TODO(allenl): Allow shapes to differ.
TF_SetStatus(status, TF_UNIMPLEMENTED,
"Components of a ParallelTensor must currently all have "
"the same shape");
return nullptr;
}
if (TFE_TensorHandleDataType(component.get()) != dtype) {
TF_SetStatus(status, TF_INTERNAL,
"Components of a ParallelTensor must all have "
"the same dtype");
return nullptr;
}
}
return FromTensorHandles(parallel_device, std::move(components),
absl::Span<const int64>(shape), status);
return std::unique_ptr<ParallelTensor>(
new ParallelTensor(parallel_device, std::move(components), dtype));
}
Status ParallelTensor::Shape(const std::vector<int64_t>** shape) const {
if (!shape_.has_value()) {
TF_Status status;
PartialTensorShape first_shape;
TF_RETURN_IF_ERROR(unwrap(tensors_[0].get())->Shape(&first_shape));
// Verify that the TensorHandle's shape matches all of the component shapes.
for (const TensorHandlePtr& component : tensors_) {
PartialTensorShape component_shape;
TF_RETURN_IF_ERROR(unwrap(component.get())->Shape(&component_shape));
if (!first_shape.IsIdenticalTo(component_shape)) {
return errors::Unimplemented(absl::StrCat(
"Computing the shape of a ParallelTensor when the components do "
"not all have the same shapes is not supported. One tensor had "
"shape ",
first_shape.DebugString(), " and another had shape ",
component_shape.DebugString()));
}
}
auto dim_sizes = first_shape.dim_sizes();
shape_ = std::vector<int64_t>(dim_sizes.begin(), dim_sizes.end());
}
*shape = &*shape_;
return Status::OK();
}
} // namespace parallel_device

View File

@ -127,11 +127,13 @@ class ParallelDevice {
class ParallelTensor {
public:
// Construct a ParallelTensor from TensorHandles placed on the component
// devices of a ParallelDevice. Inspects `components` to determine a shape.
// devices of a ParallelDevice. If called, ParallelTensor::Shape inspects
// `components` to determine a shape.
static std::unique_ptr<ParallelTensor> FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status);
// Uses the provided shape without additional checks, which avoids blocking.
// Uses the provided shape without additional checks, which avoids blocking
// when ParallelTensor::Shape is called.
static std::unique_ptr<ParallelTensor> FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, absl::Span<const int64> shape,
@ -140,8 +142,13 @@ class ParallelTensor {
size_t num_tensors() const { return tensors_.size(); }
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
// A generalization of the shapes of the underlying tensors.
const std::vector<int64_t>& shape() const { return shape_; }
// If the `shape` argument to `FromTensorHandles` is specified, returns that.
//
// Otherwise if all of the tensors have the same shape, returns that via the
// `shape` output argument. This blocks waiting for async tensors, may return
// a delayed bad status encountered during async execution, and will return a
// bad status unless all tensors have the same shape.
Status Shape(const std::vector<int64_t>** shape) const;
TF_DataType dtype() const { return dtype_; }
private:
@ -150,12 +157,21 @@ class ParallelTensor {
absl::Span<const int64> shape, const TF_DataType dtype)
: device_(device),
tensors_(std::move(tensors)),
shape_(shape.begin(), shape.end()),
shape_(std::vector<int64_t>(shape.begin(), shape.end())),
dtype_(dtype) {}
ParallelTensor(const ParallelDevice& device,
std::vector<TensorHandlePtr> tensors, const TF_DataType dtype)
: device_(device),
tensors_(std::move(tensors)),
shape_(absl::nullopt),
dtype_(dtype) {}
const ParallelDevice& device_;
const std::vector<TensorHandlePtr> tensors_;
const std::vector<int64_t> shape_;
// Parallel tensors are immutable but compute their shape lazily unless it is
// provided on construction. The optional has a value if the lazy computation
// has been completed or the shape was provided on construction.
mutable absl::optional<std::vector<int64_t>> shape_;
const TF_DataType dtype_;
};

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -113,7 +114,76 @@ TEST(PARALLEL_DEVICE_LIB, TestExplicitOutputShape) {
TFE_OpGetAttrs(handle_op.get()), {PartialTensorShape({})}, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs;
EXPECT_EQ(0, handles[0]->shape().size());
const std::vector<int64_t>* shape;
Status s = handles[0]->Shape(&shape);
ASSERT_TRUE(s.ok());
EXPECT_EQ(0, shape->size());
}
TEST(PARALLEL_DEVICE_LIB, TestDifferentShapes) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::vector<std::string> devices{
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1"};
ParallelDevice parallel_device(std::move(devices));
TensorHandlePtr two_vector = VectorFloatTensorHandle({3., 4.}, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr three_vector =
VectorFloatTensorHandle({5., 6., 7.}, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::vector<TensorHandlePtr> vector_handles;
vector_handles.reserve(2);
vector_handles.push_back(std::move(two_vector));
vector_handles.push_back(std::move(three_vector));
std::unique_ptr<ParallelTensor> unknown_length_vector =
ParallelTensor::FromTensorHandles(
parallel_device, std::move(vector_handles), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const std::vector<int64_t>* shape;
Status s = unknown_length_vector->Shape(&shape);
EXPECT_FALSE(s.ok());
TensorHandlePtr scalar = FloatTensorHandle(2., status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
two_vector = VectorFloatTensorHandle({3., 4.}, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::vector<TensorHandlePtr> mixed_handles;
mixed_handles.reserve(2);
mixed_handles.push_back(std::move(scalar));
mixed_handles.push_back(std::move(two_vector));
std::unique_ptr<ParallelTensor> unknown_dims_vector =
ParallelTensor::FromTensorHandles(parallel_device,
std::move(mixed_handles), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
// Can't take the shape of a parallel tensor with varying numbers of axes, but
// running operations on them is OK.
s = unknown_length_vector->Shape(&shape);
EXPECT_FALSE(s.ok());
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> size_op(
TFE_NewOp(context.get(), "Size", status.get()), TFE_DeleteOp);
auto result = parallel_device.Execute(
context.get(), {unknown_dims_vector.get()}, "Size",
TFE_OpGetAttrs(size_op.get()), 1, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
s = (*result)[0]->Shape(&shape);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
EXPECT_EQ(0, shape->size());
}
} // namespace parallel_device

View File

@ -41,6 +41,9 @@ tensorflow::ServerDef GetServerDef(const std::string& job_name, int num_tasks) {
return server_def;
}
namespace tensorflow {
namespace parallel_device {
TEST(PARALLEL_DEVICE, TestRemoteBasic) {
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
@ -145,3 +148,5 @@ TEST(PARALLEL_DEVICE, TestAsyncCopyOff) {
worker_server1.release();
worker_server2.release();
}
} // namespace parallel_device
} // namespace tensorflow

View File

@ -29,6 +29,9 @@ limitations under the License.
// correspond fairly well to the implementation, but testing the C++ directly is
// another option.
namespace tensorflow {
namespace parallel_device {
TEST(PARALLEL_DEVICE, TestBasicCPU) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -199,8 +202,10 @@ TEST(PARALLEL_DEVICE, TestDifferentShapes) {
std::array<TFE_TensorHandle*, 2> components{size_two.get(), size_three.get()};
TensorHandlePtr combined_value = CreatePerDeviceValues(
context.get(), components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_UNIMPLEMENTED)
<< TF_Message(status.get());
// We can create the handle, but fetching the shape is an error at the moment.
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandleNumDims(combined_value.get(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_UNIMPLEMENTED);
}
TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
@ -570,3 +575,6 @@ TEST(PARALLEL_DEVICE, TestFunction) {
result_components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
} // namespace parallel_device
} // namespace tensorflow

View File

@ -28,6 +28,8 @@ limitations under the License.
// correspond fairly well to the implementation, but testing the C++ directly is
// another option.
namespace tensorflow {
namespace parallel_device {
Variable* Variable::Create(TFE_Context* context, TF_DataType type,
const int64_t* dims, const int num_dims,
@ -280,3 +282,6 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
ASSERT_EQ(underlying_devices[1], second_device);
}
}
} // namespace parallel_device
} // namespace tensorflow

View File

@ -16,29 +16,18 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
#include <array>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/core/platform/test.h"
// Functor for making unique_ptr to TFE_TensorHandle slightly more
// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
// template argument requires passing a function pointer to
// TFE_DeleteTensorHandle when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
namespace tensorflow {
namespace parallel_device {
// A helper for performing common operations on variables. A much more
// restricted stand-in for tf.Variable in Python.
@ -171,4 +160,7 @@ void RegisterParallelDevice(
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
}
} // namespace parallel_device
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_

View File

@ -6748,6 +6748,8 @@ tf_python_pybind_extension(
module_name = "_pywrap_parallel_device",
visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
deps = [
"//tensorflow/c:pywrap_required_hdrs",
"//tensorflow/c/eager:tfe_tensorhandle_internal_hdrs_only",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:lib_headers_for_pybind",
"//tensorflow/core:protos_all_cc",

View File

@ -405,6 +405,24 @@ class ParallelDeviceTests(_VirtualDeviceTestCase, parameterized.TestCase):
outputs = self.device.unpack(packed_outputs)
self.assertAllClose([16., 16.], outputs)
def test_different_shapes(self):
with self.device:
x = self.device.pack(
[constant_op.constant([1., 2.]),
constant_op.constant([5.])])
y = x * 2.
with self.assertRaisesRegex(Exception,
"components do not all have the same shape"):
y.shape # pylint: disable=pointless-statement
self.assertAllClose([[2., 4.], [10.]], self.device.unpack(y))
different_axes = self.device.pack(
[constant_op.constant([1., 2.]),
constant_op.constant([[5.]])])
with self.assertRaisesRegex(Exception,
"components do not all have the same shape"):
different_axes.shape # pylint: disable=pointless-statement
class LayerTests(_VirtualDeviceTestCase):