Update device_id to be int32 rather than int64.

PiperOrigin-RevId: 317709385
Change-Id: I577dd469d223cc05c50dbdb6a8bd908e2e757344
This commit is contained in:
Haitang Hu 2020-06-22 12:21:45 -07:00 committed by TensorFlower Gardener
parent dfd21eaec6
commit 780c0a29fe
2 changed files with 7 additions and 7 deletions

View File

@ -262,14 +262,14 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
int64_t* device_id = new int64_t;
int32_t* device_id = new int32_t;
*device_id = device_index;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int64_t),
TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int32_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int64_t*>(data);
delete reinterpret_cast<int32_t*>(data);
},
nullptr),
TF_DeleteTensor);
@ -283,7 +283,7 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT32);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);

View File

@ -296,8 +296,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
TFE_DeleteTensorHandle(result_handle);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<int64_t>(components[0].get(), 0);
ExpectScalarEq<int64_t>(components[1].get(), 1);
ExpectScalarEq<int32_t>(components[0].get(), 0);
ExpectScalarEq<int32_t>(components[1].get(), 1);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);