Update device_id to be int32 rather than int64.
PiperOrigin-RevId: 317709385 Change-Id: I577dd469d223cc05c50dbdb6a8bd908e2e757344
This commit is contained in:
parent
dfd21eaec6
commit
780c0a29fe
@ -262,14 +262,14 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
|||||||
components.reserve(underlying_devices_.size());
|
components.reserve(underlying_devices_.size());
|
||||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||||
++device_index) {
|
++device_index) {
|
||||||
int64_t* device_id = new int64_t;
|
int32_t* device_id = new int32_t;
|
||||||
*device_id = device_index;
|
*device_id = device_index;
|
||||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||||
TF_NewTensor(
|
TF_NewTensor(
|
||||||
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
||||||
sizeof(int64_t),
|
sizeof(int32_t),
|
||||||
[](void* data, size_t, void* arg) {
|
[](void* data, size_t, void* arg) {
|
||||||
delete reinterpret_cast<int64_t*>(data);
|
delete reinterpret_cast<int32_t*>(data);
|
||||||
},
|
},
|
||||||
nullptr),
|
nullptr),
|
||||||
TF_DeleteTensor);
|
TF_DeleteTensor);
|
||||||
@ -283,7 +283,7 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
|||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
|
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT32);
|
||||||
TFE_TensorHandle* device_handle;
|
TFE_TensorHandle* device_handle;
|
||||||
int num_outputs = 1;
|
int num_outputs = 1;
|
||||||
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
||||||
|
@ -296,8 +296,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
|||||||
TFE_DeleteTensorHandle(result_handle);
|
TFE_DeleteTensorHandle(result_handle);
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
ExpectScalarEq<int64_t>(components[0].get(), 0);
|
ExpectScalarEq<int32_t>(components[0].get(), 0);
|
||||||
ExpectScalarEq<int64_t>(components[1].get(), 1);
|
ExpectScalarEq<int32_t>(components[1].get(), 1);
|
||||||
std::string first_device =
|
std::string first_device =
|
||||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||||
ASSERT_EQ(underlying_devices[0], first_device);
|
ASSERT_EQ(underlying_devices[0], first_device);
|
||||||
|
Loading…
Reference in New Issue
Block a user