Add function tf.config.experimental.get_memory_info.
PiperOrigin-RevId: 353324579 Change-Id: I71bfeb9fd08a6834ef07ccabb359065aec1ba641
This commit is contained in:
parent
6bb7f19c3d
commit
697f117f36
@ -97,6 +97,9 @@
|
|||||||
`tf.while_loop`, and compositions like `tf.foldl`) computed with
|
`tf.while_loop`, and compositions like `tf.foldl`) computed with
|
||||||
`tf.GradientTape` inside a `tf.function`.
|
`tf.GradientTape` inside a `tf.function`.
|
||||||
* Changed the default step size in `gradient_checker_v2.compute_gradients` to be exactly representable as a binary floating point numbers. This avoids poluting gradient approximations needlessly, which is some cases leads to false negatives in op gradient tests.
|
* Changed the default step size in `gradient_checker_v2.compute_gradients` to be exactly representable as a binary floating point numbers. This avoids poluting gradient approximations needlessly, which is some cases leads to false negatives in op gradient tests.
|
||||||
|
* Added `tf.config.experimental.get_memory_info`, returning a dict with the
|
||||||
|
current and peak memory usage. Deprecated
|
||||||
|
`tf.config.experimental.get_memory_usage` in favor of this new function.
|
||||||
|
|
||||||
* `tf.summary`:
|
* `tf.summary`:
|
||||||
* New `tf.summary.graph` allows manual write of TensorFlow graph
|
* New `tf.summary.graph` allows manual write of TensorFlow graph
|
||||||
|
|||||||
@ -1438,11 +1438,16 @@ class Context(object):
|
|||||||
|
|
||||||
self._visible_device_list = visible_device_list
|
self._visible_device_list = visible_device_list
|
||||||
|
|
||||||
def get_total_memory_usage(self, dev):
|
def get_memory_info(self, dev):
|
||||||
"""Returns total memory usage in bytes for the current device."""
|
"""Returns a dict of memory info for the device."""
|
||||||
self._initialize_physical_devices()
|
self._initialize_physical_devices()
|
||||||
self.ensure_initialized()
|
self.ensure_initialized()
|
||||||
return pywrap_tfe.TFE_GetTotalMemoryUsage(self._context_handle, dev)
|
return pywrap_tfe.TFE_GetMemoryInfo(self._context_handle, dev)
|
||||||
|
|
||||||
|
# TODO(reedwm): Remove this function
|
||||||
|
def get_total_memory_usage(self, dev):
|
||||||
|
"""Returns total memory usage in bytes for the current device."""
|
||||||
|
return self.get_memory_info(dev)["current"]
|
||||||
|
|
||||||
def get_memory_growth(self, dev):
|
def get_memory_growth(self, dev):
|
||||||
"""Get if memory growth is enabled for a PhysicalDevice."""
|
"""Get if memory growth is enabled for a PhysicalDevice."""
|
||||||
|
|||||||
@ -510,9 +510,58 @@ def set_visible_devices(devices, device_type=None):
|
|||||||
context.context().set_visible_devices(devices, device_type)
|
context.context().set_visible_devices(devices, device_type)
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('config.experimental.get_memory_info')
|
||||||
|
def get_memory_info(device):
|
||||||
|
"""Get memory info for the chosen device, as a dict.
|
||||||
|
|
||||||
|
This function returns a dict containing information about the device's memory
|
||||||
|
usage. For example:
|
||||||
|
|
||||||
|
>>> if tf.config.list_physical_devices('GPU'):
|
||||||
|
... # Returns a dict in the form {'current': <current mem usage>,
|
||||||
|
... # 'peak': <peak mem usage>}
|
||||||
|
... tf.config.experimental.get_memory_info('GPU:0')
|
||||||
|
|
||||||
|
Currently returns the following keys:
|
||||||
|
`'current'`: The current memory used by the device, in bytes.
|
||||||
|
`'peak'`: The peak memory used by the device across the run of the program,
|
||||||
|
in bytes.
|
||||||
|
|
||||||
|
More keys may be added in the future, including device-specific keys.
|
||||||
|
|
||||||
|
Currently raises an exception for the CPU.
|
||||||
|
|
||||||
|
For GPUs, TensorFlow will allocate all the memory by default, unless changed
|
||||||
|
with `tf.config.experimental.set_memory_growth`. The dict specifies only the
|
||||||
|
current and peak memory that TensorFlow is actually using, not the memory that
|
||||||
|
TensorFlow has allocated on the GPU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: Device string to get the memory information for, e.g. `"GPU:0"`. See
|
||||||
|
https://www.tensorflow.org/api_docs/python/tf/device for specifying device
|
||||||
|
strings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict with keys `'current'` and `'peak'`, specifying the current and peak
|
||||||
|
memory usage respectively.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Non-existent or CPU device specified.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return context.context().get_memory_info(device)
|
||||||
|
|
||||||
|
|
||||||
|
@deprecation.deprecated(
|
||||||
|
None,
|
||||||
|
"Use tf.config.experimental.get_memory_info(device)['current'] instead.")
|
||||||
@tf_export('config.experimental.get_memory_usage')
|
@tf_export('config.experimental.get_memory_usage')
|
||||||
def get_memory_usage(device):
|
def get_memory_usage(device):
|
||||||
"""Get the memory usage, in bytes, for the chosen device.
|
"""Get the current memory usage, in bytes, for the chosen device.
|
||||||
|
|
||||||
|
This function is deprecated in favor of
|
||||||
|
`tf.config.experimental.get_memory_info`. Calling this function is equivalent
|
||||||
|
to calling `tf.config.experimental.get_memory_info()['current']`.
|
||||||
|
|
||||||
See https://www.tensorflow.org/api_docs/python/tf/device for specifying device
|
See https://www.tensorflow.org/api_docs/python/tf/device for specifying device
|
||||||
strings.
|
strings.
|
||||||
@ -525,8 +574,13 @@ def get_memory_usage(device):
|
|||||||
|
|
||||||
Does not work for CPU.
|
Does not work for CPU.
|
||||||
|
|
||||||
|
For GPUs, TensorFlow will allocate all the memory by default, unless changed
|
||||||
|
with `tf.config.experimental.set_memory_growth`. This function only returns
|
||||||
|
the memory that TensorFlow is actually using, not the memory that TensorFlow
|
||||||
|
has allocated on the GPU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device: Device string to get the bytes in use for.
|
device: Device string to get the bytes in use for, e.g. `"GPU:0"`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Total memory usage in bytes.
|
Total memory usage in bytes.
|
||||||
@ -534,7 +588,7 @@ def get_memory_usage(device):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: Non-existent or CPU device specified.
|
ValueError: Non-existent or CPU device specified.
|
||||||
"""
|
"""
|
||||||
return context.context().get_total_memory_usage(device)
|
return get_memory_info(device)['current']
|
||||||
|
|
||||||
|
|
||||||
@tf_export('config.experimental.get_memory_growth')
|
@tf_export('config.experimental.get_memory_growth')
|
||||||
|
|||||||
@ -599,25 +599,48 @@ class DeviceTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
@reset_eager
|
@reset_eager
|
||||||
def testGetMemoryUsage(self):
|
def testGetMemoryInfoBasic(self):
|
||||||
device = array_ops.zeros([]).backing_device
|
device = array_ops.zeros([]).backing_device
|
||||||
self.assertGreater(config.get_memory_usage(device), 0)
|
info = config.get_memory_info(device)
|
||||||
|
self.assertGreater(info['current'], 0)
|
||||||
|
self.assertGreater(info['peak'], 0)
|
||||||
|
self.assertEqual(info.keys(), {'current', 'peak'})
|
||||||
|
self.assertEqual(config.get_memory_usage(device), info['current'])
|
||||||
|
|
||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
@reset_eager
|
@reset_eager
|
||||||
def testGetMemoryUsageSubstring(self):
|
def testGetMemoryUsageSubstring(self):
|
||||||
self.assertGreater(config.get_memory_usage('GPU:0'), 0)
|
info = config.get_memory_info('GPU:0')
|
||||||
|
self.assertGreater(info['current'], 0)
|
||||||
|
|
||||||
@reset_eager
|
@reset_eager
|
||||||
def testGetMemoryUsageCPU(self):
|
def testGetMemoryInfoCPU(self):
|
||||||
|
with self.assertRaisesRegex(ValueError, 'CPU does not support'):
|
||||||
|
config.get_memory_info('CPU:0')
|
||||||
with self.assertRaisesRegex(ValueError, 'CPU does not support'):
|
with self.assertRaisesRegex(ValueError, 'CPU does not support'):
|
||||||
config.get_memory_usage('CPU:0')
|
config.get_memory_usage('CPU:0')
|
||||||
|
|
||||||
@reset_eager
|
@reset_eager
|
||||||
def testGetMemoryUsageUnknownDevice(self):
|
def testGetMemoryInfoUnknownDevice(self):
|
||||||
|
with self.assertRaisesRegex(ValueError, 'Failed parsing device name'):
|
||||||
|
config.get_memory_info('unknown_device')
|
||||||
with self.assertRaisesRegex(ValueError, 'Failed parsing device name'):
|
with self.assertRaisesRegex(ValueError, 'Failed parsing device name'):
|
||||||
config.get_memory_usage('unknown_device')
|
config.get_memory_usage('unknown_device')
|
||||||
|
|
||||||
|
@test_util.run_gpu_only
|
||||||
|
@reset_eager
|
||||||
|
def testPeakMemoryUsage(self):
|
||||||
|
x1 = array_ops.zeros((1000, 1000))
|
||||||
|
peak1 = config.get_memory_info('GPU:0')['peak']
|
||||||
|
self.assertGreaterEqual(peak1, 4 * 1000 * 1000)
|
||||||
|
x2 = array_ops.ones((1000, 1000))
|
||||||
|
peak2 = config.get_memory_info('GPU:0')['peak']
|
||||||
|
self.assertGreaterEqual(peak2, peak1 + 4 * 1000 * 1000)
|
||||||
|
del x1, x2 # With CPython, causes tensor memory to be immediately freed
|
||||||
|
peak3 = config.get_memory_info('GPU:0')['peak']
|
||||||
|
self.assertGreaterEqual(peak3, peak2)
|
||||||
|
self.assertGreaterEqual(peak3, config.get_memory_info('GPU:0')['current'])
|
||||||
|
|
||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
@reset_eager
|
@reset_eager
|
||||||
def testGetMemoryUsageAmbiguousDevice(self):
|
def testGetMemoryUsageAmbiguousDevice(self):
|
||||||
|
|||||||
@ -517,7 +517,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"TFE_GetTotalMemoryUsage", [](py::handle& ctx, const char* device_name) {
|
"TFE_GetMemoryInfo", [](py::handle& ctx, const char* device_name) {
|
||||||
tensorflow::EagerContext* context = tensorflow::ContextFromInterface(
|
tensorflow::EagerContext* context = tensorflow::ContextFromInterface(
|
||||||
reinterpret_cast<tensorflow::ImmediateExecutionContext*>(
|
reinterpret_cast<tensorflow::ImmediateExecutionContext*>(
|
||||||
tensorflow::InputTFE_Context(ctx)));
|
tensorflow::InputTFE_Context(ctx)));
|
||||||
@ -568,7 +568,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
|
|
||||||
if (absl::optional<tensorflow::AllocatorStats> stats =
|
if (absl::optional<tensorflow::AllocatorStats> stats =
|
||||||
allocator->GetStats()) {
|
allocator->GetStats()) {
|
||||||
return stats->bytes_in_use;
|
return std::map<std::string, int64_t>{
|
||||||
|
{"current", stats->bytes_in_use},
|
||||||
|
{"peak", stats->peak_bytes_in_use}};
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::ThrowTypeError(
|
tensorflow::ThrowTypeError(
|
||||||
|
|||||||
@ -40,6 +40,10 @@ tf_module {
|
|||||||
name: "get_memory_growth"
|
name: "get_memory_growth"
|
||||||
argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_memory_info"
|
||||||
|
argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_memory_usage"
|
name: "get_memory_usage"
|
||||||
argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
|||||||
@ -40,6 +40,10 @@ tf_module {
|
|||||||
name: "get_memory_growth"
|
name: "get_memory_growth"
|
||||||
argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_memory_info"
|
||||||
|
argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_memory_usage"
|
name: "get_memory_usage"
|
||||||
argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user