Parallel device: Minor rename+slight generalization to make it easier to create literal/constant parallel tensors
PiperOrigin-RevId: 356387020 Change-Id: I0c0b88f8eccaf9b0b4fed90390295f44c0c11388
This commit is contained in:
parent
655dde4d59
commit
358a61ab90
@ -265,55 +265,6 @@ std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
||||
status);
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::Vector(
|
||||
TFE_Context* context, TF_Status* status,
|
||||
absl::Span<const int32_t> values) const {
|
||||
// TODO(allenl): We could cache DeviceIDs (keyed by context).
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
|
||||
if (values.size() != num_underlying_devices()) {
|
||||
TF_SetStatus(
|
||||
status, TF_INVALID_ARGUMENT,
|
||||
"Number of values did not match number of underlying devices.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (int device_index = 0; device_index < num_underlying_devices();
|
||||
++device_index) {
|
||||
int32_t* device_value = new int32_t;
|
||||
*device_value = values[device_index];
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(
|
||||
TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_value,
|
||||
sizeof(int32_t),
|
||||
[](void* data, size_t, void* arg) {
|
||||
delete reinterpret_cast<int32_t*>(data);
|
||||
},
|
||||
nullptr),
|
||||
TF_DeleteTensor);
|
||||
// TODO(allenl): Here and when executing regular operations, we could hold
|
||||
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
|
||||
// device names repeatedly.
|
||||
OpPtr const_op(TFE_NewOp(context, "Const", status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT32);
|
||||
TFE_TensorHandle* device_handle;
|
||||
int num_outputs = 1;
|
||||
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
components.emplace_back(device_handle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||
status);
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
TFE_Context* context, TF_Status* status) const {
|
||||
std::vector<int32_t> ids;
|
||||
@ -321,7 +272,7 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
for (int i = 0; i < num_underlying_devices(); ++i) {
|
||||
ids.push_back(i);
|
||||
}
|
||||
return Vector(context, status, ids);
|
||||
return ScalarsFromSequence<int32_t>(ids, context, status);
|
||||
}
|
||||
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace parallel_device {
|
||||
@ -67,9 +68,10 @@ class ParallelDevice {
|
||||
TF_Status* status) const;
|
||||
|
||||
// Construct a parallel tensor consisting of the scalar values from `values`.
|
||||
std::unique_ptr<ParallelTensor> Vector(
|
||||
TFE_Context* context, TF_Status* status,
|
||||
absl::Span<const int32_t> values) const;
|
||||
template <typename DataType>
|
||||
std::unique_ptr<ParallelTensor> ScalarsFromSequence(
|
||||
absl::Span<const DataType> values, TFE_Context* context,
|
||||
TF_Status* status) const;
|
||||
|
||||
// A parallel tensor with scalar integers numbering component devices.
|
||||
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
|
||||
@ -193,6 +195,56 @@ class ParallelTensor {
|
||||
const TF_DataType dtype_;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::ScalarsFromSequence(
|
||||
absl::Span<DataType const> values, TFE_Context* context,
|
||||
TF_Status* status) const {
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
|
||||
if (values.size() != num_underlying_devices()) {
|
||||
TF_SetStatus(
|
||||
status, TF_INVALID_ARGUMENT,
|
||||
"Number of values did not match number of underlying devices.");
|
||||
return nullptr;
|
||||
}
|
||||
TF_DataType datatype_enum(
|
||||
static_cast<TF_DataType>(DataTypeToEnum<DataType>().value));
|
||||
for (int device_index = 0; device_index < num_underlying_devices();
|
||||
++device_index) {
|
||||
auto device_value = absl::make_unique<DataType>();
|
||||
*device_value = values[device_index];
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(
|
||||
datatype_enum, /*dims=*/nullptr, /*num_dims=*/0,
|
||||
device_value.release(), sizeof(DataType),
|
||||
[](void* data, size_t, void* arg) {
|
||||
delete reinterpret_cast<DataType*>(data);
|
||||
},
|
||||
nullptr),
|
||||
TF_DeleteTensor);
|
||||
// TODO(allenl): Here and when executing regular operations, we could hold
|
||||
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
|
||||
// device names repeatedly.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> const_op(
|
||||
TFE_NewOp(context, "Const", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(const_op.get(), "dtype", datatype_enum);
|
||||
TFE_TensorHandle* device_handle;
|
||||
int num_outputs = 1;
|
||||
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
components.emplace_back(device_handle);
|
||||
}
|
||||
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||
status);
|
||||
}
|
||||
|
||||
} // namespace parallel_device
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -188,5 +188,44 @@ TEST(PARALLEL_DEVICE_LIB, TestDifferentShapes) {
|
||||
EXPECT_EQ(0, shape->size());
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE_LIB, TestScalarsFromSequence) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*enable_xla_compilation=*/false,
|
||||
/*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::vector<std::string> devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
ParallelDevice parallel_device(std::move(devices));
|
||||
{
|
||||
std::unique_ptr<ParallelTensor> float_tensors =
|
||||
parallel_device.ScalarsFromSequence<float>({10.0, 11.0}, context.get(),
|
||||
status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
ExpectScalarEq<float>(float_tensors->tensor(0), 10.0);
|
||||
ExpectScalarEq<float>(float_tensors->tensor(1), 11.0);
|
||||
}
|
||||
|
||||
{
|
||||
std::unique_ptr<ParallelTensor> int_tensors =
|
||||
parallel_device.ScalarsFromSequence<int>({5, 6}, context.get(),
|
||||
status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
ExpectScalarEq<int>(int_tensors->tensor(0), 5);
|
||||
ExpectScalarEq<int>(int_tensors->tensor(1), 6);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace parallel_device
|
||||
} // namespace tensorflow
|
||||
|
@ -140,11 +140,13 @@ template <typename value_type>
|
||||
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> actual_value(
|
||||
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(TF_TensorType(actual_value.get()),
|
||||
static_cast<TF_DataType>(DataTypeToEnum<value_type>().value));
|
||||
EXPECT_EQ(expected_value,
|
||||
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
|
||||
*static_cast<value_type*>(TF_TensorData(actual_value.get())));
|
||||
}
|
||||
|
||||
template <std::size_t num_devices>
|
||||
|
Loading…
Reference in New Issue
Block a user