PSv2: Update in keras training_v1 message regarding TF2 ParameterServerStrategy support.

PiperOrigin-RevId: 339366494
Change-Id: I69cd1f7bd4fac4555d3859943323a4983ee3058b
This commit is contained in:
Rick Chao 2020-10-27 17:40:53 -07:00 committed by TensorFlower Gardener
parent f6ba0dfbbc
commit f0e5b4e28d
2 changed files with 11 additions and 3 deletions

View File

@ -54,6 +54,7 @@ py_library(
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/distribute:parameter_server_strategy",
"//tensorflow/python/distribute:parameter_server_strategy_v2",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/eager:monitoring",
"//tensorflow/python/keras:activations",

View File

@ -27,6 +27,7 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
@ -360,9 +361,15 @@ class Model(training_lib.Model):
if isinstance(self._distribution_strategy,
parameter_server_strategy.ParameterServerStrategyV1):
raise NotImplementedError('`tf.compat.v1.distribute.experimental.Paramet'
'erServerStrategy` currently only works '
'with the tf.Estimator API')
raise NotImplementedError(
'`tf.compat.v1.distribute.experimental.ParameterServerStrategy` '
'currently only works with the tf.Estimator API')
if isinstance(self._distribution_strategy,
parameter_server_strategy_v2.ParameterServerStrategyV2):
raise NotImplementedError(
'`tf.distribute.experimental.ParameterServerStrategy` is only '
'supported in TF2.')
if not self._experimental_run_tf_function:
self._validate_compile_param_for_distribution_strategy(self.run_eagerly,