From 41a576f5051e6e4a1afae4931ad2d38f6568f4aa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 2 Dec 2019 17:34:52 -0800 Subject: [PATCH] Throw an explicit error if user call TPUStrategy experimental_run_v2 in eager mode with a python function. PiperOrigin-RevId: 283453288 Change-Id: I381a61afbaf6cb74ccd1ad1f556d8e5cf3f962f2 --- .../distribute/custom_training_loop_test.py | 5 +- tensorflow/python/distribute/tpu_strategy.py | 60 ++------------ tensorflow/python/distribute/values_test.py | 79 +++---------------- 3 files changed, 21 insertions(+), 123 deletions(-) diff --git a/tensorflow/python/distribute/custom_training_loop_test.py b/tensorflow/python/distribute/custom_training_loop_test.py index e9b283d376c..1db9bff21f0 100644 --- a/tensorflow/python/distribute/custom_training_loop_test.py +++ b/tensorflow/python/distribute/custom_training_loop_test.py @@ -36,8 +36,9 @@ class InputIterationTest(test.TestCase, parameterized.TestCase): @combinations.generate( combinations.combine( - distribution=strategy_combinations.strategies_minus_tpu, - mode=["eager"])) + distribution=strategy_combinations.all_strategies, + mode=["eager"] + )) def testFullEager(self, distribution): dataset = self._get_dataset() diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 8f32e8e2226..2dd4309537a 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -37,7 +37,6 @@ from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver from tensorflow.python.eager import context from tensorflow.python.eager import def_function -from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import device_spec from tensorflow.python.framework import dtypes @@ -83,29 +82,6 @@ def maybe_init_scope(): yield -def validate_experimental_run_function(fn): - """Validate the function passed into strategy.experimental_run_v2.""" - - # We allow three types of functions/objects passed into TPUStrategy - # experimental_run_v2 in eager mode: - # 1. a user annotated tf.function - # 2. a ConcreteFunction, this is mostly what you get from loading a saved - # model. - # 3. a callable object and the `__call__` method itself is a tf.function. - # - # Otherwise we return an error, because we don't support eagerly running - # experimental_run_v2 in TPUStrategy. - - if context.executing_eagerly() and not isinstance( - fn, def_function.Function) and not isinstance( - fn, function.ConcreteFunction) and not (callable(fn) and isinstance( - fn.__call__, def_function.Function)): - raise NotImplementedError( - "TPUStrategy.experimental_run_v2(fn, ...) does not support eager " - "execution. Either convert `fn` into a tf.function or consider " - "calling strategy.experimental_run_v2 inside a tf.function.") - - @tf_export("distribute.experimental.TPUStrategy", v1=[]) class TPUStrategy(distribute_lib.Strategy): """TPU distribution strategy implementation.""" @@ -113,36 +89,14 @@ class TPUStrategy(distribute_lib.Strategy): def __init__(self, tpu_cluster_resolver=None, device_assignment=None): - """Synchronous training in TPU donuts or Pods. - - To construct a TPUStrategy object, you need to run the - initialization code as below: - - ```python - resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu) - tf.config.experimental_connect_to_cluster(resolver) - tf.tpu.experimental.initialize_tpu_system(resolver) - strategy = tf.distribute.experimental.TPUStrategy(resolver) - ``` - - While using distribution strategies, the variables created within strategy's - scope will be replicated across all the replicas and can be kept in sync - using all-reduce algorithms. - - To run TF2 programs on TPUs, you can either use `.compile` and - `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized - training loop by calling `strategy.experimental_run_v2` directly. Note that - TPUStrategy doesn't support pure eager execution, so please make sure the - function passed into `strategy.experimental_run_v2` is a `tf.function` or - `strategy.experimental_run_v2` us called inside a `tf.function` if running - in eager mode. + """Initializes the TPUStrategy object. Args: tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, - which provides information about the TPU cluster. + which provides information about the TPU cluster. device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to - specify the placement of replicas on the TPU cluster. Currently only - supports the usecase of using a single core within a TPU cluster. + specify the placement of replicas on the TPU cluster. Currently only + supports the usecase of using a single core within a TPU cluster. """ super(TPUStrategy, self).__init__(TPUExtended( self, tpu_cluster_resolver, device_assignment=device_assignment)) @@ -157,8 +111,6 @@ class TPUStrategy(distribute_lib.Strategy): # This implementation runs a single step. It does not use infeed or outfeed. def experimental_run_v2(self, fn, args=(), kwargs=None): """See base class.""" - validate_experimental_run_function(fn) - # Note: the target function is converted to graph even when in Eager mode, # so autograph is on by default here. fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx()) @@ -205,8 +157,6 @@ class TPUStrategyV1(distribute_lib.StrategyV1): # This implementation runs a single step. It does not use infeed or outfeed. def experimental_run_v2(self, fn, args=(), kwargs=None): """See base class.""" - validate_experimental_run_function(fn) - fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx()) return self.extended.tpu_run(fn, args, kwargs) @@ -749,7 +699,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. - if result[0] is None or isinstance(result[0], ops.Operation): + if result[0] is None: replicate_outputs = [None] * len(replicate_outputs) else: replicate_outputs = [ diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index 26d0eb3ac32..d97d1155c82 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -818,31 +818,13 @@ class SyncOnReadVariablePropertiesTest(test.TestCase): self.assertEqual(2., self.evaluate(add1(replica_local))) -def mirrored_and_tpu_strategy_combinations(): - return combinations.combine( - distribution=[ - strategy_combinations.mirrored_strategy_with_gpu_and_cpu, - strategy_combinations.tpu_strategy, - ], - mode=["graph", "eager"]) - - -def strategy_and_run_tf_function_combinations(): - # Test the combination of different strategies and whether a tf.function - # is passed into strategy.experimental_run_v2.""" - return combinations.combine( - distribution=[ - strategy_combinations.mirrored_strategy_with_gpu_and_cpu, - ], - mode=["graph", "eager"], - experimental_run_tf_function=[True, False]) + combinations.combine( - distribution=[ - strategy_combinations.tpu_strategy, - ], - mode=["graph", "eager"], - experimental_run_tf_function=[True]) - - +@combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + ], + mode=["graph", "eager"])) class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): def _assign_replica_local(self, v, new): @@ -860,7 +842,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): save_path, _ = self._save_return_saver(sess, var) return save_path - @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution): with self.cached_session() as sess: v, replica_local = _make_replica_local( @@ -881,7 +862,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): saver.restore(sess, save_path) self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) - @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveAndRestoreReplicaLocalMeanOneGraph(self, distribution): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") @@ -998,46 +978,36 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): saver.restore(sess, save_path) self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]])) - @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveReplicaLocalRestoreReplicaLocalMean(self, distribution): save_path = self._save_replica_local_mean(distribution) self._restore_replica_local_mean(save_path, distribution) - @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveReplicaLocalRestoreReplicaLocalSum(self, distribution): save_path = self._save_replica_local_sum(distribution) self._restore_replica_local_sum(save_path, distribution) - @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveReplicaLocalMeanRestoreNormal(self, distribution): save_path = self._save_replica_local_mean(distribution) self._restore_normal(save_path) - @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveReplicaLocalSumRestoreNormal(self, distribution): save_path = self._save_replica_local_sum(distribution) self._restore_normal(save_path) - @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveNormalRestoreReplicaLocalMean(self, distribution): save_path = self._save_normal() self._restore_replica_local_mean(save_path, distribution) - @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testSaveNormalRestoreReplicaLocalSum(self, distribution): save_path = self._save_normal() self._restore_replica_local_sum(save_path, distribution) - @combinations.generate(strategy_and_run_tf_function_combinations()) - def testAssign(self, distribution, experimental_run_tf_function): - + def testAssign(self, distribution): def assign(fn, v, update_value, cross_replica): update_fn = lambda: getattr(v, fn)(update_value) if cross_replica: return update_fn() else: - if experimental_run_tf_function: - update_fn = def_function.function(update_fn) return distribution.experimental_local_results( distribution.experimental_run_v2(update_fn)) updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)] @@ -1063,17 +1033,12 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(self.evaluate(component.read_value()), self.evaluate(array_ops.ones_like(component))) - @combinations.generate(strategy_and_run_tf_function_combinations()) - def testAssignDtypeConversion(self, distribution, - experimental_run_tf_function): - + def testAssignDtypeConversion(self, distribution): def assign(fn, v, update_value, cross_replica): update_fn = lambda: getattr(v, fn)(update_value) if cross_replica: return update_fn() else: - if experimental_run_tf_function: - update_fn = def_function.function(update_fn) return distribution.experimental_local_results( distribution.experimental_run_v2(update_fn)) updates = [("assign", 1), ("assign_add", 1), ("assign_sub", -1)] @@ -1099,7 +1064,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(self.evaluate(component.read_value()), self.evaluate(array_ops.ones_like(component))) - @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testAssignWithAggregationSum(self, distribution): with distribution.scope(): v = variable_scope.variable( @@ -1112,7 +1076,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(self.evaluate(component.read_value()), self.evaluate(array_ops.ones_like(component))) - @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testAssignAddSubWithAggregationSum(self, distribution): with distribution.scope(): v = variable_scope.variable( @@ -1127,9 +1090,7 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): ValueError, "SyncOnReadVariable does not support "): self.evaluate(v.assign_sub(1.)) - @combinations.generate(strategy_and_run_tf_function_combinations()) - def testReadValueInReplicaContext(self, distribution, - experimental_run_tf_function): + def testReadValueInReplicaContext(self, distribution): aggregations = [ variables_lib.VariableAggregation.NONE, variables_lib.VariableAggregation.SUM, @@ -1143,19 +1104,12 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): synchronization=variables_lib.VariableSynchronization.ON_READ, aggregation=aggregation) self.evaluate(variables_lib.global_variables_initializer()) - if experimental_run_tf_function: - read_var_fn = def_function.function(v.read_value) - else: - read_var_fn = v.read_value - results = self.evaluate( - distribution.experimental_local_results( - distribution.experimental_run_v2(read_var_fn))) + results = self.evaluate(distribution.experimental_local_results( + distribution.experimental_run_v2(v.read_value))) for component, value in zip(v._values, results): self.assertAllEqual(self.evaluate(component.read_value()), value) - @combinations.generate(strategy_and_run_tf_function_combinations()) - def testReadValueInCrossReplicaContext(self, distribution, - experimental_run_tf_function): + def testReadValueInCrossReplicaContext(self, distribution): aggregations = [ variables_lib.VariableAggregation.SUM, variables_lib.VariableAggregation.MEAN, @@ -1171,15 +1125,10 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): synchronization=variables_lib.VariableSynchronization.ON_READ, aggregation=aggregation) self.evaluate(variables_lib.global_variables_initializer()) - def assign(v=v): ctx = distribution_strategy_context.get_replica_context() replica_id = ctx.replica_id_in_sync_group return v.assign(math_ops.cast(replica_id, dtypes.float32)) - - if experimental_run_tf_function: - assign = def_function.function(assign) - self.evaluate(distribution.experimental_local_results( distribution.experimental_run_v2(assign))) result = self.evaluate(v.read_value()) @@ -1193,7 +1142,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): expected = 0 self.assertEqual(expected, result, aggregation) - @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testReadValueWithAggregationNoneInCrossReplicaContext(self, distribution): with distribution.scope(): v = variable_scope.variable( @@ -1205,7 +1153,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): ValueError, "Could not convert from .* VariableAggregation\\.NONE"): self.evaluate(v.read_value()) - @combinations.generate(mirrored_and_tpu_strategy_combinations()) def testInitializedToSameValueInsideEagerRun(self, distribution): if not context.executing_eagerly(): self.skipTest("eager only")