Clear eager kernel cache when resetting random seed.
"big hammer" required for reproducibility. PiperOrigin-RevId: 180961787
This commit is contained in:
parent
388c9b7331
commit
a17038297c
@ -110,6 +110,10 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
||||
return TF_SessionListDevices(ctx->session, status);
|
||||
}
|
||||
|
||||
void TFE_ContextClearCaches(TFE_Context* ctx) {
|
||||
tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
||||
tensorflow::Tensor tensor;
|
||||
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
||||
|
@ -89,6 +89,10 @@ TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status
|
||||
TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// Clears the internal caches in the TFE context. Useful when reseeding random
|
||||
// ops.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextClearCaches(TFE_Context* ctx);
|
||||
|
||||
// A handle to a tensor on a device.
|
||||
//
|
||||
// Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape,
|
||||
|
@ -133,6 +133,9 @@ class Context(object):
|
||||
"""Set a global eager mode seed for random ops."""
|
||||
self._seed = seed
|
||||
self._rng = random.Random(self._seed)
|
||||
# Also clear the kernel cache, to reset any existing seeds
|
||||
if self._context_handle is not None:
|
||||
pywrap_tensorflow.TFE_ContextClearCaches(self._context_handle)
|
||||
|
||||
def _internal_operation_seed(self):
|
||||
"""Returns a fake operation seed.
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -174,6 +175,17 @@ class TruncatedNormalTest(test.TestCase):
|
||||
diff = rnd2 - rnd1
|
||||
self.assertTrue(np.linalg.norm(diff.eval()) > 0.1)
|
||||
|
||||
def testEagerSeed(self):
|
||||
with context.eager_mode():
|
||||
# Ensure a context has been created
|
||||
random_ops.random_normal([])
|
||||
# Set the same seed twice and check that the values match
|
||||
context.set_global_seed(42)
|
||||
rnd1 = random_ops.random_normal([])
|
||||
context.set_global_seed(42)
|
||||
rnd2 = random_ops.random_normal([])
|
||||
self.assertAllEqual(rnd1, rnd2)
|
||||
|
||||
|
||||
class RandomUniformTest(test.TestCase):
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
%rename("%s") TFE_ContextListDevices;
|
||||
%rename("%s") TFE_ContextAddFunction;
|
||||
%rename("%s") TFE_ContextAddFunctionDef;
|
||||
%rename("%s") TFE_ContextClearCaches;
|
||||
%rename("%s") TFE_OpNameGetAttrType;
|
||||
%rename("%s") TFE_Py_InitEagerTensor;
|
||||
%rename("%s") TFE_Py_RegisterExceptionClass;
|
||||
|
Loading…
x
Reference in New Issue
Block a user