Parallel device: Minor rename+slight generalization to make it easier to create literal/constant parallel tensors
PiperOrigin-RevId: 356387020 Change-Id: I0c0b88f8eccaf9b0b4fed90390295f44c0c11388
This commit is contained in:
parent
655dde4d59
commit
358a61ab90
@ -265,55 +265,6 @@ std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
|||||||
status);
|
status);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<ParallelTensor> ParallelDevice::Vector(
|
|
||||||
TFE_Context* context, TF_Status* status,
|
|
||||||
absl::Span<const int32_t> values) const {
|
|
||||||
// TODO(allenl): We could cache DeviceIDs (keyed by context).
|
|
||||||
std::vector<TensorHandlePtr> components;
|
|
||||||
components.reserve(underlying_devices_.size());
|
|
||||||
|
|
||||||
if (values.size() != num_underlying_devices()) {
|
|
||||||
TF_SetStatus(
|
|
||||||
status, TF_INVALID_ARGUMENT,
|
|
||||||
"Number of values did not match number of underlying devices.");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int device_index = 0; device_index < num_underlying_devices();
|
|
||||||
++device_index) {
|
|
||||||
int32_t* device_value = new int32_t;
|
|
||||||
*device_value = values[device_index];
|
|
||||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
|
||||||
TF_NewTensor(
|
|
||||||
TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_value,
|
|
||||||
sizeof(int32_t),
|
|
||||||
[](void* data, size_t, void* arg) {
|
|
||||||
delete reinterpret_cast<int32_t*>(data);
|
|
||||||
},
|
|
||||||
nullptr),
|
|
||||||
TF_DeleteTensor);
|
|
||||||
// TODO(allenl): Here and when executing regular operations, we could hold
|
|
||||||
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
|
|
||||||
// device names repeatedly.
|
|
||||||
OpPtr const_op(TFE_NewOp(context, "Const", status));
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
|
|
||||||
status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT32);
|
|
||||||
TFE_TensorHandle* device_handle;
|
|
||||||
int num_outputs = 1;
|
|
||||||
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
components.emplace_back(device_handle);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
}
|
|
||||||
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
|
||||||
status);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||||
TFE_Context* context, TF_Status* status) const {
|
TFE_Context* context, TF_Status* status) const {
|
||||||
std::vector<int32_t> ids;
|
std::vector<int32_t> ids;
|
||||||
@ -321,7 +272,7 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
|||||||
for (int i = 0; i < num_underlying_devices(); ++i) {
|
for (int i = 0; i < num_underlying_devices(); ++i) {
|
||||||
ids.push_back(i);
|
ids.push_back(i);
|
||||||
}
|
}
|
||||||
return Vector(context, status, ids);
|
return ScalarsFromSequence<int32_t>(ids, context, status);
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace parallel_device {
|
namespace parallel_device {
|
||||||
@ -67,9 +68,10 @@ class ParallelDevice {
|
|||||||
TF_Status* status) const;
|
TF_Status* status) const;
|
||||||
|
|
||||||
// Construct a parallel tensor consisting of the scalar values from `values`.
|
// Construct a parallel tensor consisting of the scalar values from `values`.
|
||||||
std::unique_ptr<ParallelTensor> Vector(
|
template <typename DataType>
|
||||||
TFE_Context* context, TF_Status* status,
|
std::unique_ptr<ParallelTensor> ScalarsFromSequence(
|
||||||
absl::Span<const int32_t> values) const;
|
absl::Span<const DataType> values, TFE_Context* context,
|
||||||
|
TF_Status* status) const;
|
||||||
|
|
||||||
// A parallel tensor with scalar integers numbering component devices.
|
// A parallel tensor with scalar integers numbering component devices.
|
||||||
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
|
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
|
||||||
@ -193,6 +195,56 @@ class ParallelTensor {
|
|||||||
const TF_DataType dtype_;
|
const TF_DataType dtype_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename DataType>
|
||||||
|
std::unique_ptr<ParallelTensor> ParallelDevice::ScalarsFromSequence(
|
||||||
|
absl::Span<DataType const> values, TFE_Context* context,
|
||||||
|
TF_Status* status) const {
|
||||||
|
std::vector<TensorHandlePtr> components;
|
||||||
|
components.reserve(underlying_devices_.size());
|
||||||
|
|
||||||
|
if (values.size() != num_underlying_devices()) {
|
||||||
|
TF_SetStatus(
|
||||||
|
status, TF_INVALID_ARGUMENT,
|
||||||
|
"Number of values did not match number of underlying devices.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
TF_DataType datatype_enum(
|
||||||
|
static_cast<TF_DataType>(DataTypeToEnum<DataType>().value));
|
||||||
|
for (int device_index = 0; device_index < num_underlying_devices();
|
||||||
|
++device_index) {
|
||||||
|
auto device_value = absl::make_unique<DataType>();
|
||||||
|
*device_value = values[device_index];
|
||||||
|
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||||
|
TF_NewTensor(
|
||||||
|
datatype_enum, /*dims=*/nullptr, /*num_dims=*/0,
|
||||||
|
device_value.release(), sizeof(DataType),
|
||||||
|
[](void* data, size_t, void* arg) {
|
||||||
|
delete reinterpret_cast<DataType*>(data);
|
||||||
|
},
|
||||||
|
nullptr),
|
||||||
|
TF_DeleteTensor);
|
||||||
|
// TODO(allenl): Here and when executing regular operations, we could hold
|
||||||
|
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
|
||||||
|
// device names repeatedly.
|
||||||
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> const_op(
|
||||||
|
TFE_NewOp(context, "Const", status), TFE_DeleteOp);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
|
||||||
|
status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetAttrType(const_op.get(), "dtype", datatype_enum);
|
||||||
|
TFE_TensorHandle* device_handle;
|
||||||
|
int num_outputs = 1;
|
||||||
|
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
components.emplace_back(device_handle);
|
||||||
|
}
|
||||||
|
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||||
|
status);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace parallel_device
|
} // namespace parallel_device
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -188,5 +188,44 @@ TEST(PARALLEL_DEVICE_LIB, TestDifferentShapes) {
|
|||||||
EXPECT_EQ(0, shape->size());
|
EXPECT_EQ(0, shape->size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(PARALLEL_DEVICE_LIB, TestScalarsFromSequence) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||||
|
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||||
|
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||||
|
TF_CreateConfig(
|
||||||
|
/*enable_xla_compilation=*/false,
|
||||||
|
/*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
|
||||||
|
TF_DeleteBuffer);
|
||||||
|
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||||
|
status.get());
|
||||||
|
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||||
|
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||||
|
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
std::vector<std::string> devices{
|
||||||
|
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||||
|
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||||
|
ParallelDevice parallel_device(std::move(devices));
|
||||||
|
{
|
||||||
|
std::unique_ptr<ParallelTensor> float_tensors =
|
||||||
|
parallel_device.ScalarsFromSequence<float>({10.0, 11.0}, context.get(),
|
||||||
|
status.get());
|
||||||
|
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||||
|
ExpectScalarEq<float>(float_tensors->tensor(0), 10.0);
|
||||||
|
ExpectScalarEq<float>(float_tensors->tensor(1), 11.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
std::unique_ptr<ParallelTensor> int_tensors =
|
||||||
|
parallel_device.ScalarsFromSequence<int>({5, 6}, context.get(),
|
||||||
|
status.get());
|
||||||
|
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||||
|
ExpectScalarEq<int>(int_tensors->tensor(0), 5);
|
||||||
|
ExpectScalarEq<int>(int_tensors->tensor(1), 6);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace parallel_device
|
} // namespace parallel_device
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -140,11 +140,13 @@ template <typename value_type>
|
|||||||
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
|
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
|
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> actual_value(
|
||||||
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
|
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||||
|
ASSERT_EQ(TF_TensorType(actual_value.get()),
|
||||||
|
static_cast<TF_DataType>(DataTypeToEnum<value_type>().value));
|
||||||
EXPECT_EQ(expected_value,
|
EXPECT_EQ(expected_value,
|
||||||
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
|
*static_cast<value_type*>(TF_TensorData(actual_value.get())));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <std::size_t num_devices>
|
template <std::size_t num_devices>
|
||||||
|
Loading…
Reference in New Issue
Block a user