Clear eager kernel cache when resetting random seed.
"big hammer" required for reproducibility. PiperOrigin-RevId: 180961787
This commit is contained in:
parent
388c9b7331
commit
a17038297c
@ -110,6 +110,10 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
|||||||
return TF_SessionListDevices(ctx->session, status);
|
return TF_SessionListDevices(ctx->session, status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TFE_ContextClearCaches(TFE_Context* ctx) {
|
||||||
|
tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
|
||||||
|
}
|
||||||
|
|
||||||
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
||||||
tensorflow::Tensor tensor;
|
tensorflow::Tensor tensor;
|
||||||
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
||||||
|
@ -89,6 +89,10 @@ TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status
|
|||||||
TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
|
TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Clears the internal caches in the TFE context. Useful when reseeding random
|
||||||
|
// ops.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_ContextClearCaches(TFE_Context* ctx);
|
||||||
|
|
||||||
// A handle to a tensor on a device.
|
// A handle to a tensor on a device.
|
||||||
//
|
//
|
||||||
// Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape,
|
// Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape,
|
||||||
|
@ -133,6 +133,9 @@ class Context(object):
|
|||||||
"""Set a global eager mode seed for random ops."""
|
"""Set a global eager mode seed for random ops."""
|
||||||
self._seed = seed
|
self._seed = seed
|
||||||
self._rng = random.Random(self._seed)
|
self._rng = random.Random(self._seed)
|
||||||
|
# Also clear the kernel cache, to reset any existing seeds
|
||||||
|
if self._context_handle is not None:
|
||||||
|
pywrap_tensorflow.TFE_ContextClearCaches(self._context_handle)
|
||||||
|
|
||||||
def _internal_operation_seed(self):
|
def _internal_operation_seed(self):
|
||||||
"""Returns a fake operation seed.
|
"""Returns a fake operation seed.
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -174,6 +175,17 @@ class TruncatedNormalTest(test.TestCase):
|
|||||||
diff = rnd2 - rnd1
|
diff = rnd2 - rnd1
|
||||||
self.assertTrue(np.linalg.norm(diff.eval()) > 0.1)
|
self.assertTrue(np.linalg.norm(diff.eval()) > 0.1)
|
||||||
|
|
||||||
|
def testEagerSeed(self):
|
||||||
|
with context.eager_mode():
|
||||||
|
# Ensure a context has been created
|
||||||
|
random_ops.random_normal([])
|
||||||
|
# Set the same seed twice and check that the values match
|
||||||
|
context.set_global_seed(42)
|
||||||
|
rnd1 = random_ops.random_normal([])
|
||||||
|
context.set_global_seed(42)
|
||||||
|
rnd2 = random_ops.random_normal([])
|
||||||
|
self.assertAllEqual(rnd1, rnd2)
|
||||||
|
|
||||||
|
|
||||||
class RandomUniformTest(test.TestCase):
|
class RandomUniformTest(test.TestCase):
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
%rename("%s") TFE_ContextListDevices;
|
%rename("%s") TFE_ContextListDevices;
|
||||||
%rename("%s") TFE_ContextAddFunction;
|
%rename("%s") TFE_ContextAddFunction;
|
||||||
%rename("%s") TFE_ContextAddFunctionDef;
|
%rename("%s") TFE_ContextAddFunctionDef;
|
||||||
|
%rename("%s") TFE_ContextClearCaches;
|
||||||
%rename("%s") TFE_OpNameGetAttrType;
|
%rename("%s") TFE_OpNameGetAttrType;
|
||||||
%rename("%s") TFE_Py_InitEagerTensor;
|
%rename("%s") TFE_Py_InitEagerTensor;
|
||||||
%rename("%s") TFE_Py_RegisterExceptionClass;
|
%rename("%s") TFE_Py_RegisterExceptionClass;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user