Parallel device: Minor rename+slight generalization to make it easier to create literal/constant parallel tensors

PiperOrigin-RevId: 356387020
Change-Id: I0c0b88f8eccaf9b0b4fed90390295f44c0c11388
This commit is contained in:
Allen Lavoie 2021-02-08 17:06:53 -08:00 committed by TensorFlower Gardener
parent 655dde4d59
commit 358a61ab90
4 changed files with 100 additions and 56 deletions

View File

@ -265,55 +265,6 @@ std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
status);
}
std::unique_ptr<ParallelTensor> ParallelDevice::Vector(
TFE_Context* context, TF_Status* status,
absl::Span<const int32_t> values) const {
// TODO(allenl): We could cache DeviceIDs (keyed by context).
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
if (values.size() != num_underlying_devices()) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
"Number of values did not match number of underlying devices.");
return nullptr;
}
for (int device_index = 0; device_index < num_underlying_devices();
++device_index) {
int32_t* device_value = new int32_t;
*device_value = values[device_index];
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_value,
sizeof(int32_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int32_t*>(data);
},
nullptr),
TF_DeleteTensor);
// TODO(allenl): Here and when executing regular operations, we could hold
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
// device names repeatedly.
OpPtr const_op(TFE_NewOp(context, "Const", status));
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT32);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(device_handle);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
TFE_Context* context, TF_Status* status) const {
std::vector<int32_t> ids;
@ -321,7 +272,7 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
for (int i = 0; i < num_underlying_devices(); ++i) {
ids.push_back(i);
}
return Vector(context, status, ids);
return ScalarsFromSequence<int32_t>(ids, context, status);
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
namespace tensorflow {
namespace parallel_device {
@ -67,9 +68,10 @@ class ParallelDevice {
TF_Status* status) const;
// Construct a parallel tensor consisting of the scalar values from `values`.
std::unique_ptr<ParallelTensor> Vector(
TFE_Context* context, TF_Status* status,
absl::Span<const int32_t> values) const;
template <typename DataType>
std::unique_ptr<ParallelTensor> ScalarsFromSequence(
absl::Span<const DataType> values, TFE_Context* context,
TF_Status* status) const;
// A parallel tensor with scalar integers numbering component devices.
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
@ -193,6 +195,56 @@ class ParallelTensor {
const TF_DataType dtype_;
};
template <typename DataType>
std::unique_ptr<ParallelTensor> ParallelDevice::ScalarsFromSequence(
absl::Span<DataType const> values, TFE_Context* context,
TF_Status* status) const {
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
if (values.size() != num_underlying_devices()) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
"Number of values did not match number of underlying devices.");
return nullptr;
}
TF_DataType datatype_enum(
static_cast<TF_DataType>(DataTypeToEnum<DataType>().value));
for (int device_index = 0; device_index < num_underlying_devices();
++device_index) {
auto device_value = absl::make_unique<DataType>();
*device_value = values[device_index];
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
datatype_enum, /*dims=*/nullptr, /*num_dims=*/0,
device_value.release(), sizeof(DataType),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<DataType*>(data);
},
nullptr),
TF_DeleteTensor);
// TODO(allenl): Here and when executing regular operations, we could hold
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
// device names repeatedly.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> const_op(
TFE_NewOp(context, "Const", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", datatype_enum);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(device_handle);
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
} // namespace parallel_device
} // namespace tensorflow

View File

@ -188,5 +188,44 @@ TEST(PARALLEL_DEVICE_LIB, TestDifferentShapes) {
EXPECT_EQ(0, shape->size());
}
TEST(PARALLEL_DEVICE_LIB, TestScalarsFromSequence) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*enable_xla_compilation=*/false,
/*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
std::vector<std::string> devices{
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1"};
ParallelDevice parallel_device(std::move(devices));
{
std::unique_ptr<ParallelTensor> float_tensors =
parallel_device.ScalarsFromSequence<float>({10.0, 11.0}, context.get(),
status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(float_tensors->tensor(0), 10.0);
ExpectScalarEq<float>(float_tensors->tensor(1), 11.0);
}
{
std::unique_ptr<ParallelTensor> int_tensors =
parallel_device.ScalarsFromSequence<int>({5, 6}, context.get(),
status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
ExpectScalarEq<int>(int_tensors->tensor(0), 5);
ExpectScalarEq<int>(int_tensors->tensor(1), 6);
}
}
} // namespace parallel_device
} // namespace tensorflow

View File

@ -140,11 +140,13 @@ template <typename value_type>
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> actual_value(
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
ASSERT_EQ(TF_TensorType(actual_value.get()),
static_cast<TF_DataType>(DataTypeToEnum<value_type>().value));
EXPECT_EQ(expected_value,
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
*static_cast<value_type*>(TF_TensorData(actual_value.get())));
}
template <std::size_t num_devices>