Getting rid of experimental_allow_stateful now that we have experimental_external_state_policy instead.
that option. We've replaced experimental_allow_stateful with experimental_external_state_policy with a default of WARN in which we'll only print a warning message and not fail. As a result, we can now safely remove all client code that explicitly set the allow_stateful_ops flag to True. PiperOrigin-RevId: 276115313 Change-Id: Id1000bd3f404010d609a830a5ab109313e46616b
This commit is contained in:
parent
1c2fa3dd26
commit
45150cfa79
@ -21,7 +21,6 @@ from absl.testing import parameterized
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import distribute
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
@ -30,7 +29,6 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -111,38 +109,6 @@ class ReplicateClusterTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OpError):
|
||||
sess.run(it1.initializer)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph"]))
|
||||
def testAllowStatefulOp(self):
|
||||
with compat.forward_compatibility_horizon(2019, 9, 12):
|
||||
with ops.device(self._device0):
|
||||
dataset0 = dataset_ops.Dataset.range(100).map(
|
||||
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
|
||||
[],
|
||||
minval=1,
|
||||
maxval=10,
|
||||
dtype=dtypes.float32))
|
||||
opt = dataset_ops.Options()
|
||||
opt.experimental_allow_stateful = True
|
||||
dataset0 = dataset0.with_options(opt)
|
||||
replicated_ds = distribute.replicate(dataset0,
|
||||
[self._device1, self._device2])
|
||||
dataset1 = replicated_ds[self._device1]
|
||||
dataset2 = replicated_ds[self._device2]
|
||||
|
||||
with ops.device(self._device0):
|
||||
get_next0 = self.getNext(dataset0)
|
||||
with ops.device(self._device1):
|
||||
get_next1 = self.getNext(dataset1)
|
||||
with ops.device(self._device2):
|
||||
get_next2 = self.getNext(dataset2)
|
||||
|
||||
with session.Session(self._target) as sess:
|
||||
for _ in range(100):
|
||||
sess.run(get_next0())
|
||||
sess.run(get_next1())
|
||||
sess.run(get_next2())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -86,36 +86,6 @@ class LocalReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertDatasetProduces(
|
||||
dataset2, range(201, 301), requires_initialization=True)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
|
||||
def testAllowStatefulOp(self):
|
||||
with ops.device(self._device0):
|
||||
dataset0 = dataset_ops.Dataset.range(100).map(
|
||||
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
|
||||
[],
|
||||
minval=1,
|
||||
maxval=10,
|
||||
dtype=dtypes.float32))
|
||||
opt = dataset_ops.Options()
|
||||
opt.experimental_allow_stateful = True
|
||||
dataset0 = dataset0.with_options(opt)
|
||||
replicated_ds = distribute.replicate(dataset0,
|
||||
[self._device1, self._device2])
|
||||
dataset1 = replicated_ds[self._device1]
|
||||
dataset2 = replicated_ds[self._device2]
|
||||
|
||||
with ops.device(self._device0):
|
||||
get_next0 = self.getNext(dataset0)
|
||||
with ops.device(self._device1):
|
||||
get_next1 = self.getNext(dataset1)
|
||||
with ops.device(self._device2):
|
||||
get_next2 = self.getNext(dataset2)
|
||||
|
||||
for _ in range(100):
|
||||
self.evaluate(get_next0())
|
||||
self.evaluate(get_next1())
|
||||
self.evaluate(get_next2())
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
|
||||
def testExternalStatePolicyIgnore(self):
|
||||
@ -316,37 +286,6 @@ class RemoteReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.evaluate(get_next0())
|
||||
self.evaluate(get_next1())
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[2], mode=["eager"]))
|
||||
def testAllowStatefulOp(self):
|
||||
with compat.forward_compatibility_horizon(2019, 9, 12):
|
||||
with ops.device(self._device0):
|
||||
dataset0 = dataset_ops.Dataset.range(100).map(
|
||||
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
|
||||
[],
|
||||
minval=1,
|
||||
maxval=10,
|
||||
dtype=dtypes.float32))
|
||||
opt = dataset_ops.Options()
|
||||
opt.experimental_allow_stateful = True
|
||||
dataset0 = dataset0.with_options(opt)
|
||||
replicated_ds = distribute.replicate(dataset0,
|
||||
[self._device1, self._device2])
|
||||
dataset1 = replicated_ds[self._device1]
|
||||
dataset2 = replicated_ds[self._device2]
|
||||
|
||||
with ops.device(self._device0):
|
||||
get_next0 = self.getNext(dataset0)
|
||||
with ops.device(self._device1):
|
||||
get_next1 = self.getNext(dataset1)
|
||||
with ops.device(self._device2):
|
||||
get_next2 = self.getNext(dataset2)
|
||||
|
||||
for _ in range(100):
|
||||
self.evaluate(get_next0())
|
||||
self.evaluate(get_next1())
|
||||
self.evaluate(get_next2())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution(
|
||||
|
@ -156,10 +156,8 @@ def replicate(dataset, devices):
|
||||
|
||||
with ops.colocate_with(dataset._variant_tensor):
|
||||
dataset = dataset._apply_options()
|
||||
allow_stateful = dataset.options().experimental_allow_stateful
|
||||
external_state_policy = dataset.options().experimental_external_state_policy
|
||||
graph_def = dataset._as_serialized_graph(
|
||||
allow_stateful=allow_stateful,
|
||||
strip_device_assignment=True,
|
||||
external_state_policy=external_state_policy)
|
||||
for device in devices:
|
||||
|
@ -2448,16 +2448,6 @@ class Options(options_lib.OptionsBase):
|
||||
"`tf.data.experimental.ThreadingOptions` for more details.",
|
||||
default_factory=threading_options.ThreadingOptions)
|
||||
|
||||
experimental_allow_stateful = options_lib.create_option(
|
||||
name="experimental_allow_stateful",
|
||||
ty=bool,
|
||||
docstring="By default, tf.data will refuse to serialize a dataset or "
|
||||
"checkpoint its iterator if the dataset contains a stateful op as the "
|
||||
"serialization / checkpointing won't be able to capture its state. "
|
||||
"Users can -- at their own risk -- override this restriction by "
|
||||
"explicitly specifying that they are fine throwing away the state "
|
||||
"in these ops when they turn this option on.")
|
||||
|
||||
experimental_external_state_policy = options_lib.create_option(
|
||||
name="experimental_external_state_policy",
|
||||
ty=ExternalStatePolicy,
|
||||
|
@ -379,7 +379,8 @@ class TensorLikeDataAdapter(DataAdapter):
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
if self._shuffle:
|
||||
# See b/141490660 for more details.
|
||||
options.experimental_allow_stateful = True
|
||||
options.experimental_external_state_policy = (
|
||||
dataset_ops.ExternalStatePolicy.IGNORE)
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
|
@ -3,10 +3,6 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Options\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "experimental_allow_stateful"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_deterministic"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -3,10 +3,6 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Options\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "experimental_allow_stateful"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_deterministic"
|
||||
mtype: "<type \'property\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user