Update device_id to be int32 rather than int64.
PiperOrigin-RevId: 317709385 Change-Id: I577dd469d223cc05c50dbdb6a8bd908e2e757344
This commit is contained in:
parent
dfd21eaec6
commit
780c0a29fe
@ -262,14 +262,14 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
int64_t* device_id = new int64_t;
|
||||
int32_t* device_id = new int32_t;
|
||||
*device_id = device_index;
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(
|
||||
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
||||
sizeof(int64_t),
|
||||
TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
||||
sizeof(int32_t),
|
||||
[](void* data, size_t, void* arg) {
|
||||
delete reinterpret_cast<int64_t*>(data);
|
||||
delete reinterpret_cast<int32_t*>(data);
|
||||
},
|
||||
nullptr),
|
||||
TF_DeleteTensor);
|
||||
@ -283,7 +283,7 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
|
||||
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT32);
|
||||
TFE_TensorHandle* device_handle;
|
||||
int num_outputs = 1;
|
||||
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
||||
|
@ -296,8 +296,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
ExpectScalarEq<int64_t>(components[0].get(), 0);
|
||||
ExpectScalarEq<int64_t>(components[1].get(), 1);
|
||||
ExpectScalarEq<int32_t>(components[0].get(), 0);
|
||||
ExpectScalarEq<int32_t>(components[1].get(), 1);
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
|
Loading…
Reference in New Issue
Block a user