Clear eager kernel cache when resetting random seed.

"big hammer" required for reproducibility.

PiperOrigin-RevId: 180961787
This commit is contained in:
Alexandre Passos 2018-01-05 12:38:20 -08:00 committed by TensorFlower Gardener
parent 388c9b7331
commit a17038297c
5 changed files with 24 additions and 0 deletions

View File

@ -110,6 +110,10 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
return TF_SessionListDevices(ctx->session, status); return TF_SessionListDevices(ctx->session, status);
} }
void TFE_ContextClearCaches(TFE_Context* ctx) {
tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
}
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
tensorflow::Tensor tensor; tensorflow::Tensor tensor;
status->status = tensorflow::TF_TensorToTensor(t, &tensor); status->status = tensorflow::TF_TensorToTensor(t, &tensor);

View File

@ -89,6 +89,10 @@ TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status
TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
TF_Status* status); TF_Status* status);
// Clears the internal caches in the TFE context. Useful when reseeding random
// ops.
TF_CAPI_EXPORT extern void TFE_ContextClearCaches(TFE_Context* ctx);
// A handle to a tensor on a device. // A handle to a tensor on a device.
// //
// Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape, // Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape,

View File

@ -133,6 +133,9 @@ class Context(object):
"""Set a global eager mode seed for random ops.""" """Set a global eager mode seed for random ops."""
self._seed = seed self._seed = seed
self._rng = random.Random(self._seed) self._rng = random.Random(self._seed)
# Also clear the kernel cache, to reset any existing seeds
if self._context_handle is not None:
pywrap_tensorflow.TFE_ContextClearCaches(self._context_handle)
def _internal_operation_seed(self): def _internal_operation_seed(self):
"""Returns a fake operation seed. """Returns a fake operation seed.

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -174,6 +175,17 @@ class TruncatedNormalTest(test.TestCase):
diff = rnd2 - rnd1 diff = rnd2 - rnd1
self.assertTrue(np.linalg.norm(diff.eval()) > 0.1) self.assertTrue(np.linalg.norm(diff.eval()) > 0.1)
def testEagerSeed(self):
with context.eager_mode():
# Ensure a context has been created
random_ops.random_normal([])
# Set the same seed twice and check that the values match
context.set_global_seed(42)
rnd1 = random_ops.random_normal([])
context.set_global_seed(42)
rnd2 = random_ops.random_normal([])
self.assertAllEqual(rnd1, rnd2)
class RandomUniformTest(test.TestCase): class RandomUniformTest(test.TestCase):

View File

@ -20,6 +20,7 @@ limitations under the License.
%rename("%s") TFE_ContextListDevices; %rename("%s") TFE_ContextListDevices;
%rename("%s") TFE_ContextAddFunction; %rename("%s") TFE_ContextAddFunction;
%rename("%s") TFE_ContextAddFunctionDef; %rename("%s") TFE_ContextAddFunctionDef;
%rename("%s") TFE_ContextClearCaches;
%rename("%s") TFE_OpNameGetAttrType; %rename("%s") TFE_OpNameGetAttrType;
%rename("%s") TFE_Py_InitEagerTensor; %rename("%s") TFE_Py_InitEagerTensor;
%rename("%s") TFE_Py_RegisterExceptionClass; %rename("%s") TFE_Py_RegisterExceptionClass;