diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc index 19c7f99c0d8..8acfb4fcbf2 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc @@ -265,55 +265,6 @@ std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice( status); } -std::unique_ptr<ParallelTensor> ParallelDevice::Vector( - TFE_Context* context, TF_Status* status, - absl::Span<const int32_t> values) const { - // TODO(allenl): We could cache DeviceIDs (keyed by context). - std::vector<TensorHandlePtr> components; - components.reserve(underlying_devices_.size()); - - if (values.size() != num_underlying_devices()) { - TF_SetStatus( - status, TF_INVALID_ARGUMENT, - "Number of values did not match number of underlying devices."); - return nullptr; - } - - for (int device_index = 0; device_index < num_underlying_devices(); - ++device_index) { - int32_t* device_value = new int32_t; - *device_value = values[device_index]; - std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor( - TF_NewTensor( - TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_value, - sizeof(int32_t), - [](void* data, size_t, void* arg) { - delete reinterpret_cast<int32_t*>(data); - }, - nullptr), - TF_DeleteTensor); - // TODO(allenl): Here and when executing regular operations, we could hold - // on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing - // device names repeatedly. - OpPtr const_op(TFE_NewOp(context, "Const", status)); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(), - status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT32); - TFE_TensorHandle* device_handle; - int num_outputs = 1; - TFE_Execute(const_op.get(), &device_handle, &num_outputs, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - components.emplace_back(device_handle); - if (TF_GetCode(status) != TF_OK) return nullptr; - } - return ParallelTensor::FromTensorHandles(*this, std::move(components), - status); -} - std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs( TFE_Context* context, TF_Status* status) const { std::vector<int32_t> ids; @@ -321,7 +272,7 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs( for (int i = 0; i < num_underlying_devices(); ++i) { ids.push_back(i); } - return Vector(context, status, ids); + return ScalarsFromSequence<int32_t>(ids, context, status); } absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.h b/tensorflow/c/eager/parallel_device/parallel_device_lib.h index d5b6042b8fc..c510cb5ca7e 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.h +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" namespace tensorflow { namespace parallel_device { @@ -67,9 +68,10 @@ class ParallelDevice { TF_Status* status) const; // Construct a parallel tensor consisting of the scalar values from `values`. - std::unique_ptr<ParallelTensor> Vector( - TFE_Context* context, TF_Status* status, - absl::Span<const int32_t> values) const; + template <typename DataType> + std::unique_ptr<ParallelTensor> ScalarsFromSequence( + absl::Span<const DataType> values, TFE_Context* context, + TF_Status* status) const; // A parallel tensor with scalar integers numbering component devices. std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context, @@ -193,6 +195,56 @@ class ParallelTensor { const TF_DataType dtype_; }; +template <typename DataType> +std::unique_ptr<ParallelTensor> ParallelDevice::ScalarsFromSequence( + absl::Span<DataType const> values, TFE_Context* context, + TF_Status* status) const { + std::vector<TensorHandlePtr> components; + components.reserve(underlying_devices_.size()); + + if (values.size() != num_underlying_devices()) { + TF_SetStatus( + status, TF_INVALID_ARGUMENT, + "Number of values did not match number of underlying devices."); + return nullptr; + } + TF_DataType datatype_enum( + static_cast<TF_DataType>(DataTypeToEnum<DataType>().value)); + for (int device_index = 0; device_index < num_underlying_devices(); + ++device_index) { + auto device_value = absl::make_unique<DataType>(); + *device_value = values[device_index]; + std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor( + TF_NewTensor( + datatype_enum, /*dims=*/nullptr, /*num_dims=*/0, + device_value.release(), sizeof(DataType), + [](void* data, size_t, void* arg) { + delete reinterpret_cast<DataType*>(data); + }, + nullptr), + TF_DeleteTensor); + // TODO(allenl): Here and when executing regular operations, we could hold + // on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing + // device names repeatedly. + std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> const_op( + TFE_NewOp(context, "Const", status), TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(), + status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(const_op.get(), "dtype", datatype_enum); + TFE_TensorHandle* device_handle; + int num_outputs = 1; + TFE_Execute(const_op.get(), &device_handle, &num_outputs, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + components.emplace_back(device_handle); + } + return ParallelTensor::FromTensorHandles(*this, std::move(components), + status); +} + } // namespace parallel_device } // namespace tensorflow diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc index e4c4538d79d..f2d1b31c893 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc @@ -188,5 +188,44 @@ TEST(PARALLEL_DEVICE_LIB, TestDifferentShapes) { EXPECT_EQ(0, shape->size()); } +TEST(PARALLEL_DEVICE_LIB, TestScalarsFromSequence) { + std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config( + TF_CreateConfig( + /*enable_xla_compilation=*/false, + /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2), + TF_DeleteBuffer); + TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, + status.get()); + std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + + std::vector<std::string> devices{ + "/job:localhost/replica:0/task:0/device:CPU:0", + "/job:localhost/replica:0/task:0/device:CPU:1"}; + ParallelDevice parallel_device(std::move(devices)); + { + std::unique_ptr<ParallelTensor> float_tensors = + parallel_device.ScalarsFromSequence<float>({10.0, 11.0}, context.get(), + status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + ExpectScalarEq<float>(float_tensors->tensor(0), 10.0); + ExpectScalarEq<float>(float_tensors->tensor(1), 11.0); + } + + { + std::unique_ptr<ParallelTensor> int_tensors = + parallel_device.ScalarsFromSequence<int>({5, 6}, context.get(), + status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + ExpectScalarEq<int>(int_tensors->tensor(0), 5); + ExpectScalarEq<int>(int_tensors->tensor(1), 6); + } +} + } // namespace parallel_device } // namespace tensorflow diff --git a/tensorflow/c/eager/parallel_device/parallel_device_testlib.h b/tensorflow/c/eager/parallel_device/parallel_device_testlib.h index 30d0881b512..ecc96dd66ee 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_testlib.h +++ b/tensorflow/c/eager/parallel_device/parallel_device_testlib.h @@ -140,11 +140,13 @@ template <typename value_type> void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) { std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( TF_NewStatus(), TF_DeleteStatus); - std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero( + std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> actual_value( TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_TensorType(actual_value.get()), + static_cast<TF_DataType>(DataTypeToEnum<value_type>().value)); EXPECT_EQ(expected_value, - *static_cast<value_type*>(TF_TensorData(value_zero.get()))); + *static_cast<value_type*>(TF_TensorData(actual_value.get()))); } template <std::size_t num_devices>