Gets base_layer_test working with run_distributed=True.
Also disables autograph in the tf.function() wrapping the v2 loop. Note that this is not being done because of any explicit bug or test failure atm, but as a conservative measure because we already autograph custom layers implemented by users. We will explore re-enabling this in the future and will use that as an opportunity to simplify built in layers, and the __call__ logic. PiperOrigin-RevId: 258898180
This commit is contained in:
parent
11f1d8ccd8
commit
632ff99d56
@ -301,12 +301,8 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
|||||||
# Cannot access tensor.name in eager execution.
|
# Cannot access tensor.name in eager execution.
|
||||||
self.assertTrue('Variable_2/Regularizer' in layer.losses[0].name)
|
self.assertTrue('Variable_2/Regularizer' in layer.losses[0].name)
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||||
def test_learning_phase_freezing_for_layers(self):
|
def test_learning_phase_freezing_for_layers(self):
|
||||||
# This test is only meant to run in graph functions mode (ambient eager).
|
|
||||||
# In forced eager, `model.predict` ignores the global learning phase
|
|
||||||
# and just uses training=False. TODO(fchollet): consider unifying the
|
|
||||||
# behaviors.
|
|
||||||
|
|
||||||
class LearningPhaseLayer(keras.layers.Layer):
|
class LearningPhaseLayer(keras.layers.Layer):
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
@ -316,7 +312,9 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
def get_learning_phase_value():
|
def get_learning_phase_value():
|
||||||
model = keras.models.Sequential([LearningPhaseLayer(input_shape=(1,))])
|
model = keras.models.Sequential([LearningPhaseLayer(input_shape=(1,))])
|
||||||
return np.sum(model.predict(np.ones((1, 1))))
|
model._run_eagerly = testing_utils.should_run_eagerly()
|
||||||
|
model._run_distributed = testing_utils.should_run_distributed()
|
||||||
|
return np.sum(model(np.ones((1, 1))))
|
||||||
|
|
||||||
self.assertEqual(get_learning_phase_value(), 0)
|
self.assertEqual(get_learning_phase_value(), 0)
|
||||||
|
|
||||||
@ -333,6 +331,41 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
|||||||
keras.backend.set_learning_phase(0)
|
keras.backend.set_learning_phase(0)
|
||||||
self.assertEqual(get_learning_phase_value(), 0)
|
self.assertEqual(get_learning_phase_value(), 0)
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
def test_learning_phase_freezing_for_layers_in_predict(self):
|
||||||
|
if not (testing_utils.should_run_eagerly() or
|
||||||
|
testing_utils.should_run_distributed()):
|
||||||
|
self.skipTest('Predict fails to override the outer learning phase in'
|
||||||
|
'the FuncGraph path.')
|
||||||
|
|
||||||
|
class LearningPhaseLayer(keras.layers.Layer):
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
return keras.backend.in_train_phase(
|
||||||
|
lambda: array_ops.ones_like(inputs),
|
||||||
|
lambda: array_ops.zeros_like(inputs))
|
||||||
|
|
||||||
|
def get_learning_phase_value():
|
||||||
|
model = keras.models.Sequential([LearningPhaseLayer(input_shape=(1,))])
|
||||||
|
model._run_eagerly = testing_utils.should_run_eagerly()
|
||||||
|
model._run_distributed = testing_utils.should_run_distributed()
|
||||||
|
return np.sum(model.predict(np.ones((1, 1))))
|
||||||
|
|
||||||
|
self.assertEqual(get_learning_phase_value(), 0)
|
||||||
|
|
||||||
|
# Test scope.
|
||||||
|
with keras.backend.learning_phase_scope(1):
|
||||||
|
self.assertEqual(get_learning_phase_value(), 0)
|
||||||
|
|
||||||
|
# The effects of the scope end after exiting it.
|
||||||
|
self.assertEqual(get_learning_phase_value(), 0)
|
||||||
|
|
||||||
|
# Test setting.
|
||||||
|
keras.backend.set_learning_phase(1)
|
||||||
|
self.assertEqual(get_learning_phase_value(), 0)
|
||||||
|
keras.backend.set_learning_phase(0)
|
||||||
|
self.assertEqual(get_learning_phase_value(), 0)
|
||||||
|
|
||||||
# Cannot be enabled with `run_eagerly=True`, see b/123904578
|
# Cannot be enabled with `run_eagerly=True`, see b/123904578
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
def test_layer_can_return_variable(self):
|
def test_layer_can_return_variable(self):
|
||||||
@ -846,6 +879,7 @@ class NameScopingTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(layer.kernel.name, 'MyName3/kernel:0')
|
self.assertEqual(layer.kernel.name, 'MyName3/kernel:0')
|
||||||
|
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||||
class AutographControlFlowTest(keras_parameterized.TestCase):
|
class AutographControlFlowTest(keras_parameterized.TestCase):
|
||||||
|
|
||||||
def test_disabling_in_context_is_matched(self):
|
def test_disabling_in_context_is_matched(self):
|
||||||
@ -866,9 +900,7 @@ class AutographControlFlowTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
test_fn()
|
test_fn()
|
||||||
|
|
||||||
@parameterized.named_parameters(('eager', True),
|
def test_if_training_pattern_output(self):
|
||||||
('symbolic', False))
|
|
||||||
def test_if_training_pattern_output(self, eager):
|
|
||||||
|
|
||||||
class MyLayer(keras.layers.Layer):
|
class MyLayer(keras.layers.Layer):
|
||||||
|
|
||||||
@ -880,15 +912,17 @@ class AutographControlFlowTest(keras_parameterized.TestCase):
|
|||||||
inputs = keras.Input((3,))
|
inputs = keras.Input((3,))
|
||||||
outputs = MyLayer()(inputs)
|
outputs = MyLayer()(inputs)
|
||||||
model = keras.Model(inputs, outputs)
|
model = keras.Model(inputs, outputs)
|
||||||
model.compile('sgd', 'mse', run_eagerly=eager)
|
model.compile(
|
||||||
|
'sgd',
|
||||||
|
'mse',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly(),
|
||||||
|
run_distributed=testing_utils.should_run_distributed())
|
||||||
train_loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
train_loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||||
self.assertEqual(train_loss, 0.)
|
self.assertEqual(train_loss, 0.)
|
||||||
test_loss = model.test_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
test_loss = model.test_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||||
self.assertEqual(test_loss, 1.)
|
self.assertEqual(test_loss, 1.)
|
||||||
|
|
||||||
@parameterized.named_parameters(('eager', True),
|
def test_if_training_pattern_loss(self):
|
||||||
('symbolic', False))
|
|
||||||
def test_if_training_pattern_loss(self, eager):
|
|
||||||
|
|
||||||
class MyLayer(keras.layers.Layer):
|
class MyLayer(keras.layers.Layer):
|
||||||
|
|
||||||
@ -903,15 +937,17 @@ class AutographControlFlowTest(keras_parameterized.TestCase):
|
|||||||
inputs = keras.Input((3,))
|
inputs = keras.Input((3,))
|
||||||
outputs = MyLayer()(inputs)
|
outputs = MyLayer()(inputs)
|
||||||
model = keras.Model(inputs, outputs)
|
model = keras.Model(inputs, outputs)
|
||||||
model.compile('sgd', 'mse', run_eagerly=eager)
|
model.compile(
|
||||||
|
'sgd',
|
||||||
|
'mse',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly(),
|
||||||
|
run_distributed=testing_utils.should_run_distributed())
|
||||||
train_loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
train_loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||||
self.assertEqual(train_loss, 2 * 3)
|
self.assertEqual(train_loss, 2 * 3)
|
||||||
test_loss = model.test_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
test_loss = model.test_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||||
self.assertEqual(test_loss, 0)
|
self.assertEqual(test_loss, 0)
|
||||||
|
|
||||||
@parameterized.named_parameters(('eager', True),
|
def test_if_training_pattern_metric(self):
|
||||||
('symbolic', False))
|
|
||||||
def test_if_training_pattern_metric(self, eager):
|
|
||||||
|
|
||||||
class MyLayer(keras.layers.Layer):
|
class MyLayer(keras.layers.Layer):
|
||||||
|
|
||||||
@ -926,7 +962,11 @@ class AutographControlFlowTest(keras_parameterized.TestCase):
|
|||||||
inputs = keras.Input((3,))
|
inputs = keras.Input((3,))
|
||||||
outputs = MyLayer()(inputs)
|
outputs = MyLayer()(inputs)
|
||||||
model = keras.Model(inputs, outputs)
|
model = keras.Model(inputs, outputs)
|
||||||
model.compile('sgd', 'mse', run_eagerly=eager)
|
model.compile(
|
||||||
|
'sgd',
|
||||||
|
'mse',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly(),
|
||||||
|
run_distributed=testing_utils.should_run_distributed())
|
||||||
_, train_metric = model.train_on_batch(np.ones((2, 3)),
|
_, train_metric = model.train_on_batch(np.ones((2, 3)),
|
||||||
np.ones((2, 3)))
|
np.ones((2, 3)))
|
||||||
self.assertEqual(train_metric, 2 * 3)
|
self.assertEqual(train_metric, 2 * 3)
|
||||||
@ -934,9 +974,7 @@ class AutographControlFlowTest(keras_parameterized.TestCase):
|
|||||||
np.ones((2, 3)))
|
np.ones((2, 3)))
|
||||||
self.assertEqual(test_metric, 0)
|
self.assertEqual(test_metric, 0)
|
||||||
|
|
||||||
@parameterized.named_parameters(('eager', True),
|
def test_if_training_pattern_update(self):
|
||||||
('symbolic', False))
|
|
||||||
def test_if_training_pattern_update(self, eager):
|
|
||||||
|
|
||||||
class MyLayer(keras.layers.Layer):
|
class MyLayer(keras.layers.Layer):
|
||||||
|
|
||||||
@ -956,18 +994,21 @@ class AutographControlFlowTest(keras_parameterized.TestCase):
|
|||||||
layer = MyLayer()
|
layer = MyLayer()
|
||||||
outputs = layer(inputs)
|
outputs = layer(inputs)
|
||||||
model = keras.Model(inputs, outputs)
|
model = keras.Model(inputs, outputs)
|
||||||
model.compile('sgd', 'mse', run_eagerly=eager)
|
model.compile(
|
||||||
|
'sgd',
|
||||||
|
'mse',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly(),
|
||||||
|
run_distributed=testing_utils.should_run_distributed())
|
||||||
model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||||
self.assertEqual(keras.backend.get_value(layer.counter), 1.)
|
self.assertEqual(keras.backend.get_value(layer.counter), 1.)
|
||||||
|
|
||||||
@parameterized.named_parameters(('eager', True),
|
def test_conditional_updates_in_call(self):
|
||||||
('symbolic', False))
|
|
||||||
def test_conditional_updates_in_call(self, eager):
|
|
||||||
|
|
||||||
class MyLayer(keras.layers.Layer):
|
class MyLayer(keras.layers.Layer):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MyLayer, self).__init__(dynamic=eager)
|
super(MyLayer,
|
||||||
|
self).__init__(dynamic=testing_utils.should_run_eagerly())
|
||||||
|
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
self.counter = self.add_weight(
|
self.counter = self.add_weight(
|
||||||
@ -982,12 +1023,16 @@ class AutographControlFlowTest(keras_parameterized.TestCase):
|
|||||||
def compute_output_shape(self, input_shape):
|
def compute_output_shape(self, input_shape):
|
||||||
return input_shape
|
return input_shape
|
||||||
|
|
||||||
if eager:
|
if testing_utils.should_run_eagerly():
|
||||||
inputs = keras.Input((3,))
|
inputs = keras.Input((3,))
|
||||||
layer = MyLayer()
|
layer = MyLayer()
|
||||||
outputs = layer(inputs)
|
outputs = layer(inputs)
|
||||||
model = keras.Model(inputs, outputs)
|
model = keras.Model(inputs, outputs)
|
||||||
model.compile('sgd', 'mse', run_eagerly=eager)
|
model.compile(
|
||||||
|
'sgd',
|
||||||
|
'mse',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly(),
|
||||||
|
run_distributed=testing_utils.should_run_distributed())
|
||||||
model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||||
self.assertEqual(keras.backend.get_value(layer.counter), 6.)
|
self.assertEqual(keras.backend.get_value(layer.counter), 6.)
|
||||||
else:
|
else:
|
||||||
@ -998,14 +1043,13 @@ class AutographControlFlowTest(keras_parameterized.TestCase):
|
|||||||
layer(keras.Input((3,)))
|
layer(keras.Input((3,)))
|
||||||
_ = layer.updates
|
_ = layer.updates
|
||||||
|
|
||||||
@parameterized.named_parameters(('eager', True),
|
def test_conditional_losses_in_call(self):
|
||||||
('symbolic', False))
|
|
||||||
def test_conditional_losses_in_call(self, eager):
|
|
||||||
|
|
||||||
class MyLayer(keras.layers.Layer):
|
class MyLayer(keras.layers.Layer):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MyLayer, self).__init__(dynamic=eager)
|
super(MyLayer,
|
||||||
|
self).__init__(dynamic=testing_utils.should_run_eagerly())
|
||||||
|
|
||||||
def call(self, inputs, training=None):
|
def call(self, inputs, training=None):
|
||||||
if training:
|
if training:
|
||||||
@ -1015,12 +1059,16 @@ class AutographControlFlowTest(keras_parameterized.TestCase):
|
|||||||
def compute_output_shape(self, input_shape):
|
def compute_output_shape(self, input_shape):
|
||||||
return input_shape
|
return input_shape
|
||||||
|
|
||||||
if eager:
|
if testing_utils.should_run_eagerly():
|
||||||
inputs = keras.Input((3,))
|
inputs = keras.Input((3,))
|
||||||
layer = MyLayer()
|
layer = MyLayer()
|
||||||
outputs = layer(inputs)
|
outputs = layer(inputs)
|
||||||
model = keras.Model(inputs, outputs)
|
model = keras.Model(inputs, outputs)
|
||||||
model.compile('sgd', 'mse')
|
model.compile(
|
||||||
|
'sgd',
|
||||||
|
'mse',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly(),
|
||||||
|
run_distributed=testing_utils.should_run_distributed())
|
||||||
loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
|
||||||
self.assertEqual(loss, 2 * 3)
|
self.assertEqual(loss, 2 * 3)
|
||||||
else:
|
else:
|
||||||
@ -1028,12 +1076,13 @@ class AutographControlFlowTest(keras_parameterized.TestCase):
|
|||||||
'`add_loss` in a control flow branch'):
|
'`add_loss` in a control flow branch'):
|
||||||
layer = MyLayer()(keras.Input((3,)))
|
layer = MyLayer()(keras.Input((3,)))
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
|
||||||
def test_conditional_callable_losses(self):
|
def test_conditional_callable_losses(self):
|
||||||
model = keras.Sequential([
|
model = keras.Sequential([
|
||||||
keras.layers.Dense(
|
keras.layers.Dense(
|
||||||
1, kernel_regularizer=keras.regularizers.l2(1e-4), input_shape=(1,))
|
1, kernel_regularizer=keras.regularizers.l2(1e-4), input_shape=(1,))
|
||||||
])
|
])
|
||||||
|
model._run_eagerly = testing_utils.should_run_eagerly()
|
||||||
|
model._run_distributed = testing_utils.should_run_distributed()
|
||||||
|
|
||||||
def assert_graph(t):
|
def assert_graph(t):
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
@ -1049,14 +1098,13 @@ class AutographControlFlowTest(keras_parameterized.TestCase):
|
|||||||
assert_graph(get_losses(constant_op.constant(2.)))
|
assert_graph(get_losses(constant_op.constant(2.)))
|
||||||
assert_graph(get_losses(constant_op.constant(0.5)))
|
assert_graph(get_losses(constant_op.constant(0.5)))
|
||||||
|
|
||||||
@parameterized.named_parameters(('eager', True),
|
def test_conditional_metrics_in_call(self):
|
||||||
('symbolic', False))
|
|
||||||
def test_conditional_metrics_in_call(self, eager):
|
|
||||||
|
|
||||||
class MyLayer(keras.layers.Layer):
|
class MyLayer(keras.layers.Layer):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MyLayer, self).__init__(dynamic=eager)
|
super(MyLayer,
|
||||||
|
self).__init__(dynamic=testing_utils.should_run_eagerly())
|
||||||
|
|
||||||
def call(self, inputs, training=None):
|
def call(self, inputs, training=None):
|
||||||
if training:
|
if training:
|
||||||
@ -1068,12 +1116,16 @@ class AutographControlFlowTest(keras_parameterized.TestCase):
|
|||||||
def compute_output_shape(self, input_shape):
|
def compute_output_shape(self, input_shape):
|
||||||
return input_shape
|
return input_shape
|
||||||
|
|
||||||
if eager:
|
if testing_utils.should_run_eagerly():
|
||||||
inputs = keras.Input((3,))
|
inputs = keras.Input((3,))
|
||||||
layer = MyLayer()
|
layer = MyLayer()
|
||||||
outputs = layer(inputs)
|
outputs = layer(inputs)
|
||||||
model = keras.Model(inputs, outputs)
|
model = keras.Model(inputs, outputs)
|
||||||
model.compile('sgd', 'mse')
|
model.compile(
|
||||||
|
'sgd',
|
||||||
|
'mse',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly(),
|
||||||
|
run_distributed=testing_utils.should_run_distributed())
|
||||||
history = model.fit(np.ones((2, 3)), np.ones((2, 3)))
|
history = model.fit(np.ones((2, 3)), np.ones((2, 3)))
|
||||||
self.assertEqual(history.history['sum'][-1], 2 * 3)
|
self.assertEqual(history.history['sum'][-1], 2 * 3)
|
||||||
else:
|
else:
|
||||||
@ -1082,58 +1134,66 @@ class AutographControlFlowTest(keras_parameterized.TestCase):
|
|||||||
'`add_metric` in a control flow branch'):
|
'`add_metric` in a control flow branch'):
|
||||||
layer = MyLayer()(keras.Input((3,)))
|
layer = MyLayer()(keras.Input((3,)))
|
||||||
|
|
||||||
@parameterized.named_parameters(('eager', True), ('symbolic', False))
|
def test_conditional_activity_regularizer_in_call(self):
|
||||||
def test_conditional_activity_regularizer_in_call(self, eager):
|
|
||||||
|
|
||||||
class TestModel(keras.Model):
|
class TestModel(keras.Model):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(TestModel, self).__init__(name='test_model', dynamic=eager)
|
super(TestModel, self).__init__(
|
||||||
|
name='test_model', dynamic=testing_utils.should_run_eagerly())
|
||||||
self.layer = keras.layers.Dense(2, activity_regularizer='l2')
|
self.layer = keras.layers.Dense(2, activity_regularizer='l2')
|
||||||
|
|
||||||
def call(self, x, training=None):
|
def call(self, x, training=None):
|
||||||
if training:
|
if math_ops.greater(math_ops.reduce_sum(x), 0.0):
|
||||||
return self.layer(x)
|
return self.layer(x)
|
||||||
else:
|
else:
|
||||||
return self.layer(x)
|
return self.layer(x)
|
||||||
|
|
||||||
model = TestModel()
|
model = TestModel()
|
||||||
model.compile(loss='mse', optimizer='sgd')
|
model.compile(
|
||||||
|
loss='mse',
|
||||||
|
optimizer='sgd',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly(),
|
||||||
|
run_distributed=testing_utils.should_run_distributed())
|
||||||
|
|
||||||
x = np.ones(shape=(10, 1))
|
x = np.ones(shape=(10, 1))
|
||||||
y = np.ones(shape=(10, 2))
|
y = np.ones(shape=(10, 2))
|
||||||
|
|
||||||
if eager:
|
if testing_utils.should_run_eagerly():
|
||||||
model.fit(x, y, epochs=2, batch_size=5)
|
model.fit(x, y, epochs=2, batch_size=5)
|
||||||
else:
|
else:
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
RuntimeError, '`activity_regularizer` in a control flow branch'):
|
RuntimeError, '`activity_regularizer` in a control flow branch'):
|
||||||
model.fit(x, y, epochs=2, batch_size=5)
|
model.fit(x, y, epochs=2, batch_size=5)
|
||||||
|
|
||||||
@parameterized.named_parameters(('eager', True), ('symbolic', False))
|
def test_conditional_activity_regularizer_with_wrappers_in_call(self):
|
||||||
def test_conditional_activity_regularizer_with_wrappers_in_call(self, eager):
|
|
||||||
|
|
||||||
class TestModel(keras.Model):
|
class TestModel(keras.Model):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(TestModel, self).__init__(name='test_model', dynamic=eager)
|
super(TestModel, self).__init__(
|
||||||
|
name='test_model', dynamic=testing_utils.should_run_eagerly())
|
||||||
self.layer = keras.layers.TimeDistributed(
|
self.layer = keras.layers.TimeDistributed(
|
||||||
keras.layers.Dense(2, activity_regularizer='l2'),
|
keras.layers.Dense(2, activity_regularizer='l2'),
|
||||||
input_shape=(3, 4))
|
input_shape=(3, 4))
|
||||||
|
|
||||||
def call(self, x, training=None):
|
def call(self, x, training=None):
|
||||||
if training:
|
if math_ops.greater(math_ops.reduce_sum(x), 0.0):
|
||||||
return self.layer(x)
|
return self.layer(x)
|
||||||
else:
|
else:
|
||||||
return self.layer(x)
|
return self.layer(x)
|
||||||
|
|
||||||
model = TestModel()
|
model = TestModel()
|
||||||
model.compile(loss='mse', optimizer='sgd')
|
model.compile(
|
||||||
|
loss='mse',
|
||||||
|
optimizer='sgd',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly(),
|
||||||
|
run_distributed=testing_utils.should_run_distributed())
|
||||||
|
|
||||||
x = np.ones(shape=(10, 3, 4))
|
x = np.ones(shape=(10, 3, 4))
|
||||||
y = np.ones(shape=(10, 3, 2))
|
y = np.ones(shape=(10, 3, 2))
|
||||||
|
|
||||||
if eager:
|
if testing_utils.should_run_eagerly():
|
||||||
model.fit(x, y, epochs=2, batch_size=5)
|
model.fit(x, y, epochs=2, batch_size=5)
|
||||||
else:
|
else:
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
|
@ -73,7 +73,9 @@ def _make_execution_function(model, mode):
|
|||||||
if model.run_eagerly:
|
if model.run_eagerly:
|
||||||
execution_function = distributed_function
|
execution_function = distributed_function
|
||||||
else:
|
else:
|
||||||
distributed_function = def_function.function(distributed_function)
|
distributed_function = def_function.function(
|
||||||
|
distributed_function, autograph=False)
|
||||||
|
|
||||||
def execution_function(input_fn):
|
def execution_function(input_fn):
|
||||||
# `numpy` translates Tensors to values in Eager mode.
|
# `numpy` translates Tensors to values in Eager mode.
|
||||||
return [out.numpy() for out in distributed_function(input_fn)]
|
return [out.numpy() for out in distributed_function(input_fn)]
|
||||||
|
Loading…
Reference in New Issue
Block a user