Expose device memory information via XRT API.
PiperOrigin-RevId: 318893328 Change-Id: I1bdfc8c6fcabe7b4f9a662272aa0c40a795299da
This commit is contained in:
parent
1b53b995da
commit
1e1ce81457
@ -195,4 +195,9 @@ REGISTER_KERNEL_BUILDER(Name("XRTCompactAllocations").Device(DEVICE_XLA_CPU),
|
|||||||
REGISTER_KERNEL_BUILDER(Name("XRTMetricsCollect").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("XRTMetricsCollect").Device(DEVICE_CPU),
|
||||||
XRTMetricsCollectOp);
|
XRTMetricsCollectOp);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("XRTMemoryInfo").Device(DEVICE_XLA_GPU),
|
||||||
|
XRTMemoryInfoOp<XRTGenericDeviceAccessor>);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("XRTMemoryInfo").Device(DEVICE_XLA_CPU),
|
||||||
|
XRTMemoryInfoOp<XRTGenericDeviceAccessor>);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -739,6 +739,42 @@ class XRTCompactAllocationsOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <class DeviceAccessor>
|
||||||
|
class XRTMemoryInfoOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit XRTMemoryInfoOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
~XRTMemoryInfoOp() override = default;
|
||||||
|
XRTMemoryInfoOp(const XRTMemoryInfoOp&) = delete;
|
||||||
|
XRTMemoryInfoOp& operator=(const XRTMemoryInfoOp&) = delete;
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
auto kernel_fn = [&]() -> Status {
|
||||||
|
VLOG(1) << "XRTMemoryInfoOp::Compute";
|
||||||
|
|
||||||
|
class DeviceAccessor::ScopedRef device_ref;
|
||||||
|
TF_RETURN_IF_ERROR(DeviceAccessor::InitScopedRef(ctx, &device_ref));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
se::StreamExecutor * stream_executor,
|
||||||
|
device_ref.backend()->stream_executor(device_ref.device_ordinal()));
|
||||||
|
int64 mem_free = -1;
|
||||||
|
int64 mem_total = -1;
|
||||||
|
if (!stream_executor->DeviceMemoryUsage(&mem_free, &mem_total)) {
|
||||||
|
VLOG(2) << "Device " << ctx->device()->name()
|
||||||
|
<< " does not expose memory information";
|
||||||
|
}
|
||||||
|
xrt::MemoryInfo mem_info;
|
||||||
|
mem_info.set_kb_total((mem_total >= 0) ? mem_total / 1024 : -1);
|
||||||
|
mem_info.set_kb_free((mem_free >= 0) ? mem_free / 1024 : -1);
|
||||||
|
|
||||||
|
Tensor output(DT_STRING, TensorShape({}));
|
||||||
|
output.scalar<tstring>()() = mem_info.SerializeAsString();
|
||||||
|
ctx->set_output(0, output);
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
OP_REQUIRES_OK(ctx, kernel_fn());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
|
#endif // TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
|
||||||
|
@ -228,4 +228,14 @@ Reads the selected metric values from the metrics collection registry.
|
|||||||
'result' is a serialized xrt::MetricsReport proto.
|
'result' is a serialized xrt::MetricsReport proto.
|
||||||
)");
|
)");
|
||||||
|
|
||||||
|
REGISTER_OP("XRTMemoryInfo")
|
||||||
|
.Output("result: string")
|
||||||
|
.SetShapeFn(tensorflow::shape_inference::ScalarShape)
|
||||||
|
.Doc(
|
||||||
|
R"(
|
||||||
|
Returns the memory information of the device this op executes on/
|
||||||
|
|
||||||
|
'result' is a serialized xrt::MemoryInfo proto.
|
||||||
|
)");
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -2206,6 +2206,22 @@ TEST(RawApiTest, TestMetricsFetch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(RawApiTest, TestMemoryInfo) {
|
||||||
|
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
|
||||||
|
Output result = ops::XRTMemoryInfo(root);
|
||||||
|
TF_ASSERT_OK(root.status());
|
||||||
|
|
||||||
|
ClientSession session(root);
|
||||||
|
std::vector<Tensor> outputs;
|
||||||
|
TF_EXPECT_OK(session.Run({result}, &outputs));
|
||||||
|
ASSERT_EQ(outputs.size(), 1);
|
||||||
|
|
||||||
|
xrt::MemoryInfo mem_info;
|
||||||
|
EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &mem_info));
|
||||||
|
EXPECT_GT(mem_info.kb_total(), 0);
|
||||||
|
EXPECT_GT(mem_info.kb_free(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -268,3 +268,10 @@ message MetricValues {
|
|||||||
message MetricsReport {
|
message MetricsReport {
|
||||||
repeated MetricValues metrics = 1;
|
repeated MetricValues metrics = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message MemoryInfo {
|
||||||
|
// The total memory on a device, in KB.
|
||||||
|
int64 kb_total = 1;
|
||||||
|
// The free memory on a device, in KB.
|
||||||
|
int64 kb_free = 2;
|
||||||
|
}
|
||||||
|
@ -348,15 +348,17 @@ double NominalCPUFrequency() {
|
|||||||
return absl::base_internal::NominalCPUFrequency();
|
return absl::base_internal::NominalCPUFrequency();
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 AvailableRam() {
|
MemoryInfo GetMemoryInfo() {
|
||||||
|
MemoryInfo mem_info = {INT64_MAX, INT64_MAX};
|
||||||
#if defined(__linux__) && !defined(__ANDROID__)
|
#if defined(__linux__) && !defined(__ANDROID__)
|
||||||
struct sysinfo info;
|
struct sysinfo info;
|
||||||
int err = sysinfo(&info);
|
int err = sysinfo(&info);
|
||||||
if (err == 0) {
|
if (err == 0) {
|
||||||
return info.freeram;
|
mem_info.free = info.freeram;
|
||||||
|
mem_info.total = info.totalram;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
return INT64_MAX;
|
return mem_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace port
|
} // namespace port
|
||||||
|
@ -59,8 +59,18 @@ void MallocExtension_ReleaseToSystem(std::size_t num_bytes);
|
|||||||
// routine, this routine returns 0.
|
// routine, this routine returns 0.
|
||||||
std::size_t MallocExtension_GetAllocatedSize(const void* p);
|
std::size_t MallocExtension_GetAllocatedSize(const void* p);
|
||||||
|
|
||||||
|
struct MemoryInfo {
|
||||||
|
int64 total = 0;
|
||||||
|
int64 free = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Retrieves the host memory information. If any of the fields in the returned
|
||||||
|
// MemoryInfo structure is INT64_MAX, it means such information is not
|
||||||
|
// available.
|
||||||
|
MemoryInfo GetMemoryInfo();
|
||||||
|
|
||||||
// Returns the amount of RAM available in bytes, or INT64_MAX if unknown.
|
// Returns the amount of RAM available in bytes, or INT64_MAX if unknown.
|
||||||
int64 AvailableRam();
|
static inline int64 AvailableRam() { return GetMemoryInfo().free; }
|
||||||
|
|
||||||
} // namespace port
|
} // namespace port
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
// class declaration].
|
// class declaration].
|
||||||
#include "tensorflow/stream_executor/host/host_gpu_executor.h"
|
#include "tensorflow/stream_executor/host/host_gpu_executor.h"
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
#include "absl/strings/numbers.h"
|
#include "absl/strings/numbers.h"
|
||||||
@ -58,6 +59,13 @@ port::Status HostExecutor::Init(int device_ordinal,
|
|||||||
return port::Status::OK();
|
return port::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HostExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
|
||||||
|
tensorflow::port::MemoryInfo mem_info = tensorflow::port::GetMemoryInfo();
|
||||||
|
*free = (mem_info.free != INT64_MAX) ? mem_info.free : -1;
|
||||||
|
*total = (mem_info.total != INT64_MAX) ? mem_info.total : -1;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
DeviceMemoryBase HostExecutor::Allocate(uint64 size, int64 memory_space) {
|
DeviceMemoryBase HostExecutor::Allocate(uint64 size, int64 memory_space) {
|
||||||
CHECK_EQ(memory_space, 0);
|
CHECK_EQ(memory_space, 0);
|
||||||
// Use a minimum alignment of 64 bytes to be friendly to AVX512 code.
|
// Use a minimum alignment of 64 bytes to be friendly to AVX512 code.
|
||||||
|
@ -130,9 +130,7 @@ class HostExecutor : public internal::StreamExecutorInterface {
|
|||||||
|
|
||||||
int PlatformDeviceCount() override { return 1; }
|
int PlatformDeviceCount() override { return 1; }
|
||||||
|
|
||||||
bool DeviceMemoryUsage(int64 *free, int64 *total) const override {
|
bool DeviceMemoryUsage(int64 *free, int64 *total) const override;
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
port::StatusOr<std::unique_ptr<DeviceDescription>> CreateDeviceDescription()
|
port::StatusOr<std::unique_ptr<DeviceDescription>> CreateDeviceDescription()
|
||||||
const override {
|
const override {
|
||||||
|
Loading…
Reference in New Issue
Block a user