Getting rid of experimental_allow_stateful now that we have experimental_external_state_policy instead.

that option.

We've replaced experimental_allow_stateful with experimental_external_state_policy with a default of WARN in which we'll only print a warning message and not fail. As a result, we can now safely remove all client code that explicitly set the allow_stateful_ops flag to True.

PiperOrigin-RevId: 276115313
Change-Id: Id1000bd3f404010d609a830a5ab109313e46616b
This commit is contained in:
Rohan Jain 2019-10-22 12:05:29 -07:00 committed by TensorFlower Gardener
parent 1c2fa3dd26
commit 45150cfa79
7 changed files with 2 additions and 116 deletions

View File

@ -21,7 +21,6 @@ from absl.testing import parameterized
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import distribute
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
@ -30,7 +29,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
@ -111,38 +109,6 @@ class ReplicateClusterTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OpError):
sess.run(it1.initializer)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testAllowStatefulOp(self):
with compat.forward_compatibility_horizon(2019, 9, 12):
with ops.device(self._device0):
dataset0 = dataset_ops.Dataset.range(100).map(
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
[],
minval=1,
maxval=10,
dtype=dtypes.float32))
opt = dataset_ops.Options()
opt.experimental_allow_stateful = True
dataset0 = dataset0.with_options(opt)
replicated_ds = distribute.replicate(dataset0,
[self._device1, self._device2])
dataset1 = replicated_ds[self._device1]
dataset2 = replicated_ds[self._device2]
with ops.device(self._device0):
get_next0 = self.getNext(dataset0)
with ops.device(self._device1):
get_next1 = self.getNext(dataset1)
with ops.device(self._device2):
get_next2 = self.getNext(dataset2)
with session.Session(self._target) as sess:
for _ in range(100):
sess.run(get_next0())
sess.run(get_next1())
sess.run(get_next2())
if __name__ == "__main__":
test.main()

View File

@ -86,36 +86,6 @@ class LocalReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(
dataset2, range(201, 301), requires_initialization=True)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
def testAllowStatefulOp(self):
with ops.device(self._device0):
dataset0 = dataset_ops.Dataset.range(100).map(
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
[],
minval=1,
maxval=10,
dtype=dtypes.float32))
opt = dataset_ops.Options()
opt.experimental_allow_stateful = True
dataset0 = dataset0.with_options(opt)
replicated_ds = distribute.replicate(dataset0,
[self._device1, self._device2])
dataset1 = replicated_ds[self._device1]
dataset2 = replicated_ds[self._device2]
with ops.device(self._device0):
get_next0 = self.getNext(dataset0)
with ops.device(self._device1):
get_next1 = self.getNext(dataset1)
with ops.device(self._device2):
get_next2 = self.getNext(dataset2)
for _ in range(100):
self.evaluate(get_next0())
self.evaluate(get_next1())
self.evaluate(get_next2())
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
def testExternalStatePolicyIgnore(self):
@ -316,37 +286,6 @@ class RemoteReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
self.evaluate(get_next0())
self.evaluate(get_next1())
@combinations.generate(
combinations.combine(tf_api_version=[2], mode=["eager"]))
def testAllowStatefulOp(self):
with compat.forward_compatibility_horizon(2019, 9, 12):
with ops.device(self._device0):
dataset0 = dataset_ops.Dataset.range(100).map(
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
[],
minval=1,
maxval=10,
dtype=dtypes.float32))
opt = dataset_ops.Options()
opt.experimental_allow_stateful = True
dataset0 = dataset0.with_options(opt)
replicated_ds = distribute.replicate(dataset0,
[self._device1, self._device2])
dataset1 = replicated_ds[self._device1]
dataset2 = replicated_ds[self._device2]
with ops.device(self._device0):
get_next0 = self.getNext(dataset0)
with ops.device(self._device1):
get_next1 = self.getNext(dataset1)
with ops.device(self._device2):
get_next2 = self.getNext(dataset2)
for _ in range(100):
self.evaluate(get_next0())
self.evaluate(get_next1())
self.evaluate(get_next2())
if __name__ == "__main__":
ops.enable_eager_execution(

View File

@ -156,10 +156,8 @@ def replicate(dataset, devices):
with ops.colocate_with(dataset._variant_tensor):
dataset = dataset._apply_options()
allow_stateful = dataset.options().experimental_allow_stateful
external_state_policy = dataset.options().experimental_external_state_policy
graph_def = dataset._as_serialized_graph(
allow_stateful=allow_stateful,
strip_device_assignment=True,
external_state_policy=external_state_policy)
for device in devices:

View File

@ -2448,16 +2448,6 @@ class Options(options_lib.OptionsBase):
"`tf.data.experimental.ThreadingOptions` for more details.",
default_factory=threading_options.ThreadingOptions)
experimental_allow_stateful = options_lib.create_option(
name="experimental_allow_stateful",
ty=bool,
docstring="By default, tf.data will refuse to serialize a dataset or "
"checkpoint its iterator if the dataset contains a stateful op as the "
"serialization / checkpointing won't be able to capture its state. "
"Users can -- at their own risk -- override this restriction by "
"explicitly specifying that they are fine throwing away the state "
"in these ops when they turn this option on.")
experimental_external_state_policy = options_lib.create_option(
name="experimental_external_state_policy",
ty=ExternalStatePolicy,

View File

@ -379,7 +379,8 @@ class TensorLikeDataAdapter(DataAdapter):
options.experimental_optimization.apply_default_optimizations = False
if self._shuffle:
# See b/141490660 for more details.
options.experimental_allow_stateful = True
options.experimental_external_state_policy = (
dataset_ops.ExternalStatePolicy.IGNORE)
dataset = dataset.with_options(options)
return dataset

View File

@ -3,10 +3,6 @@ tf_class {
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Options\'>"
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "experimental_allow_stateful"
mtype: "<type \'property\'>"
}
member {
name: "experimental_deterministic"
mtype: "<type \'property\'>"

View File

@ -3,10 +3,6 @@ tf_class {
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Options\'>"
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "experimental_allow_stateful"
mtype: "<type \'property\'>"
}
member {
name: "experimental_deterministic"
mtype: "<type \'property\'>"