eager: Rename in_eager_mode to executing_eagerly and get rid of in_graph_mode.

This is in preparation to introduce one public, stable symbol: tf.executing_eagerly()
(i.e., part of moving APIs related to eager execution from "contrib" to a namespace
where we provide API stability guarantees)

PiperOrigin-RevId: 188212646
This commit is contained in:
Asim Shankar 2018-03-07 12:03:56 -08:00 committed by TensorFlower Gardener
parent 808b569e85
commit 37cef895bf
110 changed files with 790 additions and 854 deletions
tensorflow
contrib
python

View File

@ -44,7 +44,7 @@ class PrivateThreadPool(object):
def __init__(self, num_threads, display_name=None):
"""Creates a `PrivateThreadPool` with the given number of threads."""
if context.in_eager_mode():
if context.executing_eagerly():
shared_name = _generate_shared_name("privatethreadpool")
self._resource = gen_dataset_ops.thread_pool_handle(
num_threads=num_threads,

View File

@ -395,7 +395,7 @@ class CheckpointLoadStatus(_LoadStatus):
def run_restore_ops(self, session=None):
"""Run operations to restore objects in the dependency graph."""
if context.in_eager_mode():
if context.executing_eagerly():
return # Run eagerly
if session is None:
session = ops.get_default_session()
@ -459,7 +459,7 @@ class InitializationOnlyStatus(_LoadStatus):
session: The session to run initialization ops in. If `None`, uses the
default session.
"""
if context.in_eager_mode():
if context.executing_eagerly():
return # run eagerly
if session is None:
session = ops.get_default_session()
@ -491,7 +491,7 @@ class NameBasedSaverStatus(_LoadStatus):
date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS)
def run_restore_ops(self, session=None):
"""Load the name-based training checkpoint using a new `tf.train.Saver`."""
if session is None and context.in_graph_mode():
if session is None and not context.executing_eagerly():
session = ops.get_default_session()
saver_lib.Saver(self._object_saver._global_variable_names()).restore( # pylint: disable=protected-access
sess=session, save_path=self._save_path)
@ -548,7 +548,7 @@ class CheckpointableSaver(object):
# Allow passing in a weak reference to avoid reference cycles when
# `Checkpointable` objects save themselves.
self._root_checkpointable_ref = root_checkpointable
if context.in_graph_mode():
if not context.executing_eagerly():
with ops.device("/cpu:0"):
self._file_prefix_placeholder = constant_op.constant("model")
else:
@ -597,7 +597,7 @@ class CheckpointableSaver(object):
"""
named_variables, graph_proto = _serialize_object_graph(
self._root_checkpointable)
in_graph_mode = context.in_graph_mode()
in_graph_mode = not context.executing_eagerly()
if in_graph_mode:
if session is None:
session = ops.get_default_session()
@ -714,7 +714,7 @@ class CheckpointableSaver(object):
"""
if save_path is None:
return InitializationOnlyStatus(self._root_checkpointable)
in_graph_mode = context.in_graph_mode()
in_graph_mode = not context.executing_eagerly()
if in_graph_mode:
if session is None:
session = ops.get_default_session()
@ -850,7 +850,7 @@ class Checkpoint(core_checkpointable.Checkpointable):
def save(self, file_prefix, session=None):
"""Save a checkpoint. Wraps `tfe.CheckpointableSaver.save`."""
in_graph_mode = context.in_graph_mode()
in_graph_mode = not context.executing_eagerly()
if in_graph_mode:
if session is None:
session = ops.get_default_session()

View File

@ -108,14 +108,14 @@ class InterfaceTests(test.TestCase):
[0., 0.]], self.evaluate(bare_initializer))
self.assertEqual("a_variable:0", obj.a_variable.name)
self.assertEqual("duplicate:0", other_duplicate.name)
if context.in_graph_mode():
# The .name attribute may be globally influenced, but the checkpoint name
# won't be (tested below).
self.assertEqual("duplicate_1:0", duplicate.name)
else:
if context.executing_eagerly():
# When executing eagerly, there's no uniquification of variable names. The
# checkpoint name will be the same.
self.assertEqual("duplicate:0", duplicate.name)
else:
# The .name attribute may be globally influenced, but the checkpoint name
# won't be (tested below).
self.assertEqual("duplicate_1:0", duplicate.name)
named_variables, _ = checkpointable_utils._serialize_object_graph(obj)
expected_checkpoint_names = (
"a_variable/.ATTRIBUTES/VARIABLE_VALUE",
@ -165,7 +165,7 @@ class CheckpointingTests(test.TestCase):
optimizer_step = training_util.get_or_create_global_step()
root_checkpointable = checkpointable_utils.Checkpoint(
optimizer=optimizer, model=model, optimizer_step=optimizer_step)
if context.in_eager_mode():
if context.executing_eagerly():
optimizer.minimize(
lambda: model(input_value),
global_step=optimizer_step)
@ -268,7 +268,7 @@ class CheckpointingTests(test.TestCase):
root_checkpointable = checkpointable_utils.Checkpoint(
optimizer=optimizer, model=model)
input_value = constant_op.constant([[3.]])
if context.in_eager_mode():
if context.executing_eagerly():
optimizer.minimize(
lambda: model(input_value))
else:
@ -293,7 +293,7 @@ class CheckpointingTests(test.TestCase):
self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1]))
self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter))
self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
if context.in_graph_mode():
if not context.executing_eagerly():
return # Restore-on-create is only supported when executing eagerly
on_create_model = MyModel()
on_create_optimizer = adam.AdamOptimizer(0.001)
@ -400,7 +400,7 @@ class CheckpointingTests(test.TestCase):
optimizer.minimize,
functools.partial(model, input_value),
global_step=root.global_step)
if context.in_graph_mode():
if not context.executing_eagerly():
train_fn = functools.partial(self.evaluate, train_fn())
status.initialize_or_restore()
for _ in range(num_training_steps):
@ -524,7 +524,9 @@ class CheckpointingTests(test.TestCase):
root.var = checkpointable_utils.add_variable(
root, name="var", initializer=0.)
optimizer = adam.AdamOptimizer(0.1)
if context.in_graph_mode():
if context.executing_eagerly():
optimizer.minimize(root.var.read_value)
else:
train_op = optimizer.minimize(root.var)
# Note that `optimizer` has not been added as a dependency of
# `root`. Create a one-off grouping so that slot variables for `root.var`
@ -532,8 +534,6 @@ class CheckpointingTests(test.TestCase):
self.evaluate(checkpointable_utils.gather_initializers(
checkpointable_utils.Checkpoint(root=root, optimizer=optimizer)))
self.evaluate(train_op)
else:
optimizer.minimize(root.var.read_value)
self.evaluate(state_ops.assign(root.var, 12.))
no_slots_path = checkpointable_utils.CheckpointableSaver(root).save(
os.path.join(checkpoint_directory, "no_slots"))
@ -561,7 +561,7 @@ class CheckpointingTests(test.TestCase):
with self.assertRaisesRegexp(AssertionError, "beta1_power"):
slot_status.assert_consumed()
self.assertEqual(12., self.evaluate(new_root.var))
if context.in_eager_mode():
if context.executing_eagerly():
# Slot variables are only created with restoring initializers when
# executing eagerly.
self.assertEqual(14., self.evaluate(
@ -569,7 +569,9 @@ class CheckpointingTests(test.TestCase):
else:
self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
None)
if context.in_graph_mode():
if context.executing_eagerly():
new_root.optimizer.minimize(new_root.var.read_value)
else:
train_op = new_root.optimizer.minimize(new_root.var)
# The slot variable now exists; restore() didn't create it, but we should
# now have a restore op for it.
@ -577,8 +579,6 @@ class CheckpointingTests(test.TestCase):
self.assertEqual(14., self.evaluate(
new_root.optimizer.get_slot(name="m", var=new_root.var)))
self.evaluate(train_op)
else:
new_root.optimizer.minimize(new_root.var.read_value)
slot_status.assert_consumed()
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)

View File

@ -68,7 +68,7 @@ class Iterator(object):
RuntimeError: When invoked without eager execution enabled.
"""
if not context.in_eager_mode():
if not context.executing_eagerly():
raise RuntimeError(
"{} objects can only be used when eager execution is enabled, use "
"tf.data.Dataset.make_initializable_iterator or "

View File

@ -57,7 +57,7 @@ class Evaluator(object):
self._model = model
self._metrics = {}
self._evaluators = {}
if context.in_graph_mode():
if not context.executing_eagerly():
self.call = function.defun(self.call)
# ---- API for users ----
@ -90,7 +90,7 @@ class Evaluator(object):
Only for graph execution.
@end_compatibility
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError("Evaluator.init_variables() not needed when "
"eager execution is enabled.")
return control_flow_ops.group([m.init_variables() for _, m in self.metrics])
@ -113,7 +113,8 @@ class Evaluator(object):
with summary_ops.create_file_writer(
summary_logdir).as_default(), summary_ops.always_record_summaries():
return self._all_metric_results()
if context.in_eager_mode():
if context.executing_eagerly():
return f()
else:
return function.defun(f)()
@ -158,16 +159,16 @@ class Evaluator(object):
@end_compatibility
"""
summary_logdir = kwargs.pop("summary_logdir", None)
if context.in_graph_mode():
call_op = self.__call__(dataset.make_one_shot_iterator().get_next(),
*args, **kwargs)
init_op = self.init_variables()
results_op = self.all_metric_results(summary_logdir)
return (init_op, call_op, results_op)
# Eager case
for example in datasets.Iterator(dataset):
self.__call__(example, *args, **kwargs)
return self.all_metric_results(summary_logdir)
if context.executing_eagerly():
for example in datasets.Iterator(dataset):
self.__call__(example, *args, **kwargs)
return self.all_metric_results(summary_logdir)
# Graph construction
call_op = self.__call__(dataset.make_one_shot_iterator().get_next(), *args,
**kwargs)
init_op = self.init_variables()
results_op = self.all_metric_results(summary_logdir)
return (init_op, call_op, results_op)
@staticmethod
def run_evaluation(init_op, call_op, results_op, sess=None):
@ -192,7 +193,7 @@ class Evaluator(object):
Only for graph execution.
@end_compatibility
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError("Evaluator.run_evaluation() not supported when "
"eager execution is enabled.")
sess = sess or ops.get_default_session()

View File

@ -109,13 +109,13 @@ class Metric(checkpointable.CheckpointableBase):
pos = scope.name.rfind(scope_name)
self._name = name + scope.name[pos + len(scope_name):]
self._scope = scope
if context.in_graph_mode():
if context.executing_eagerly():
self._construction_scope = context.eager_mode
else:
# We make self.call() into a graph callable here, so that we can
# return a single op that performs all of the variable updates.
self._construction_scope = ops.get_default_graph().as_default
self.call = function.defun(self.call)
else:
self._construction_scope = context.eager_mode
# ---- API for users ----
def __call__(self, *args, **kwargs):
@ -156,10 +156,11 @@ class Metric(checkpointable.CheckpointableBase):
initialization. Under eager execution, the variables are reset to their
initial values as a side effect and this function returns None.
"""
if context.in_graph_mode():
if context.executing_eagerly():
for v in self._vars:
v.assign(self._initial_values[v])
else:
return control_flow_ops.group([v.initializer for v in self._vars])
for v in self._vars:
v.assign(self._initial_values[v])
# ---- To be implemented by descendants ---
def build(self, *args, **kwargs):
@ -201,10 +202,10 @@ class Metric(checkpointable.CheckpointableBase):
def value(self):
"""In graph mode returns the result Tensor while in eager the callable."""
if context.in_graph_mode():
return self.result()
else:
if context.executing_eagerly():
return self.result
else:
return self.result()
# We can support two different strategies of for doing data-parallel
# distributed metric computations:
@ -246,7 +247,7 @@ class Metric(checkpointable.CheckpointableBase):
"""***Only for use by descendants of Metric***."""
if self._built:
raise RuntimeError("Can't call add_variable() except in build().")
if context.in_eager_mode():
if context.executing_eagerly():
collections = None
else:
if self._use_global_variables:
@ -270,7 +271,7 @@ class Metric(checkpointable.CheckpointableBase):
# Checkpointable.
overwrite=True)
self._vars.append(v)
if context.in_eager_mode():
if context.executing_eagerly():
self._initial_values[v] = v.value()
return v

View File

@ -639,7 +639,7 @@ def _make_custom_getter_for_deferred_restorations():
# Mark as already restored from this checkpoint.
delayed_restoration.checkpointed_variables_to_restore[
checkpoint_name] = None
if context.in_graph_mode():
if not context.executing_eagerly():
delayed_restoration.session.run(variable.initializer)
if found_value:
# Error checking should run even if we've already restored a value.
@ -772,7 +772,7 @@ def save_network_checkpoint(
variable_map[mapped_name]._shared_name,
variable._shared_name,
network.scope_name))
if context.in_eager_mode():
if context.executing_eagerly():
sess = None
else:
sess = ops.get_default_session()
@ -853,7 +853,7 @@ def _restore_existing_variables(network, save_path, map_func, user_map_func):
network_name=network.name,
network_scope_name=network.scope_name))
if existing_variables_by_checkpoint_name:
if context.in_eager_mode():
if context.executing_eagerly():
sess = None
else:
sess = ops.get_default_session()
@ -880,7 +880,7 @@ def _set_restore_on_create(network, save_path, map_func, user_map_func,
# _DeferredRestoration objects once a Network has been built (so that
# restoring in a loop does not take increasing amounts of memory).
if checkpointed_variables_to_restore:
if context.in_eager_mode():
if context.executing_eagerly():
sess = None
else:
sess = ops.get_default_session()

View File

@ -73,7 +73,7 @@ def restore_variables_on_create(save_path, map_func=None):
NotFoundError: If the variable is not found in checkpoint.
ValueError: If not used in eager mode or map_func is not callable.
"""
if context.in_graph_mode():
if not context.executing_eagerly():
raise ValueError(
"Currently, restore_variables_on_create can only be used with "
"eager execution enabled.")
@ -131,7 +131,7 @@ class Saver(object):
Raises:
RuntimeError: if invoked when eager execution has not been enabled.
"""
if context.in_graph_mode():
if not context.executing_eagerly():
raise RuntimeError("tfe.Saver can only be used when eager "
"execution is enabled. Use tf.train.Saver when "
"building graphs.")

View File

@ -60,8 +60,8 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@Checkpointable
@@CheckpointableSaver
@@executing_eagerly
@@in_eager_mode
@@in_graph_mode
@@run_test_in_graph_and_eager_modes
@ -93,8 +93,7 @@ from tensorflow.python.eager import function
from tensorflow.python.eager.context import DEVICE_PLACEMENT_EXPLICIT
from tensorflow.python.eager.context import DEVICE_PLACEMENT_WARN
from tensorflow.python.eager.context import DEVICE_PLACEMENT_SILENT
from tensorflow.python.eager.context import in_eager_mode
from tensorflow.python.eager.context import in_graph_mode
from tensorflow.python.eager.context import executing_eagerly
from tensorflow.python.eager.context import list_devices
from tensorflow.python.eager.context import num_gpus
from tensorflow.python.eager.execution_callbacks import add_execution_callback
@ -122,5 +121,6 @@ implicit_value_and_gradients = backprop.implicit_val_and_grad
gradients_function = backprop.gradients_function
value_and_gradients_function = backprop.val_and_grad_function
GradientTape = backprop.GradientTape # pylint: disable=invalid-name
in_eager_mode = executing_eagerly
remove_undocumented(__name__)

View File

@ -47,7 +47,8 @@ class TFETest(test_util.TensorFlowTestCase):
def testVariableError(self):
with self.assertRaisesRegexp(
RuntimeError, r'Variable not supported in Eager mode'):
RuntimeError,
r'Variable not supported when eager execution is enabled'):
variables.Variable(initial_value=1.0)
def testGradients(self):

View File

@ -154,7 +154,7 @@ class CriticalSection(object):
self._handle = gen_resource_variable_ops.mutex_v2(
shared_name=shared_name, container=container, name=name)
if context.in_graph_mode():
if not context.executing_eagerly():
ops.add_to_collections(CRITICAL_SECTIONS, self)
@property
@ -221,7 +221,7 @@ class CriticalSection(object):
"This is illegal and would cause deadlocks. "
"CriticalSection: %s." % self._handle)
if context.in_graph_mode():
if not context.executing_eagerly():
# Collections and op introspection does not work in eager
# mode. This is generally ok; since eager mode (as of
# writing) executes sequentially anyway.
@ -250,7 +250,7 @@ class CriticalSection(object):
return x.identity()
elif isinstance(x, ops.Operation):
return control_flow_ops.group(x)
elif context.in_eager_mode() and x is None:
elif context.executing_eagerly() and x is None:
return None
else:
return array_ops.identity(x)
@ -274,7 +274,7 @@ class CriticalSection(object):
with ops.control_dependencies([ensure_lock_exists]):
outputs = nest.map_structure(identity, r)
if context.in_graph_mode():
if not context.executing_eagerly():
signature = _ExecutionSignature(
op=lock.op,
handle=self._handle,

View File

@ -2746,7 +2746,7 @@ def softmax(logits, scope=None):
logits_2d = array_ops.reshape(logits, [-1, num_logits])
predictions = nn.softmax(logits_2d)
predictions = array_ops.reshape(predictions, array_ops.shape(logits))
if context.in_graph_mode():
if not context.executing_eagerly():
predictions.set_shape(logits.get_shape())
return predictions

View File

@ -1263,7 +1263,7 @@ def _compute_placement_auc(labels, predictions, weights, alpha,
weights_for_true = ordered_weights * float_labels_for_true
weights_for_false = ordered_weights * float_labels_for_false
# For each set of weights with the same segmented indices, we add up the
# For each set of weights with the same segmented indices, we add up the
# weight values. Note that for each label, we deliberately rely on weights
# for the opposite label.
weight_totals_for_true = math_ops.segment_sum(weights_for_false,
@ -3646,7 +3646,7 @@ def cohen_kappa(labels,
`updates_collections` are not a list or tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.contrib.metrics.cohen_kappa is not supported'
'when eager execution is enabled.')
if num_classes < 2:

View File

@ -267,5 +267,5 @@ def _check_device(tensor, expected=None):
def _check_graph_mode():
if context.in_eager_mode():
if context.executing_eagerly():
raise ValueError('Nccl ops are not supported in eager mode')

View File

@ -97,7 +97,7 @@ class AddSignTest(test.TestCase):
global_step=global_step)
neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),
global_step=global_step)
if context.in_graph_mode():
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
@ -108,13 +108,13 @@ class AddSignTest(test.TestCase):
# last 3 steps with negative gradient (sign(gm) should be -1)
for t in range(1, 8):
if t < 5:
if context.in_graph_mode():
if not context.executing_eagerly():
self.evaluate(update)
elif t > 1:
opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
global_step=global_step)
else:
if context.in_graph_mode():
if not context.executing_eagerly():
self.evaluate(neg_update)
elif t > 1:
opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),

View File

@ -99,7 +99,7 @@ class PowerSignTest(test.TestCase):
neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),
global_step=global_step)
if context.in_graph_mode():
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
@ -110,13 +110,13 @@ class PowerSignTest(test.TestCase):
# last 3 steps with negative gradient (sign(gm) should be -1)
for t in range(1, 8):
if t < 5:
if context.in_graph_mode():
if not context.executing_eagerly():
self.evaluate(update)
elif t > 1:
opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
global_step=global_step)
else:
if context.in_graph_mode():
if not context.executing_eagerly():
self.evaluate(neg_update)
elif t > 1:
opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),

View File

@ -869,7 +869,7 @@ class LSTMTest(test.TestCase):
num_proj = 4
max_length = 8
sequence_length = [4, 6]
in_graph_mode = context.in_graph_mode()
in_graph_mode = not context.executing_eagerly()
with self.test_session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
@ -934,8 +934,7 @@ class LSTMTest(test.TestCase):
if in_graph_mode:
self.assertAllEqual(outputs_static, outputs_dynamic)
else:
self.assertAllEqual(
array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy())
self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic)
self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic))
@test_util.run_in_graph_and_eager_modes()
@ -946,7 +945,7 @@ class LSTMTest(test.TestCase):
num_proj = 4
max_length = 8
sequence_length = [4, 6]
in_graph_mode = context.in_graph_mode()
in_graph_mode = not context.executing_eagerly()
with self.test_session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
@ -1022,10 +1021,9 @@ class LSTMTest(test.TestCase):
if in_graph_mode:
self.assertAllEqual(outputs_static, outputs_dynamic)
else:
self.assertAllEqual(
array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy())
state_static = [s.numpy() for s in nest.flatten(state_static)]
state_dynamic = [s.numpy() for s in nest.flatten(state_dynamic)]
self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic)
state_static = nest.flatten(state_static)
state_dynamic = nest.flatten(state_dynamic)
self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic))
def _testDynamicEquivalentToStaticRNN(self, use_sequence_length):
@ -1043,7 +1041,7 @@ class LSTMTest(test.TestCase):
else:
sequence_length = None
in_graph_mode = context.in_graph_mode()
in_graph_mode = not context.executing_eagerly()
# TODO(b/68017812): Eager ignores operation seeds, so we need to create a
# single cell and reuse it across the static and dynamic RNNs. Remove this

View File

@ -110,7 +110,7 @@ class SummaryWriter(object):
def __init__(self, resource):
self._resource = resource
if context.in_eager_mode() and self._resource is not None:
if context.executing_eagerly() and self._resource is not None:
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
handle=self._resource, handle_device="cpu:0")
@ -158,7 +158,7 @@ def initialize(
@{tf.contrib.summary.SummaryWriter}.
ValueError: If session wasn't passed and no default session.
"""
if context.in_eager_mode():
if context.executing_eagerly():
return
if context.context().summary_writer_resource is None:
raise RuntimeError("No default tf.contrib.summary.SummaryWriter found")
@ -269,7 +269,7 @@ def _make_summary_writer(name, factory, **kwargs):
resource = gen_summary_ops.summary_writer(shared_name=name)
# TODO(apassos): Consider doing this instead.
# node = factory(resource, **kwargs)
# if not context.in_eager_mode():
# if not context.executing_eagerly():
# ops.get_default_session().run(node)
ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME,
factory(resource, **kwargs))
@ -295,7 +295,7 @@ def all_summary_ops():
Returns:
The summary ops.
"""
if context.in_eager_mode():
if context.executing_eagerly():
return None
return ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
@ -309,7 +309,7 @@ def summary_writer_initializer_op():
Raises:
RuntimeError: If in Eager mode.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError(
"tf.contrib.summary.summary_writer_initializer_op is only "
"supported in graph mode.")
@ -477,7 +477,7 @@ def graph(param, step=None, name=None):
Raises:
TypeError: If `param` isn't already a @{tf.Tensor} in graph mode.
"""
if not context.in_eager_mode() and not isinstance(param, ops.Tensor):
if not context.executing_eagerly() and not isinstance(param, ops.Tensor):
raise TypeError("graph() needs a tf.Tensor (e.g. tf.placeholder) in graph "
"mode, but was: %s" % type(param))
writer = context.context().summary_writer_resource

View File

@ -91,7 +91,7 @@ class Dataset(object):
Raises:
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError(
"dataset.make_initializable_iterator is not supported when eager "
"execution is enabled.")
@ -123,7 +123,7 @@ class Dataset(object):
Raises:
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError(
"dataset.make_one_shot_iterator is not supported when eager "
"execution is enabled.")

View File

@ -65,7 +65,7 @@ class RandomSeedTest(test.TestCase):
self.assertEqual((g_seed, op_seed), toutput, msg=msg)
random_seed.set_random_seed(None)
if context.in_graph_mode():
if not context.executing_eagerly():
random_seed.set_random_seed(1)
tinput = (1, None)
toutput = (1, ops.get_default_graph()._last_id) # pylint: disable=protected-access

View File

@ -55,7 +55,7 @@ def c_tfe_py_fastpath_execute(a,
transpose_b=False,
name=None):
ctx = context.context()
assert not ctx.in_graph_mode(
assert ctx.in_eager_mode(
), "The prototype doesn't contain C code for graph construction"
try:
return pywrap_tensorflow.TFE_Py_FastPathExecute(

View File

@ -260,12 +260,8 @@ class Context(object):
if mode == EAGER_MODE:
context_stack.pop()
def in_graph_mode(self):
"""Returns True if current thread is in GRAPH mode."""
return self._eager_context.mode == GRAPH_MODE
def in_eager_mode(self):
"""Returns True if current thread is in EAGER mode."""
def executing_eagerly(self):
"""Returns True if current thread has eager executing enabled."""
return self._eager_context.mode == EAGER_MODE
def scalar_cache(self):
@ -522,23 +518,23 @@ def internal_operation_seed():
return context()._internal_operation_seed() # pylint: disable=protected-access
def in_graph_mode():
"""Returns True if current thread is in GRAPH mode for default context."""
return context().in_graph_mode()
def executing_eagerly():
"""Returns True if the current thread has eager execution enabled."""
return context().executing_eagerly()
def in_eager_mode():
"""Returns True if current thread is in EAGER mode for default context."""
return context().in_eager_mode()
"""Use executing_eagerly() instead. This function will be removed."""
return executing_eagerly()
def graph_mode():
"""Context-manager to enable GRAPH mode for current thread."""
"""Context-manager to disable eager execution for the current thread."""
return context()._mode(GRAPH_MODE) # pylint: disable=protected-access
def eager_mode():
"""Context-manager to enable EAGER mode for current thread."""
"""Context-manager to enable eager execution for the current thread."""
return context()._mode(EAGER_MODE) # pylint: disable=protected-access
@ -631,4 +627,8 @@ def export_run_metadata():
# (for example, enable_eager_execution in python/framework/ops.py),
# but they do all import this file. Note that IS_IN_GRAPH_MODE and
# in_graph_mode are both parameterless functions.
is_in_graph_mode.IS_IN_GRAPH_MODE = in_graph_mode
def _tmp_in_graph_mode():
return not executing_eagerly()
is_in_graph_mode.IS_IN_GRAPH_MODE = _tmp_in_graph_mode

View File

@ -57,8 +57,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testContext(self):
ctx = context.Context()
self.assertFalse(ctx.in_graph_mode())
self.assertTrue(ctx.in_eager_mode())
self.assertTrue(ctx.executing_eagerly())
self.assertEqual('', ctx.scope_name)
ctx.scope_name = 'foo'
@ -150,9 +149,9 @@ class TFETest(test_util.TensorFlowTestCase):
def get_context_values(ctx):
return [
ctx.in_graph_mode(),
ctx.in_eager_mode(), ctx.scope_name, ctx.summary_writer_resource,
ctx.device_name, ctx.num_gpus()
ctx.executing_eagerly(), ctx.scope_name, ctx.summary_writer_resource,
ctx.device_name,
ctx.num_gpus()
]
def get_values(ctx, values):

View File

@ -112,7 +112,7 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
"""
del as_ref # Unused.
if context.in_eager_mode():
if context.executing_eagerly():
return value
default_graph = ops.get_default_graph()
@ -295,7 +295,7 @@ class _EagerDefinedFunction(object):
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
function_def = function_pb2.FunctionDef()
function_def.ParseFromString(compat.as_bytes(proto_data))
if context.in_eager_mode():
if context.executing_eagerly():
_register(fn)
self.definition = function_def
self.name = function_def.signature.name
@ -438,7 +438,14 @@ class GraphModeFunction(object):
all_args = args + self._extra_inputs
signature = self._forward_fdef.signature
ctx = context.context()
if ctx.in_graph_mode():
if ctx.executing_eagerly():
outputs = execute.execute(
str(signature.name),
num_outputs=len(signature.output_arg),
inputs=all_args,
attrs=None,
ctx=ctx)
else:
g = ops.get_default_graph()
g._add_function(self._forward_fdef) # pylint: disable=protected-access
op = g.create_op(
@ -453,13 +460,6 @@ class GraphModeFunction(object):
outputs, (ops.Tensor, type(None))) else list(outputs)
for i, s in enumerate(self._output_shapes):
outputs[i].set_shape(s)
else:
outputs = execute.execute(
str(signature.name),
num_outputs=len(signature.output_arg),
inputs=all_args,
attrs=None,
ctx=ctx)
real_outputs = outputs[:len(self._returns)]
side_outputs = outputs[len(self._returns):]
@ -530,7 +530,14 @@ class GraphModeFunction(object):
return self._backprop_call(tensor_inputs)
ctx = context.context()
if ctx.in_graph_mode():
if ctx.executing_eagerly():
result = execute.execute(
str(self._func_name),
num_outputs=self._num_outputs,
inputs=tensor_inputs + self._extra_inputs,
attrs=None,
ctx=ctx)
else:
g = ops.get_default_graph()
self.add_to_graph(g)
signature = self._function_def.definition.signature
@ -547,13 +554,6 @@ class GraphModeFunction(object):
return op
for i, s in enumerate(self._output_shapes):
result[i].set_shape(s)
else:
result = execute.execute(
str(self._func_name),
num_outputs=self._num_outputs,
inputs=tensor_inputs + self._extra_inputs,
attrs=None,
ctx=ctx)
return self._build_call_outputs(result)
@ -666,7 +666,7 @@ def _defun_internal(name, func, args, kwds):
if x not in all_ignored_ops)
# Register any other functions defined in the graph
# TODO(ashankar): Oh lord, forgive me for this lint travesty.
if context.in_eager_mode():
if context.executing_eagerly():
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
# TODO(ashankar): What about the gradient registry?
_register(f._c_func) # pylint: disable=protected-access
@ -906,7 +906,7 @@ class AutomaticControlDependencies(object):
return tensor
def __enter__(self):
if context.in_eager_mode():
if context.executing_eagerly():
return self
# This code assumes no other thread is adding ops to the graph while
# we're adding ops to the graph.
@ -977,7 +977,7 @@ class AutomaticControlDependencies(object):
merge_for_resource[o] = new_merge[0].op
def __exit__(self, unused_type, unused_value, unused_traceback):
if context.in_eager_mode():
if context.executing_eagerly():
return
if self._graph is not ops.get_default_graph():

View File

@ -406,7 +406,7 @@ def graph_callable(shape_and_dtypes):
A callable graph object.
"""
# TODO(alive,apassos): support initialized_value and friends from tf.Variable.
assert context.in_eager_mode(), (
assert context.executing_eagerly(), (
"graph_callable can only be used when Eager execution is enabled.")
def decorator(func):
return tf_decorator.make_decorator(func,

View File

@ -367,7 +367,7 @@ void GenEagerPythonOp::HandleGraphMode(const string& function_setup) {
// Handle graph-mode case
strings::StrAppend(&result_,
" _ctx = _context.context()\n"
" if _ctx.in_graph_mode():\n",
" if not _ctx.executing_eagerly():\n",
function_setup,
" _, _, _op = _op_def_lib._apply_op_helper(\n");
AddBodyNoReturn(" ");

View File

@ -169,7 +169,7 @@ class Tests(test.TestCase):
def testFastpathExecute_InvalidInputs(self):
a_2_by_2 = random_ops.random_uniform((2, 2))
ctx = context.context()
assert not ctx.in_graph_mode(
assert ctx.executing_eagerly(
), "The prototype doesn't contain C code for graph construction"
ctx_handle = ctx._handle # pylint: disable=protected-access

View File

@ -166,7 +166,7 @@ class Estimator(object):
ValueError: if this is called via a subclass and if that class overrides
a member of `Estimator`.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError(
'Estimators are not supported when eager execution is enabled.')

View File

@ -181,7 +181,7 @@ def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
TypeError: if shape is incorrectly specified or unsupported.
"""
ctx = context.context()
if not ctx.in_graph_mode():
if ctx.executing_eagerly():
t = convert_to_eager_tensor(value, ctx, dtype)
if shape is None:
return t

View File

@ -489,10 +489,10 @@ class _DefinedFunction(object):
# Adds this function into 'g'.
# pylint: disable=protected-access
if context.in_graph_mode():
g._add_function(self)
else:
if context.executing_eagerly():
context.context().add_function_def(self.definition)
else:
g._add_function(self)
# pylint: enable=protected-access
# Ensures related sub-routines are defined in 'g', too.

View File

@ -695,7 +695,7 @@ def import_scoped_meta_graph(meta_graph_or_file,
Raises:
ValueError: If the graph_def contains unbound inputs.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise ValueError("Exporting/importing meta graphs is not supported when "
"eager execution is enabled.")
if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
@ -856,7 +856,7 @@ def export_scoped_meta_graph(filename=None,
Raises:
ValueError: When the `GraphDef` is larger than 2GB.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise ValueError("Exporting/importing meta graphs is not supported when "
"Eager Execution is enabled.")
graph = graph or ops.get_default_graph()

View File

@ -395,10 +395,10 @@ class Tensor(_TensorLike):
"Tensor._shape cannot be assigned, use Tensor.set_shape instead.")
def __iter__(self):
if context.in_graph_mode():
if not context.executing_eagerly():
raise TypeError(
"`Tensor` objects are not iterable when eager execution is not "
"enabled. To iterate over this tensor use `tf.map_fn`.")
"Tensor objects are not iterable when eager execution is not "
"enabled. To iterate over this tensor use tf.map_fn.")
shape = self._shape_tuple()
if shape is None:
raise TypeError("Cannot iterate over a tensor with unknown shape.")
@ -772,7 +772,7 @@ class _EagerTensorBase(Tensor):
six.raise_from(core._status_to_exception(e.code, e.message), None)
# Record the copy on tape and define backprop copy as well.
if not context.in_graph_mode():
if context.executing_eagerly():
self_device = self.device
def grad_fun(dresult):
return [dresult._copy(device_name=self_device)]
@ -993,7 +993,7 @@ def internal_convert_to_tensor(value,
"""
if ctx is None: ctx = context.context()
if ctx.in_eager_mode():
if ctx.executing_eagerly():
# Fast path for EagerTensors that don't need any conversion.
if isinstance(value, EagerTensor):
# Note that we don't check that value's dtype matches the dtype
@ -4797,15 +4797,15 @@ def device(device_name_or_function):
Raises:
RuntimeError: If eager execution is enabled and a function is passed in.
"""
if context.in_graph_mode():
return get_default_graph().device(device_name_or_function)
else:
if context.executing_eagerly():
# TODO(agarwal): support device functions in EAGER mode.
if callable(device_name_or_function):
raise RuntimeError(
"tf.device does not support functions when eager execution "
"is enabled.")
return context.device(device_name_or_function)
else:
return get_default_graph().device(device_name_or_function)
@tf_export("container")
@ -4824,7 +4824,12 @@ def container(container_name):
@tf_export("colocate_with")
def colocate_with(op, ignore_existing=False):
if context.in_graph_mode():
if context.executing_eagerly():
if op is not None:
return device(op.device)
else:
return _NullContextmanager()
else:
default_graph = get_default_graph()
if isinstance(op, EagerTensor):
if default_graph.building_function:
@ -4833,11 +4838,6 @@ def colocate_with(op, ignore_existing=False):
raise ValueError("Encountered an Eager-defined Tensor during graph "
"construction, but a function was not being built.")
return default_graph.colocate_with(op, ignore_existing)
else:
if op is not None:
return device(op.device)
else:
return _NullContextmanager()
@tf_export("control_dependencies")
@ -4857,10 +4857,10 @@ def control_dependencies(control_inputs):
A context manager that specifies control dependencies for all
operations constructed within the context.
"""
if context.in_graph_mode():
return get_default_graph().control_dependencies(control_inputs)
else:
if context.executing_eagerly():
return _NullContextmanager()
else:
return get_default_graph().control_dependencies(control_inputs)
class _DefaultStack(threading.local):
@ -5123,7 +5123,7 @@ def init_scope():
"""
# pylint: enable=g-doc-return-or-yield,line-too-long
if context.in_eager_mode():
if context.executing_eagerly():
# Fastpath.
with tape.stop_recording():
yield
@ -5705,7 +5705,7 @@ class name_scope(object): # pylint: disable=invalid-name
self._default_name = default_name
self._values = values
self._ctx = context.context()
self._in_eager_mode = self._ctx.in_eager_mode()
self._in_eager_mode = self._ctx.executing_eagerly()
def __enter__(self):
"""Start the scope block.
@ -5884,7 +5884,7 @@ def get_from_proto_function(collection_name):
def _assert_collection_is_ok(collection_name):
if context.in_eager_mode():
if context.executing_eagerly():
if collection_name in GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access
raise ValueError("When Eager Execution is enabled, variable "
"collections are not supported.")

View File

@ -1763,7 +1763,13 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
return constant_op.constant(2.0)
future.calls = 0
if context.in_graph_mode():
if context.executing_eagerly():
a = constant_op.constant(1.0)
b = future()
with ops.control_dependencies([a, b]):
c = constant_op.constant(3.0)
self.assertEqual(future.calls, 1)
else:
g = ops.Graph()
with g.as_default():
a = constant_op.constant(1.0)
@ -1772,12 +1778,6 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
c = constant_op.constant(3.0)
self.assertEqual(c.op.control_inputs, [a.op, b.op])
self.assertEqual(future.calls, 1)
else:
a = constant_op.constant(1.0)
b = future()
with ops.control_dependencies([a, b]):
c = constant_op.constant(3.0)
self.assertEqual(future.calls, 1)
def testBasicWithConversion(self):
g = ops.Graph()
@ -2150,11 +2150,11 @@ class InitScopeTest(test_util.TensorFlowTestCase):
with ops.init_scope():
# Because g is building a function, init_scope should
# escape out to the eager context.
self.assertTrue(context.in_eager_mode())
self.assertTrue(context.executing_eagerly())
# g should be reinstated as the default graph, and the
# graph context should be re-entered.
self.assertIs(g, ops.get_default_graph())
self.assertTrue(context.in_graph_mode())
self.assertFalse(context.executing_eagerly())
def testStaysInEagerWhenOnlyEagerContextActive(self):
with context.eager_mode():
@ -2277,12 +2277,13 @@ class InitScopeTest(test_util.TensorFlowTestCase):
with context.eager_mode():
def foo():
with ops.name_scope("inner"), ops.init_scope():
if context.in_graph_mode():
self.assertEqual(ops.get_name_scope(), "inner")
else:
if context.executing_eagerly():
# A trailing slash is always appended when eager execution is
# enabled.
self.assertEqual(context.context().scope_name, "inner/")
else:
self.assertEqual(ops.get_name_scope(), "inner")
foo()
self.assertEqual(ops.get_name_scope(), "")
foo_compiled = eager_function.defun(foo)

View File

@ -52,20 +52,20 @@ def get_seed(op_seed):
A tuple of two integers that should be used for the local seed of this
operation.
"""
is_graph_mode = context.in_graph_mode()
eager = context.executing_eagerly()
if is_graph_mode:
global_seed = ops.get_default_graph().seed
else:
if eager:
global_seed = context.global_seed()
else:
global_seed = ops.get_default_graph().seed
if global_seed is not None:
if op_seed is None:
# pylint: disable=protected-access
if is_graph_mode:
op_seed = ops.get_default_graph()._last_id
else:
if eager:
op_seed = context.internal_operation_seed()
else:
op_seed = ops.get_default_graph()._last_id
seeds = _truncate_seed(global_seed), _truncate_seed(op_seed)
else:
@ -176,7 +176,7 @@ def set_random_seed(seed):
Args:
seed: integer.
"""
if context.in_graph_mode():
ops.get_default_graph().seed = seed
else:
if context.executing_eagerly():
context.set_global_seed(seed)
else:
ops.get_default_graph().seed = seed

View File

@ -40,13 +40,13 @@ class RandomSeedTest(test.TestCase):
((2**31 - 1, 0), (0, 2**31 - 1)), # Don't wrap to (0, 0) either
((0, 2**31 - 1), (0, 2**31 - 1)), # Wrapping for the other argument
]
if context.in_graph_mode():
# 0 will be the default_graph._lastid.
test_cases.append(((1, None), (1, 0)))
else:
if context.executing_eagerly():
# operation seed is random number generated based on global seed.
# it's not tested due to possibility of platform or version difference.
pass
else:
# 0 will be the default_graph._lastid.
test_cases.append(((1, None), (1, 0)))
for tc in test_cases:
tinput, toutput = tc[0], tc[1]
random_seed.set_random_seed(tinput[0])

View File

@ -828,7 +828,7 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name
Returns:
A `TensorShape` based on the constant value of the given `tensor`.
"""
if context.in_eager_mode():
if context.executing_eagerly():
return tensor_shape.as_shape(
[dim if dim != -1 else None for dim in tensor.numpy()])

View File

@ -816,7 +816,7 @@ class TensorFlowTestCase(googletest.TestCase):
Returns:
tensors numpy values.
"""
if context.in_eager_mode():
if context.executing_eagerly():
return self._eval_helper(tensors)
else:
sess = ops.get_default_session()

View File

@ -343,7 +343,7 @@ def learning_phase():
Returns:
Learning phase (scalar integer tensor or Python integer).
"""
if context.in_eager_mode():
if context.executing_eagerly():
if 'eager' not in _GRAPH_LEARNING_PHASES:
# Fallback to inference mode as default.
return 0
@ -370,7 +370,7 @@ def set_learning_phase(value):
global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned
if value not in {0, 1}:
raise ValueError('Expected learning phase to be 0 or 1.')
if context.in_eager_mode():
if context.executing_eagerly():
_GRAPH_LEARNING_PHASES['eager'] = value
else:
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
@ -399,7 +399,7 @@ def learning_phase_scope(value):
yield value
finally:
# Restore learning phase to initial value.
if context.in_eager_mode():
if context.executing_eagerly():
_GRAPH_LEARNING_PHASES['eager'] = previous_value
else:
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value
@ -2625,7 +2625,7 @@ def get_value(x):
Returns:
A Numpy array.
"""
if context.in_eager_mode():
if context.executing_eagerly():
return x.numpy()
return x.eval(session=get_session())
@ -2640,7 +2640,7 @@ def batch_get_value(tensors):
Returns:
A list of Numpy arrays.
"""
if context.in_eager_mode():
if context.executing_eagerly():
return [x.numpy() for x in tensors]
if tensors:
return get_session().run(tensors)
@ -2658,7 +2658,7 @@ def set_value(x, value):
(of the same shape).
"""
value = np.asarray(value, dtype=dtype(x))
if context.in_eager_mode():
if context.executing_eagerly():
x.assign(value)
else:
tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0])
@ -2681,7 +2681,7 @@ def batch_set_value(tuples):
tuples: a list of tuples `(tensor, value)`.
`value` should be a Numpy array.
"""
if context.in_eager_mode():
if context.executing_eagerly():
for x, value in tuples:
x.assign(np.asarray(value, dtype=dtype(x)))
else:
@ -3123,7 +3123,7 @@ def rnn(step_function,
outputs_shape[1] = inputs_shape[1]
outputs.set_shape(outputs_shape)
if not context.in_eager_mode():
if not context.executing_eagerly():
last_output._uses_learning_phase = uses_learning_phase
return last_output, outputs, new_states

View File

@ -237,7 +237,7 @@ class Layer(tf_base_layers.Layer):
"""
# Actually call the layer (optionally building it).
output = super(Layer, self).__call__(inputs, **kwargs)
if context.in_eager_mode():
if context.executing_eagerly():
return output
if hasattr(self, '_symbolic_set_inputs') and not self.inputs:

View File

@ -92,7 +92,7 @@ class InputLayer(base_layer.Layer):
else:
batch_input_shape = None
if context.in_eager_mode():
if context.executing_eagerly():
# In eager mode, create a temporary placeholder to call the layer on.
input_tensor = tf_base_layers._DeferredTensor( # pylint: disable=protected-access
shape=batch_input_shape,

View File

@ -99,11 +99,11 @@ class Network(base_layer.Layer):
self._losses = [] # Used in symbolic mode only.
self._scope = None # Never used.
self._reuse = None # Never used.
if context.in_eager_mode():
if context.executing_eagerly():
self._graph = None
else:
self._graph = ops.get_default_graph() # Used in symbolic mode only.
# A Network does not create weights of its own, thus has no dtype.
# A Network does not create weights of its own, thus has no dtype.
self._dtype = None
# All layers in order of horizontal graph traversal.
@ -126,7 +126,7 @@ class Network(base_layer.Layer):
self.outputs = [outputs]
# User-prodived argument validation.
if context.in_eager_mode():
if context.executing_eagerly():
# Check that all inputs/outputs are DeferredTensors.
for tensor in self.inputs:
if not isinstance(tensor, tf_base_layers._DeferredTensor): # pylint: disable=protected-access
@ -275,7 +275,7 @@ class Network(base_layer.Layer):
self._feed_input_names.append(layer.name)
self._feed_input_shapes.append(K.int_shape(self.inputs[i]))
# layer.input gives an error in eager mode
if context.in_graph_mode():
if not context.executing_eagerly():
self._feed_inputs.append(layer.input)
for layer in self._output_layers:
self.output_names.append(layer.name)
@ -317,7 +317,7 @@ class Network(base_layer.Layer):
raise NotImplementedError('`add_variable` is not supported on Networks.')
def add_loss(self, *args, **kwargs):
if context.in_eager_mode():
if context.executing_eagerly():
raise NotImplementedError('`add_loss` is not supported on Networks '
'when eager execution is enabled.')
super(Network, self).add_loss(*args, **kwargs)
@ -483,7 +483,7 @@ class Network(base_layer.Layer):
Returns:
A list of update ops.
"""
if context.in_eager_mode():
if context.executing_eagerly():
return []
if not self.trainable and not self.stateful:
@ -530,7 +530,7 @@ class Network(base_layer.Layer):
losses = []
for layer in self.layers:
losses += layer.losses
if context.in_eager_mode():
if context.executing_eagerly():
return losses
if self.inputs:
@ -623,7 +623,7 @@ class Network(base_layer.Layer):
else:
masks = nest.flatten(mask)
if context.in_graph_mode():
if not context.executing_eagerly():
# Try to retrieve cached outputs if the layer has already been called
# on these exact inputs.
cache_key = (tf_layers_util.object_list_uid(inputs)
@ -829,7 +829,7 @@ class Network(base_layer.Layer):
else:
output_masks = [None for _ in range(len(output_tensors))]
if context.in_graph_mode():
if not context.executing_eagerly():
if layer.activity_regularizer is not None:
regularization_losses = [
layer.activity_regularizer(x) for x in output_tensors
@ -859,7 +859,7 @@ class Network(base_layer.Layer):
if output_masks is not None:
output_masks = output_masks[0]
if context.in_graph_mode():
if not context.executing_eagerly():
# Update cache;
# keys are based on ids on input tensors and inputs masks.
cache_key = (tf_layers_util.object_list_uid(inputs)

View File

@ -755,7 +755,17 @@ class TopologyConstructionTest(test.TestCase):
def compute_mask(self, inputs, mask=None):
return array_ops.ones_like(inputs)
if context.in_graph_mode():
if context.executing_eagerly():
a = constant_op.constant([2] * 32)
mask = constant_op.constant([0, 1] * 16)
a._keras_mask = mask
b = MaskedLayer().apply(a)
self.assertTrue(hasattr(b, '_keras_mask'))
self.assertAllEqual(
self.evaluate(array_ops.ones_like(mask)),
self.evaluate(getattr(b, '_keras_mask')))
self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b))
else:
x = keras.Input(shape=(32,))
y = MaskedLayer()(x) # pylint: disable=not-callable
network = keras.engine.Network(x, y)
@ -769,15 +779,6 @@ class TopologyConstructionTest(test.TestCase):
x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32))
y_2 = network(x_2)
self.assertEqual(y_2.get_shape().as_list(), [None, 32])
else:
a = constant_op.constant([2] * 32)
mask = constant_op.constant([0, 1] * 16)
a._keras_mask = mask
b = MaskedLayer().apply(a)
self.assertTrue(hasattr(b, '_keras_mask'))
self.assertAllEqual(self.evaluate(array_ops.ones_like(mask)),
self.evaluate(getattr(b, '_keras_mask')))
self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b))
def test_activity_regularization_with_model_composition(self):
@ -885,13 +886,13 @@ class DeferredModeTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testSimpleNetworkBuilding(self):
inputs = keras.engine.Input(shape=(32,))
if context.in_eager_mode():
if context.executing_eagerly():
self.assertIsInstance(inputs, tf_base_layers._DeferredTensor)
self.assertEqual(inputs.dtype.name, 'float32')
self.assertEqual(inputs.shape.as_list(), [None, 32])
x = keras.layers.Dense(2)(inputs)
if context.in_eager_mode():
if context.executing_eagerly():
self.assertIsInstance(x, tf_base_layers._DeferredTensor)
self.assertEqual(x.dtype.name, 'float32')
self.assertEqual(x.shape.as_list(), [None, 2])
@ -900,7 +901,7 @@ class DeferredModeTest(test.TestCase):
network = keras.engine.Network(inputs, outputs)
self.assertIsInstance(network, keras.engine.Network)
if context.in_eager_mode():
if context.executing_eagerly():
# It should be possible to call such a network on EagerTensors.
inputs = constant_op.constant(
np.random.random((10, 32)).astype('float32'))
@ -925,7 +926,7 @@ class DeferredModeTest(test.TestCase):
c = keras.layers.Dense(2)(c)
network = keras.engine.Network([input_a, input_b], [a, c])
if context.in_eager_mode():
if context.executing_eagerly():
a_val = constant_op.constant(
np.random.random((10, 32)).astype('float32'))
b_val = constant_op.constant(

View File

@ -162,7 +162,7 @@ class Model(Network):
`optimizer`, `loss`, `metrics` or `sample_weight_mode`.
"""
loss = loss or {}
if context.in_eager_mode() and not isinstance(
if context.executing_eagerly() and not isinstance(
optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
raise ValueError('Only TF native optimizers are supported in Eager mode.')
@ -170,13 +170,13 @@ class Model(Network):
self.loss = loss
self.metrics = metrics or []
self.loss_weights = loss_weights
if context.in_eager_mode() and sample_weight_mode is not None:
if context.executing_eagerly() and sample_weight_mode is not None:
raise ValueError('sample_weight_mode is not supported in Eager mode.')
self.sample_weight_mode = sample_weight_mode
if context.in_eager_mode() and weighted_metrics is not None:
if context.executing_eagerly() and weighted_metrics is not None:
raise ValueError('weighted_metrics is not supported in Eager mode.')
self.weighted_metrics = weighted_metrics
if context.in_eager_mode() and target_tensors is not None:
if context.executing_eagerly() and target_tensors is not None:
raise ValueError('target_tensors is not supported in Eager mode.')
self.target_tensors = target_tensors
@ -230,7 +230,7 @@ class Model(Network):
skip_target_weighing_indices.append(i)
# Prepare output masks.
if context.in_graph_mode():
if not context.executing_eagerly():
masks = self.compute_mask(self.inputs, mask=None)
if masks is None:
masks = [None for _ in self.outputs]
@ -264,7 +264,7 @@ class Model(Network):
self.loss_weights_list = loss_weights_list
# initialization for Eager mode execution
if context.in_eager_mode():
if context.executing_eagerly():
if target_tensors is not None:
raise ValueError('target_tensors are not currently supported in Eager'
'mode.')
@ -738,13 +738,13 @@ class Model(Network):
'TensorFlow tensors. '
'You passed: x=' + str(x) + '; y=' + str(y))
if context.in_graph_mode():
if context.executing_eagerly():
target_tensors = None
else:
# Handle target tensors if any passed.
if not isinstance(y, (list, tuple)):
y = [y]
target_tensors = [v for v in y if tensor_util.is_tensor(v)]
else:
target_tensors = None
self.compile(optimizer=self.optimizer,
loss=self.loss,
metrics=self.metrics,
@ -761,7 +761,7 @@ class Model(Network):
# What follows is input validation and standardization to list format,
# in the case where all inputs are value arrays.
if context.in_eager_mode():
if context.executing_eagerly():
# In eager mode, do not do shape validation.
feed_input_names = self.input_names
feed_input_shapes = None
@ -784,7 +784,7 @@ class Model(Network):
exception_prefix='input')
if y is not None:
if context.in_eager_mode():
if context.executing_eagerly():
feed_output_names = self.output_names
feed_output_shapes = None
# Sample weighting not supported in this case.
@ -835,7 +835,7 @@ class Model(Network):
]
# Check that all arrays have the same length.
training_utils.check_array_lengths(x, y, sample_weights)
if self._is_graph_network and not context.in_eager_mode():
if self._is_graph_network and not context.executing_eagerly():
# Additional checks to avoid users mistakenly using improper loss fns.
training_utils.check_loss_and_target_compatibility(
y, self._feed_loss_fns, feed_output_shapes)
@ -874,7 +874,7 @@ class Model(Network):
whether to build the model's graph in inference mode (False), training
mode (True), or using the Keras learning phase (None).
"""
if context.in_eager_mode():
if context.executing_eagerly():
self._eager_set_inputs(inputs)
else:
self._symbolic_set_inputs(inputs, training=training)
@ -903,7 +903,7 @@ class Model(Network):
Raises:
ValueError: If the model's inputs are already set.
"""
assert context.in_eager_mode()
assert context.executing_eagerly()
if self.inputs:
raise ValueError('Model inputs are already set.')
# On-the-fly setting of model inputs/outputs as DeferredTensors,
@ -950,7 +950,7 @@ class Model(Network):
Raises:
ValueError: If the model's inputs are already set.
"""
assert context.in_graph_mode()
assert not context.executing_eagerly()
if self.inputs:
raise ValueError('Model inputs are already set.')
@ -1186,7 +1186,7 @@ class Model(Network):
val_y = None
val_sample_weights = None
if context.in_eager_mode():
if context.executing_eagerly():
return training_eager.fit_loop(
self,
inputs=x,
@ -1289,7 +1289,7 @@ class Model(Network):
sample_weight=sample_weight,
batch_size=batch_size)
if context.in_eager_mode():
if context.executing_eagerly():
return training_eager.test_loop(
self, inputs=x, targets=y, sample_weights=sample_weights,
batch_size=batch_size, verbose=verbose, steps=steps)
@ -1330,7 +1330,7 @@ class Model(Network):
'argument.')
x, _, _ = self._standardize_user_data(x)
if context.in_eager_mode():
if context.executing_eagerly():
return training_eager.predict_loop(
self, x, batch_size=batch_size, verbose=verbose, steps=steps)
else:
@ -1381,7 +1381,7 @@ class Model(Network):
sample_weight=sample_weight,
class_weight=class_weight)
if context.in_eager_mode():
if context.executing_eagerly():
outputs = training_eager.train_on_batch(
self, x, y, sample_weights=sample_weights)
else:
@ -1431,7 +1431,7 @@ class Model(Network):
x, y, sample_weights = self._standardize_user_data(
x, y, sample_weight=sample_weight)
if context.in_eager_mode():
if context.executing_eagerly():
outputs = training_eager.test_on_batch(
self, x, y, sample_weights=sample_weights)
else:
@ -1458,11 +1458,11 @@ class Model(Network):
"""
x, _, _ = self._standardize_user_data(x)
if context.in_eager_mode():
if context.executing_eagerly():
inputs = [ops.convert_to_tensor(val, dtype=K.floatx()) for val in x]
return self(inputs) # pylint: disable=not-callable
if context.in_graph_mode():
if not context.executing_eagerly():
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
ins = x + [0]
else:

View File

@ -553,7 +553,7 @@ class ZeroPaddingTest(test.TestCase):
layer = keras.layers.ZeroPadding1D(padding=2)
layer.build(shape)
output = layer(keras.backend.variable(inputs))
if context.in_eager_mode():
if context.executing_eagerly():
np_output = output.numpy()
else:
np_output = keras.backend.eval(output)
@ -564,7 +564,7 @@ class ZeroPaddingTest(test.TestCase):
layer = keras.layers.ZeroPadding1D(padding=(1, 2))
layer.build(shape)
output = layer(keras.backend.variable(inputs))
if context.in_eager_mode():
if context.executing_eagerly():
np_output = output.numpy()
else:
np_output = keras.backend.eval(output)
@ -610,7 +610,7 @@ class ZeroPaddingTest(test.TestCase):
padding=(2, 2), data_format=data_format)
layer.build(inputs.shape)
output = layer(keras.backend.variable(inputs))
if context.in_eager_mode():
if context.executing_eagerly():
np_output = output.numpy()
else:
np_output = keras.backend.eval(output)
@ -629,7 +629,7 @@ class ZeroPaddingTest(test.TestCase):
padding=((1, 2), (3, 4)), data_format=data_format)
layer.build(inputs.shape)
output = layer(keras.backend.variable(inputs))
if context.in_eager_mode():
if context.executing_eagerly():
np_output = output.numpy()
else:
np_output = keras.backend.eval(output)
@ -683,7 +683,7 @@ class ZeroPaddingTest(test.TestCase):
layer = keras.layers.ZeroPadding3D(padding=(2, 2, 2))
layer.build(inputs.shape)
output = layer(keras.backend.variable(inputs))
if context.in_eager_mode():
if context.executing_eagerly():
np_output = output.numpy()
else:
np_output = keras.backend.eval(output)
@ -737,7 +737,7 @@ class UpSamplingTest(test.TestCase):
size=(length_row, length_col), data_format=data_format)
layer.build(inputs.shape)
output = layer(keras.backend.variable(inputs))
if context.in_eager_mode():
if context.executing_eagerly():
np_output = output.numpy()
else:
np_output = keras.backend.eval(output)
@ -790,7 +790,7 @@ class UpSamplingTest(test.TestCase):
data_format=data_format)
layer.build(inputs.shape)
output = layer(keras.backend.variable(inputs))
if context.in_eager_mode():
if context.executing_eagerly():
np_output = output.numpy()
else:
np_output = keras.backend.eval(output)
@ -865,7 +865,7 @@ class CroppingTest(test.TestCase):
cropping=cropping, data_format=data_format)
layer.build(inputs.shape)
output = layer(keras.backend.variable(inputs))
if context.in_eager_mode():
if context.executing_eagerly():
np_output = output.numpy()
else:
np_output = keras.backend.eval(output)
@ -892,7 +892,7 @@ class CroppingTest(test.TestCase):
cropping=cropping, data_format=data_format)
layer.build(inputs.shape)
output = layer(keras.backend.variable(inputs))
if context.in_eager_mode():
if context.executing_eagerly():
np_output = output.numpy()
else:
np_output = keras.backend.eval(output)
@ -937,7 +937,7 @@ class CroppingTest(test.TestCase):
cropping=cropping, data_format=data_format)
layer.build(inputs.shape)
output = layer(keras.backend.variable(inputs))
if context.in_eager_mode():
if context.executing_eagerly():
np_output = output.numpy()
else:
np_output = keras.backend.eval(output)
@ -954,7 +954,7 @@ class CroppingTest(test.TestCase):
cropping[2][0]:-cropping[2][1], :]
np.testing.assert_allclose(np_output, expected_out)
# test incorrect use
# test incorrect use
with self.assertRaises(ValueError):
keras.layers.Cropping3D(cropping=(1, 1))
with self.assertRaises(ValueError):

View File

@ -124,7 +124,7 @@ class Dropout(tf_core_layers.Dropout, Layer):
training = K.learning_phase()
output = super(Dropout, self).call(inputs, training=training)
# EagerTensor object has no attribute _uses_learning_phase
if not context.in_eager_mode() and training is K.learning_phase():
if not context.executing_eagerly() and training is K.learning_phase():
output._uses_learning_phase = True # pylint: disable=protected-access
return output

View File

@ -111,7 +111,7 @@ class BatchNormalization(tf_normalization_layers.BatchNormalization, Layer):
if training is None:
training = K.learning_phase()
output = super(BatchNormalization, self).call(inputs, training=training)
if context.in_graph_mode() and training is K.learning_phase():
if not context.executing_eagerly() and training is K.learning_phase():
output._uses_learning_phase = True # pylint: disable=protected-access
return output

View File

@ -105,7 +105,7 @@ class Pooling2DTest(test.TestCase):
# This part of the test can only run on GPU but doesn't appear
# to be properly assigned to a GPU when running in eager mode.
if not context.in_eager_mode():
if not context.executing_eagerly():
# Only runs on GPU with CUDA, channels_first is not supported on CPU.
# TODO(b/62340061): Support channels_first on CPU.
if test.is_gpu_available(cuda_only=True):

View File

@ -936,7 +936,7 @@ class SimpleRNNCell(Layer):
# Properly set learning phase on output tensor.
if 0 < self.dropout + self.recurrent_dropout:
if training is None and not context.in_eager_mode():
if training is None and not context.executing_eagerly():
# This would be harmless to set in eager mode, but eager tensors
# disallow setting arbitrary attributes.
output._uses_learning_phase = True
@ -1384,7 +1384,7 @@ class GRUCell(Layer):
hh = self.activation(x_h + recurrent_h)
h = z * h_tm1 + (1 - z) * hh
if 0 < self.dropout + self.recurrent_dropout:
if training is None and not context.in_eager_mode():
if training is None and not context.executing_eagerly():
# This would be harmless to set in eager mode, but eager tensors
# disallow setting arbitrary attributes.
h._uses_learning_phase = True
@ -1877,7 +1877,7 @@ class LSTMCell(Layer):
h = o * self.activation(c)
if 0 < self.dropout + self.recurrent_dropout:
if training is None and not context.in_eager_mode():
if training is None and not context.executing_eagerly():
# This would be harmless to set in eager mode, but eager tensors
# disallow setting arbitrary attributes.
h._uses_learning_phase = True

View File

@ -83,14 +83,14 @@ class AtrousConvolutionTest(test.TestCase):
checks = []
def add_check(check, *args, **kwargs):
if context.in_eager_mode():
if context.executing_eagerly():
args_val, kwargs_val = self.evaluate([args, kwargs])
check(*args_val, **kwargs_val)
else:
checks.append((check, args, kwargs))
yield add_check
if context.in_graph_mode():
if not context.executing_eagerly():
all_values = self.evaluate([[args, kwargs] for _, args, kwargs in checks])
for (check, _, _), (args, kwargs) in zip(checks, all_values):
check(*args, **kwargs)

View File

@ -102,17 +102,15 @@ class AssertEqualTest(test.TestCase):
with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"):
check_ops.assert_equal(static_big, static_small, message="fail")
# Dynamic check
if context.in_graph_mode():
with self.test_session():
small = array_ops.placeholder(dtypes.int32, name="small")
big = array_ops.placeholder(dtypes.int32, name="big")
with ops.control_dependencies(
[check_ops.assert_equal(
big, small, message="fail")]):
out = array_ops.identity(small)
with self.assertRaisesOpError("fail.*big.*small"):
out.eval(feed_dict={small: [1, 2], big: [3, 4]})
def test_raises_when_greater_dynamic(self):
with self.test_session():
small = array_ops.placeholder(dtypes.int32, name="small")
big = array_ops.placeholder(dtypes.int32, name="big")
with ops.control_dependencies(
[check_ops.assert_equal(big, small, message="fail")]):
out = array_ops.identity(small)
with self.assertRaisesOpError("fail.*big.*small"):
out.eval(feed_dict={small: [1, 2], big: [3, 4]})
def test_error_message_eager(self):
expected_error_msg_full = r"""big does not equal small
@ -182,15 +180,14 @@ First 2 elements of y:
with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"):
check_ops.assert_equal(static_big, static_small, message="fail")
# Dynamic check
if context.in_graph_mode():
with self.test_session():
small = array_ops.placeholder(dtypes.int32, name="small")
big = array_ops.placeholder(dtypes.int32, name="big")
with ops.control_dependencies([check_ops.assert_equal(small, big)]):
out = array_ops.identity(small)
with self.assertRaisesOpError("small.*big"):
out.eval(feed_dict={small: [3, 1], big: [4, 2]})
def test_raises_when_less_dynamic(self):
with self.test_session():
small = array_ops.placeholder(dtypes.int32, name="small")
big = array_ops.placeholder(dtypes.int32, name="big")
with ops.control_dependencies([check_ops.assert_equal(small, big)]):
out = array_ops.identity(small)
with self.assertRaisesOpError("small.*big"):
out.eval(feed_dict={small: [3, 1], big: [4, 2]})
@test_util.run_in_graph_and_eager_modes()
def test_doesnt_raise_when_equal_and_broadcastable_shapes(self):

View File

@ -360,7 +360,7 @@ class PyFuncTest(test.TestCase):
raise py_exp("blah") # pylint: disable=not-callable
if eager:
if context.in_eager_mode():
if context.executing_eagerly():
with self.assertRaisesRegexp(tf_exp, "blah"):
f = script_ops.eager_py_func(raise_exception, [], [])
return
@ -432,7 +432,7 @@ class PyFuncTest(test.TestCase):
output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[])
ret = self.evaluate(output)
if context.in_eager_mode():
if context.executing_eagerly():
self.assertEquals(len(ret), 0)
else:
self.assertIsNone(ret)

View File

@ -279,15 +279,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
# Tests for the 'read_value' argument:
assign_with_read = v.assign(3.0, read_value=True)
if context.in_graph_mode():
self.assertEqual(3.0, assign_with_read.eval())
else:
self.assertEqual(3.0, self.evaluate(assign_with_read))
self.assertEqual(3.0, self.evaluate(assign_with_read))
assign_without_read = v.assign(4.0, read_value=False)
if context.in_graph_mode():
self.assertIsInstance(assign_without_read, ops.Operation)
else:
if context.executing_eagerly():
self.assertIsNone(assign_without_read)
else:
self.assertIsInstance(assign_without_read, ops.Operation)
self.evaluate(assign_without_read)
self.assertEqual(4.0, self.evaluate(v.value()))
@ -355,15 +352,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
# Tests for the 'read_value' argument:
assign_with_read = v.assign_add(1.0, read_value=True)
if context.in_graph_mode():
self.assertEqual(3.0, assign_with_read.eval())
else:
self.assertEqual(3.0, self.evaluate(assign_with_read))
self.assertEqual(3.0, self.evaluate(assign_with_read))
assign_without_read = v.assign_add(1.0, read_value=False)
if context.in_graph_mode():
self.assertIsInstance(assign_without_read, ops.Operation)
else:
if context.executing_eagerly():
self.assertIsNone(assign_without_read)
else:
self.assertIsInstance(assign_without_read, ops.Operation)
self.evaluate(assign_without_read)
self.assertEqual(4.0, self.evaluate(v.value()))
@ -376,15 +370,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
# Tests for the 'read_value' argument:
assign_with_read = v.assign_sub(1.0, read_value=True)
if context.in_graph_mode():
self.assertEqual(1.0, assign_with_read.eval())
else:
self.assertEqual(1.0, self.evaluate(assign_with_read))
self.assertEqual(1.0, self.evaluate(assign_with_read))
assign_without_read = v.assign_sub(1.0, read_value=False)
if context.in_graph_mode():
self.assertIsInstance(assign_without_read, ops.Operation)
else:
if context.executing_eagerly():
self.assertIsNone(assign_without_read)
else:
self.assertIsInstance(assign_without_read, ops.Operation)
self.evaluate(assign_without_read)
self.assertEqual(0.0, self.evaluate(v.value()))
@ -485,7 +476,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.assertEqual("(10, 20, 35)", str(v.get_shape()))
self.assertEqual("(10, 20, 35)", str(v.value().shape))
self.assertEqual("(3, 20, 35)", str(v.sparse_read([0, 1, 2]).shape))
if context.in_graph_mode():
if not context.executing_eagerly():
self.assertEqual(
"<unknown>",
str(v.sparse_read(array_ops.placeholder(dtypes.int32)).shape))

View File

@ -111,10 +111,10 @@ class RNNTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testInvalidSequenceLengthShape(self):
cell = Plus1RNNCell()
if context.in_graph_mode():
inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))]
else:
if context.executing_eagerly():
inputs = [constant_op.constant(np.ones((3, 4)))]
else:
inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))]
with self.assertRaisesRegexp(ValueError, "must be a vector"):
rnn.dynamic_rnn(
cell,
@ -125,38 +125,30 @@ class RNNTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testBatchSizeFromInput(self):
cell = Plus1RNNCell()
in_graph_mode = context.in_graph_mode()
in_eager_mode = context.executing_eagerly()
# With static batch size
if in_graph_mode:
inputs = array_ops.placeholder(dtypes.float32, shape=(3, 4, 5))
initial_state = array_ops.placeholder(dtypes.float32, shape=(3, 5))
else:
if in_eager_mode:
inputs = np.zeros((3, 4, 5), dtype=np.float32)
initial_state = np.zeros((3, 5), dtype=np.float32)
else:
inputs = array_ops.placeholder(dtypes.float32, shape=(3, 4, 5))
initial_state = array_ops.placeholder(dtypes.float32, shape=(3, 5))
# - Without initial_state
outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32)
if in_graph_mode:
self.assertEqual(3, outputs.shape[0].value)
self.assertEqual(3, state.shape[0].value)
else:
self.assertEqual(3, outputs.shape[0])
self.assertEqual(3, state.shape[0])
self.assertEqual(3, outputs.shape[0])
self.assertEqual(3, state.shape[0])
# - With initial_state
outputs, state = rnn.dynamic_rnn(
cell, inputs, initial_state=initial_state)
if in_graph_mode:
self.assertEqual(3, outputs.shape[0].value)
self.assertEqual(3, state.shape[0].value)
else:
self.assertEqual(3, outputs.shape[0])
self.assertEqual(3, state.shape[0])
self.assertEqual(3, outputs.shape[0])
self.assertEqual(3, state.shape[0])
# Without static batch size
# Tensor shapes are fully determined in Eager mode, so only run this
# test in graph mode.
if in_graph_mode:
# Tensor shapes are fully determined with eager execution enabled,
# so only run this test for graph construction.
if not in_eager_mode:
inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 5))
# - Without initial_state
outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32)
@ -173,56 +165,46 @@ class RNNTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testScalarStateIsAccepted(self):
cell = ScalarStateRNNCell()
in_graph_mode = context.in_graph_mode()
in_eager_mode = context.executing_eagerly()
if in_graph_mode:
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
else:
if in_eager_mode:
inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32)
else:
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
with self.test_session() as sess:
outputs, state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32, sequence_length=[4])
if in_graph_mode:
if not in_eager_mode:
outputs, state = sess.run(
[outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]})
if in_graph_mode:
self.assertAllEqual(outputs, np.array([[[1], [2], [3], [4]]]))
self.assertEqual(state, 4)
else:
self.assertAllEqual(outputs.numpy(), np.array([[[1], [2], [3], [4]]]))
self.assertEqual(state.numpy(), 4)
self.assertAllEqual([[[1], [2], [3], [4]]], outputs)
self.assertAllEqual(4, state)
@test_util.run_in_graph_and_eager_modes()
def testTensorArrayStateIsAccepted(self):
cell = TensorArrayStateRNNCell()
in_graph_mode = context.in_graph_mode()
in_eager_mode = context.executing_eagerly()
if in_graph_mode:
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
else:
if in_eager_mode:
inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32)
else:
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
with self.test_session() as sess:
outputs, state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32, sequence_length=[4])
state = (state[0], state[1].stack())
if in_graph_mode:
if not in_eager_mode:
outputs, state = sess.run(
[outputs, state], feed_dict={
inputs: [[[1], [2], [3], [4]]]
})
if in_graph_mode:
self.assertAllEqual(outputs, np.array([[[1], [2], [3], [4]]]))
self.assertEqual(state[0], 4)
self.assertAllEqual(state[1], np.array([[[1]], [[2]], [[3]], [[4]]]))
else:
self.assertAllEqual(outputs.numpy(), np.array([[[1], [2], [3], [4]]]))
self.assertEqual(state[0].numpy(), 4)
self.assertAllEqual(state[1].numpy(),
np.array([[[1]], [[2]], [[3]], [[4]]]))
self.assertAllEqual([[[1], [2], [3], [4]]], outputs)
self.assertAllEqual(4, state[0])
self.assertAllEqual([[[1]], [[2]], [[3]], [[4]]], state[1])
######### Benchmarking RNN code

View File

@ -283,7 +283,7 @@ class SliceTest(test.TestCase):
# unintended behavior is prevented.
c = constant_op.constant(5.0)
with self.assertRaisesWithPredicateMatch(
TypeError, lambda e: "`Tensor` objects are not iterable" in str(e)):
TypeError, lambda e: "Tensor objects are not iterable" in str(e)):
for _ in c:
pass

View File

@ -562,7 +562,7 @@ class TemplateTest(test.TestCase):
outputs_b, _ = linear1(inputs)
self.assertEquals("foo", linear1.variable_scope.name)
self.assertEquals("foo/w:0", w1.name)
if context.in_graph_mode():
if not context.executing_eagerly():
self.assertEquals("foo/add:0", outputs_a.name,
"First application of template should get "
"same name scope as variables.")
@ -577,7 +577,7 @@ class TemplateTest(test.TestCase):
"New template gets a freshly uniquified variable scope "
"because 'foo' is already taken.")
self.assertEquals("foo_1/w:0", w2.name)
if context.in_graph_mode():
if not context.executing_eagerly():
self.assertEquals("foo_1_1/add:0", outputs_c.name,
"First application of template would get "
"same name scope as variables, but 'foo_1' is already "
@ -592,7 +592,7 @@ class TemplateTest(test.TestCase):
with variable_scope.variable_scope("foo"):
# Create two templates with the same name, ensure scopes are made unique.
ta = template.make_template("bar", variable_scoped_function, True)
if context.in_eager_mode():
if context.executing_eagerly():
tb = template.make_template("s", function_with_side_create,
trainable=False)
else:

View File

@ -399,28 +399,14 @@ class TensorArrayTest(test.TestCase):
def testTensorArrayWriteWrongIndexOrDataTypeFails(self):
with self.test_session(use_gpu=True):
ta = _make_ta(3, "foo", dtype=dtypes.float32)
in_graph_mode = context.in_graph_mode()
# Test writing the wrong datatype
if in_graph_mode:
with self.assertRaisesOpError(
"TensorArray dtype is float but Op is trying to write "
"dtype string"):
self.evaluate(ta.write(0, "wrong_type_scalar").flow)
else:
with self.assertRaisesOpError(
"TensorArray dtype is float32 but Op is trying to write "
"dtype string"):
self.evaluate(ta.write(0, "wrong_type_scalar").flow)
with self.assertRaisesOpError(
"TensorArray dtype is (float|float32) but Op is trying to write "
"dtype string"):
self.evaluate(ta.write(0, "wrong_type_scalar").flow)
if context.in_graph_mode():
with self.assertRaisesOpError(
"Tried to write to index -1 but array is not "
"resizeable and size is: 3"):
self.evaluate(ta.write(-1, 3.0).flow)
else:
with self.assertRaisesOpError(
r"Writing to negative indices \(index -1\) is not allowed."):
self.evaluate(ta.write(-1, 3.0).flow)
with self.assertRaisesOpError("index -1"):
self.evaluate(ta.write(-1, 3.0).flow)
# Test reading from too large an index
with self.assertRaisesOpError(
@ -435,8 +421,8 @@ class TensorArrayTest(test.TestCase):
w0 = ta.write(0, [[4.0, 5.0]])
# Test reading wrong datatype, which is only possible in graph mode
if context.in_graph_mode():
# Test reading wrong datatype (only possible when constructing graphs).
if not context.executing_eagerly():
r0_bad = gen_data_flow_ops.tensor_array_read_v3(
handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow)
with self.assertRaisesOpError(
@ -444,14 +430,8 @@ class TensorArrayTest(test.TestCase):
r0_bad.eval()
# Test reading from a negative index, which is not allowed
if context.in_graph_mode():
with self.assertRaisesOpError(
r"Tried to read from index -1 but array size is: 3"):
self.evaluate(ta.read(-1))
else:
with self.assertRaisesOpError(
r"Reading from negative indices \(index -1\) is not allowed."):
self.evaluate(ta.read(-1))
with self.assertRaisesOpError("index -1"):
self.evaluate(ta.read(-1))
# Test reading from too large an index
with self.assertRaisesOpError(
@ -467,10 +447,7 @@ class TensorArrayTest(test.TestCase):
with self.assertRaisesOpError(
"Could not write to TensorArray index 2 because "
"it has already been written to."):
if context.in_graph_mode():
self.evaluate(ta.write(2, 3.0).write(2, 3.0).flow)
else:
self.evaluate(ta.write(2, 3.0).write(2, 3.0))
self.evaluate(ta.write(2, 3.0).write(2, 3.0).flow)
@test_util.run_in_graph_and_eager_modes()
def testTensorArrayConcatIncompatibleShapesFails(self):
@ -499,58 +476,40 @@ class TensorArrayTest(test.TestCase):
w2 = w1.write(1, [4.0])
w3 = w2.write(2, [[3.0]])
# The eager-mode implementation just passes up array_op.concat's error
# message.
if context.in_graph_mode():
with self.assertRaisesOpError(
r"TensorArray has inconsistent shapes. Index 0 has "
r"\(excepting dimension 0\) shape: \[\] but index 2 has "
r"\(excepting dimension 0\) shape: \[1\]"):
self.evaluate(w3.concat())
else:
with self.assertRaisesOpError(
r".*Ranks of all input tensors should match: shape\[0\] "
r"= \[1\] vs\. shape\[2\] = \[1,1\].*"):
self.evaluate(w3.concat())
# The exact error messages differ between eager execution and graph
# construction as the former bubbles up the error from array_op.concat.
with self.assertRaisesOpError("shape"):
self.evaluate(w3.concat())
@test_util.run_in_graph_and_eager_modes()
def testTensorArraySplitIncompatibleShapesFails(self):
with self.test_session(use_gpu=True):
in_graph_mode = context.in_graph_mode()
in_eager_mode = context.executing_eagerly()
ta = _make_ta(3, "foo")
with self.assertRaisesOpError(
r"Expected lengths to be a vector, received shape: \[\]"):
if in_graph_mode:
if in_eager_mode:
self.evaluate(ta.split([1.0, 2.0, 3.0], 1))
else:
lengths = array_ops.placeholder(dtypes.int64)
ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1})
else:
self.evaluate(ta.split([1.0, 2.0, 3.0], 1))
with self.assertRaisesOpError(
r"Expected sum of lengths to be equal to values.shape\[0\], "
r"but sum of lengths is 1 and value's shape is: \[3\]"):
if in_graph_mode:
self.evaluate(ta.split([1.0, 2.0, 3.0], [1]).flow)
else:
self.evaluate(ta.split([1.0, 2.0, 3.0], [1]))
self.evaluate(ta.split([1.0, 2.0, 3.0], [1]).flow)
ta = _make_ta(1, "baz")
with self.assertRaisesOpError(
r"Expected value to be at least a vector, but received shape: \[\]"):
if in_graph_mode:
self.evaluate(ta.split(1.0, [1]).flow)
else:
self.evaluate(ta.split(1.0, [1]))
self.evaluate(ta.split(1.0, [1]).flow)
ta = _make_ta(2, "buz")
with self.assertRaisesOpError(
r"TensorArray's size is not equal to the size of lengths "
r"\(2 vs. 1\), and the TensorArray is not marked as "
r"dynamically resizeable"):
if in_graph_mode:
self.evaluate(ta.split([1.0], [1]).flow)
else:
self.evaluate(ta.split([1.0], [1]))
self.evaluate(ta.split([1.0], [1]).flow)
def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype):
with self.test_session(use_gpu=True):
@ -868,14 +827,14 @@ class TensorArrayTest(test.TestCase):
vout = func(v0, state0, var)
grad_val = -np.arange(3 * 5, dtype=np_dtype).reshape(3, 5)
if context.in_graph_mode():
if context.executing_eagerly():
grad_fn = backprop.gradients_function(func)
v0_grad, state0_grad, var_grad = grad_fn(v0, state0, var, dy=grad_val)
else:
v0_grad = gradients_impl.gradients([vout], [v0], [grad_val])[0]
state0_grad = gradients_impl.gradients([vout], [state0], [grad_val])[0]
var_grad = gradients_impl.gradients([vout], [var], [grad_val])[0]
variables.global_variables_initializer().run()
else:
grad_fn = backprop.gradients_function(func)
v0_grad, state0_grad, var_grad = grad_fn(v0, state0, var, dy=grad_val)
state0_t, var_t, v0_t, vout_t, v0_grad_t, var_grad_t, state0_grad_t = (
self.evaluate(
@ -959,10 +918,10 @@ class TensorArrayTest(test.TestCase):
return r
x = constant_op.constant(2.0, name="x")
if context.in_graph_mode():
grad = gradients_impl.gradients(loop(x), [x])[0]
else:
if context.executing_eagerly():
grad = backprop.gradients_function(loop)(x)[0]
else:
grad = gradients_impl.gradients(loop(x), [x])[0]
self.assertAllClose(31.0, self.evaluate(grad))
def testSumOfTwoReadVariablesWithoutRepeatGrad(self):
@ -1158,14 +1117,14 @@ class TensorArrayTest(test.TestCase):
infer_shape=True)
w0 = ta1.split(value, [1, 2])
r0 = w0.read(0)
if context.in_graph_mode():
if context.executing_eagerly():
self.assertEqual((1, 2), r0.get_shape())
self.assertEqual((2, 2), w0.read(1).get_shape())
else:
self.assertEqual(r0.get_shape().ndims, None)
self.assertEqual(
tensor_shape.TensorShape(
ta1.handle.op.get_attr("element_shape")).ndims, None)
else:
self.assertEqual((1, 2), r0.get_shape())
self.assertEqual((2, 2), w0.read(1).get_shape())
def testWriteUnknownShape(self):
with self.test_session(use_gpu=True):
@ -1297,13 +1256,13 @@ class TensorArrayTest(test.TestCase):
g = func(values)
grad_ys = [[[2.0, 3.0], [4.0, 5.0]]]
# Test combined gradients + aggregation of read(0)
if context.in_graph_mode():
grad = gradients_impl.gradients(ys=[g], xs=[values], grad_ys=grad_ys)
g_vals, grad_vals = session.run([[g], grad])
else:
if context.executing_eagerly():
g_vals = [g]
grad_vals = backprop.gradients_function(func)(
values, dy=constant_op.constant(grad_ys[0], dtype=dtypes.float32))
else:
grad = gradients_impl.gradients(ys=[g], xs=[values], grad_ys=grad_ys)
g_vals, grad_vals = session.run([[g], grad])
# Gradients for 8 of the 10 unread components are zero.
expected_grad = np.zeros((10, 2))
@ -1453,13 +1412,13 @@ class TensorArrayTest(test.TestCase):
# Tests correct properties on new TensorArrays.
self.assertEqual(dtypes.float32, ta0.dtype)
self.assertEqual(dtypes.int32, ta1.dtype)
if context.in_graph_mode():
self.assertEqual(tensor_shape.unknown_shape(), read0.get_shape())
if context.executing_eagerly():
self.assertEqual(tensor_shape.scalar(), read0.get_shape())
else:
self.assertEqual(tensor_shape.scalar(), read1.get_shape())
self.assertEqual(tensor_shape.unknown_shape(), read0.get_shape())
self.assertEqual(tensor_shape.scalar(), read1.get_shape())
if context.in_graph_mode():
if not context.executing_eagerly():
variables.global_variables_initializer().run()
read0_v, read1_v, size0_v, size1_v = self.evaluate((read0, read1, size0,

View File

@ -166,12 +166,10 @@ class VariableScopeTest(test.TestCase):
self.evaluate(variables_lib.variables_initializer([w]))
self.assertAllClose(self.evaluate(w.value()), [1, 2, 3])
if context.in_graph_mode():
with self.assertRaises(TypeError):
variable_scope.get_variable("x4", initializer={})
else:
with self.assertRaises(ValueError):
variable_scope.get_variable("x4", initializer={})
# A quirk to be revisited?
error = ValueError if context.executing_eagerly() else TypeError
with self.assertRaises(error):
variable_scope.get_variable("x4", initializer={})
@test_util.run_in_graph_and_eager_modes()
def testInitFromNonInitializer(self):
@ -267,7 +265,7 @@ class VariableScopeTest(test.TestCase):
self.assertAllClose(self.evaluate(losses[2]), 0.5)
with variable_scope.variable_scope("foo", reuse=True):
# reuse=True is for now only supported when eager execution is disabled.
if context.in_graph_mode():
if not context.executing_eagerly():
v = variable_scope.get_variable("v",
[]) # "v" is alredy there, reused
losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
@ -374,7 +372,7 @@ class VariableScopeTest(test.TestCase):
v = variable_scope.get_variable("v", [])
self.evaluate(variables_lib.variables_initializer([v]))
self.assertAllClose(self.evaluate(v.value()), 0.3)
if context.in_graph_mode():
if not context.executing_eagerly():
# Check that we can set reuse.
variable_scope.get_variable_scope().reuse_variables()
with self.assertRaises(ValueError): # Fail, w does not exist yet.
@ -408,7 +406,7 @@ class VariableScopeTest(test.TestCase):
with variable_scope.variable_scope("tower") as tower:
with ops.name_scope("scope2") as sc2:
self.assertEqual(sc2, "testVarScopeNameScope1/tower/scope2/")
if context.in_graph_mode():
if not context.executing_eagerly():
with variable_scope.variable_scope(
tower): # Re-entering acts like another "tower".
with ops.name_scope("scope2") as sc2:
@ -422,7 +420,7 @@ class VariableScopeTest(test.TestCase):
with variable_scope.variable_scope("tower"):
with ops.name_scope("scope2") as sc2:
self.assertEqual(sc2, "testVarScopeNameScope2/tower/scope2/")
if context.in_graph_mode():
if not context.executing_eagerly():
with variable_scope.variable_scope(tower):
with ops.name_scope("scope2") as sc2:
self.assertEqual(sc2, "testVarScopeNameScope2/tower_1/scope2/")
@ -903,17 +901,15 @@ class VariableScopeTest(test.TestCase):
"w", [], collections=["foo"])
self.assertEqual(local_var.name, "outer/w:0")
# Since variable is local, it should be in the local variable collection
# but not the trainable collection.
if context.in_graph_mode():
if not context.executing_eagerly():
# Since variable is local, it should be in the local variable collection
# but not the trainable collection.
self.assertIn(local_var,
ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
self.assertIn(local_var, ops.get_collection("foo"))
self.assertNotIn(local_var,
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
# Check that local variable respects `reuse`.
if context.in_graph_mode():
# Check that local variable respects `reuse`.
with variable_scope.variable_scope(outer, "default", reuse=True):
self.assertEqual(
variable_scope.get_local_variable("w", []).name, "outer/w:0")

View File

@ -115,7 +115,7 @@ class Layer(checkpointable.CheckpointableBase):
# Provides information about which inputs are compatible with the layer.
self.input_spec = None
if activity_regularizer and context.in_eager_mode():
if activity_regularizer and context.executing_eagerly():
raise ValueError(
('Activity regularization is not supported when executing eagerly. '
'Got activity_regularizer=%s') % (activity_regularizer,))
@ -228,7 +228,7 @@ class Layer(checkpointable.CheckpointableBase):
@property
def updates(self):
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('Layer.updates not supported in Eager mode.')
if not self.trainable and not self.stateful:
return []
@ -260,7 +260,7 @@ class Layer(checkpointable.CheckpointableBase):
have is available at runtime.
A step counter might fall into this category.
"""
if context.in_eager_mode():
if context.executing_eagerly():
return # Updates already applied when in eager mode.
updates = _to_list(updates)
@ -286,7 +286,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises:
RuntimeError: If called in Eager mode.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('`get_updates_for()` not supported in Eager mode.')
# Updates disabled if layer is not trainable and not explicitly stateful.
@ -317,7 +317,7 @@ class Layer(checkpointable.CheckpointableBase):
Returns:
A list of tensors.
"""
if context.in_eager_mode():
if context.executing_eagerly():
# _losses may only contain variable regularization losses when executing
# eagerly, and they have been saved as lambdas to be executed when
# requested.
@ -355,7 +355,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises:
RuntimeError: If called in Eager mode.
"""
if context.in_eager_mode():
if context.executing_eagerly():
# TODO(fchollet): it should be possible (and highly desirable) to support
# `add_loss` in eager mode. This allows great convenience and flexibility
# in defining custom losses on the fly (e.g. in VAEs).
@ -389,7 +389,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises:
RuntimeError: If called in Eager mode.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('Layer.get_losses_for not supported in Eager mode.')
if inputs is None:
@ -509,7 +509,7 @@ class Layer(checkpointable.CheckpointableBase):
# will occur; it should be None if and only if initialization will take
# place in the eager context.
init_graph = None
if context.in_graph_mode():
if not context.executing_eagerly():
default_graph = ops.get_default_graph()
if default_graph.building_function:
with ops.init_scope():
@ -517,7 +517,7 @@ class Layer(checkpointable.CheckpointableBase):
# will be lifted; if initialization ops will be lifted into
# the eager context, then there is nothing to retrieve, since variable
# collections are not supported when eager execution is enabled.
if context.in_graph_mode():
if not context.executing_eagerly():
init_graph = ops.get_default_graph()
existing_variables = set(tf_variables.global_variables())
else:
@ -624,17 +624,17 @@ class Layer(checkpointable.CheckpointableBase):
self._set_scope(kwargs.pop('scope', None))
input_list = nest.flatten(inputs)
in_graph_mode = context.in_graph_mode()
build_graph = not context.executing_eagerly()
in_deferred_mode = isinstance(input_list[0], _DeferredTensor)
# Ensure the Layer, if being reused, is working with inputs from
# the same graph as where it was created.
if in_graph_mode:
if build_graph:
try:
# Set layer's "graph" at build time
self._graph = ops._get_graph_from_inputs(input_list, graph=self._graph) # pylint: disable=protected-access
except ValueError as e:
raise ValueError('Input graph and Layer graph are not the same: %s' % e)
if in_graph_mode or in_deferred_mode:
if build_graph or in_deferred_mode:
user_kwargs = copy.copy(kwargs)
# Handle Keras mask propagation from previous layer to current layer.
@ -669,13 +669,14 @@ class Layer(checkpointable.CheckpointableBase):
with scope_context_manager as scope:
with ops.name_scope(self._name_scope_name(scope)):
if not self.built:
if not in_graph_mode:
if not build_graph:
# Activity regularization is currently unsupported in Eager mode.
if self._activity_regularizer:
raise ValueError('activity_regularizer currently unsupported in '
'Eager mode. Found an activity_regularizer in '
'%s(%s).' % (self.__class__.__name__, self))
if not in_graph_mode and not in_deferred_mode:
raise ValueError(
'activity_regularizer currently unsupported with '
'eager execution enabled. Found an activity_regularizer in '
'%s(%s).' % (self.__class__.__name__, self))
if not build_graph and not in_deferred_mode:
# TODO(agarwal): support _keras_history in Eager mode.
for x in input_list:
if hasattr(x, '_keras_history'):
@ -706,7 +707,7 @@ class Layer(checkpointable.CheckpointableBase):
if call_has_scope_arg:
kwargs['scope'] = scope
# Check input assumptions set after layer building, e.g. input shape.
if in_graph_mode or in_deferred_mode:
if build_graph or in_deferred_mode:
self._assert_input_compatibility(inputs)
if not in_deferred_mode:
@ -730,7 +731,7 @@ class Layer(checkpointable.CheckpointableBase):
if len(outputs) == 1:
outputs = outputs[0]
if in_graph_mode:
if build_graph:
# Apply activity regularization.
# Note that it should be applied every time the layer creates a new
# output, since it is output-specific.
@ -752,7 +753,7 @@ class Layer(checkpointable.CheckpointableBase):
else:
outputs._keras_mask = output_mask # pylint: disable=protected-access
if in_graph_mode:
if build_graph:
# If all input tensors have history metadata,
# we update the output tensors
# with corresponding history metadata, thus eventually allowing to use
@ -775,7 +776,7 @@ class Layer(checkpointable.CheckpointableBase):
# Update global default collections.
_add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
if in_deferred_mode or in_graph_mode:
if in_deferred_mode or build_graph:
if _have_all_keras_metadata(inputs):
# Add an inbound node to the layer, so it can keep track of this call.
# This updates the layer history of the output tensor(s).
@ -787,7 +788,7 @@ class Layer(checkpointable.CheckpointableBase):
@property
def graph(self):
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('Layer.graph not supported in Eager mode.')
return self._graph
@ -891,7 +892,7 @@ class Layer(checkpointable.CheckpointableBase):
mode.
ValueError: If the index provided does not match any node.
"""
assert context.in_graph_mode()
assert not context.executing_eagerly()
if not self._inbound_nodes:
raise RuntimeError('The layer has never been called '
'and thus has no defined ' + attr_name + '.')
@ -921,7 +922,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises:
RuntimeError: If called in Eager mode.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError(
'Layer.get_input_shape_at not supported in Eager mode.')
return self._get_node_attribute_at_index(node_index, 'input_shapes',
@ -943,7 +944,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises:
RuntimeError: If called in Eager mode.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError(
'Layer.get_output_shape_at not supported in Eager mode.')
return self._get_node_attribute_at_index(node_index, 'output_shapes',
@ -964,7 +965,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises:
RuntimeError: If called in Eager mode.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('Layer.get_input_at not supported in Eager mode.')
return self._get_node_attribute_at_index(node_index, 'input_tensors',
'input')
@ -984,7 +985,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises:
RuntimeError: If called in Eager mode.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('Layer.get_output_at not supported in Eager mode.')
return self._get_node_attribute_at_index(node_index, 'output_tensors',
'output')
@ -1007,7 +1008,7 @@ class Layer(checkpointable.CheckpointableBase):
RuntimeError: If called in Eager mode.
AttributeError: If no inbound nodes are found.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('Layer.input not supported in Eager mode.')
if not self._inbound_nodes:
raise AttributeError('Layer ' + self.name +
@ -1029,7 +1030,7 @@ class Layer(checkpointable.CheckpointableBase):
layers.
RuntimeError: if called in Eager mode.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('Layer.output not supported in Eager mode.')
if not self._inbound_nodes:
raise AttributeError('Layer ' + self.name + ' has no inbound nodes.')
@ -1051,7 +1052,7 @@ class Layer(checkpointable.CheckpointableBase):
AttributeError: if the layer has no defined input_shape.
RuntimeError: if called in Eager mode.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('Layer.input_shape not supported in Eager mode.')
if not self._inbound_nodes:
raise AttributeError('The layer has never been called '
@ -1112,7 +1113,7 @@ class Layer(checkpointable.CheckpointableBase):
AttributeError: if the layer has no defined output shape.
RuntimeError: if called in Eager mode.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('Layer.output_shape not supported in Eager mode.')
if not self._inbound_nodes:
raise AttributeError('The layer has never been called '
@ -1470,7 +1471,7 @@ def _to_list(x):
def _add_elements_to_collection(elements, collection_list):
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('Using collections from Layers not supported in Eager '
'mode. Tried to add %s to %s' % (elements,
collection_list))

View File

@ -44,7 +44,7 @@ class BaseLayerTest(test.TestCase):
self.assertEqual(layer.variables, [])
self.assertEqual(layer.trainable_variables, [])
self.assertEqual(layer.non_trainable_variables, [])
if context.in_graph_mode():
if not context.executing_eagerly():
# updates, losses only supported in GRAPH mode
self.assertEqual(layer.updates, [])
self.assertEqual(layer.losses, [])
@ -63,7 +63,7 @@ class BaseLayerTest(test.TestCase):
self.assertEqual(layer.variables, [variable])
self.assertEqual(layer.trainable_variables, [variable])
self.assertEqual(layer.non_trainable_variables, [])
if context.in_graph_mode():
if not context.executing_eagerly():
self.assertEqual(
layer.variables,
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
@ -77,7 +77,7 @@ class BaseLayerTest(test.TestCase):
self.assertEqual(layer.variables, [variable, variable_2])
self.assertEqual(layer.trainable_variables, [variable])
self.assertEqual(layer.non_trainable_variables, [variable_2])
if context.in_graph_mode():
if not context.executing_eagerly():
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
@ -161,7 +161,7 @@ class BaseLayerTest(test.TestCase):
inputs = random_ops.random_uniform((5,), seed=1)
outputs = layer.apply(inputs)
self.assertEqual(layer.built, True)
if context.in_graph_mode():
if not context.executing_eagerly():
# op is only supported in GRAPH mode
self.assertEqual(outputs.op.name, 'my_layer/Square')
@ -210,7 +210,7 @@ class BaseLayerTest(test.TestCase):
inputs = random_ops.random_uniform((5,), seed=1)
outputs = layer.apply(inputs)
self.assertEqual(layer.built, True)
if context.in_graph_mode():
if not context.executing_eagerly():
# op only supported in GRAPH mode.
self.assertEqual(outputs.op.name, 'my_layer/Square')
@ -280,7 +280,7 @@ class BaseLayerTest(test.TestCase):
def call(self, inputs):
return inputs
if context.in_graph_mode():
if not context.executing_eagerly():
layer = CustomerLayer()
with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
layer.apply(array_ops.placeholder('int32'))
@ -307,7 +307,7 @@ class BaseLayerTest(test.TestCase):
def call(self, inputs):
return inputs
if context.in_graph_mode():
if not context.executing_eagerly():
layer = CustomerLayer()
with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
layer.apply(array_ops.placeholder('int32'))
@ -335,7 +335,7 @@ class BaseLayerTest(test.TestCase):
def call(self, inputs):
return inputs
if context.in_graph_mode():
if not context.executing_eagerly():
layer = CustomerLayer()
with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
layer.apply(array_ops.placeholder('int32'))
@ -430,7 +430,7 @@ class BaseLayerTest(test.TestCase):
layer.apply(constant_op.constant(1))
# Works
if context.in_graph_mode():
if not context.executing_eagerly():
layer.apply(array_ops.placeholder('int32'))
layer.apply(array_ops.placeholder('int32', shape=(2, 3)))
@ -453,13 +453,7 @@ class BaseLayerTest(test.TestCase):
return {'l' + key: inputs[key] for key in inputs}
layer = DictLayer()
if context.in_graph_mode():
i1 = array_ops.placeholder('int32')
i2 = array_ops.placeholder('float32')
result = layer.apply({'abel': i1, 'ogits': i2})
self.assertTrue(isinstance(result, dict))
self.assertEqual(set(['label', 'logits']), set(result.keys()))
else:
if context.executing_eagerly():
i1 = constant_op.constant(3)
i2 = constant_op.constant(4.0)
result = layer.apply({'abel': i1, 'ogits': i2})
@ -467,6 +461,12 @@ class BaseLayerTest(test.TestCase):
self.assertEqual(set(['label', 'logits']), set(result.keys()))
self.assertEqual(3, result['label'].numpy())
self.assertEqual(4.0, result['logits'].numpy())
else:
i1 = array_ops.placeholder('int32')
i2 = array_ops.placeholder('float32')
result = layer.apply({'abel': i1, 'ogits': i2})
self.assertTrue(isinstance(result, dict))
self.assertEqual(set(['label', 'logits']), set(result.keys()))
def testActivityRegularizer(self):
regularizer = math_ops.reduce_sum

View File

@ -1664,7 +1664,7 @@ class Conv2DTranspose(Conv2D):
padding=self.padding.upper(),
data_format=utils.convert_data_format(self.data_format, ndim=4))
if context.in_graph_mode():
if not context.executing_eagerly():
# Infer the static output shape:
out_shape = inputs.get_shape().as_list()
out_shape[c_axis] = self.filters
@ -1969,7 +1969,7 @@ class Conv3DTranspose(Conv3D):
data_format=utils.convert_data_format(self.data_format, ndim=5),
padding=self.padding.upper())
if context.in_graph_mode():
if not context.executing_eagerly():
# Infer the static output shape:
out_shape = inputs.get_shape().as_list()
out_shape[c_axis] = self.filters

View File

@ -156,7 +156,7 @@ class Dense(base.Layer):
outputs = standard_ops.tensordot(inputs, self.kernel, [[len(shape) - 1],
[0]])
# Reshape the output back to the original ndim of the input.
if context.in_graph_mode():
if not context.executing_eagerly():
output_shape = shape[:-1] + [self.units]
outputs.set_shape(output_shape)
else:
@ -374,7 +374,7 @@ class Flatten(base.Layer):
def call(self, inputs):
outputs = array_ops.reshape(inputs, (array_ops.shape(inputs)[0], -1))
if context.in_graph_mode():
if not context.executing_eagerly():
outputs.set_shape(self.compute_output_shape(inputs.get_shape()))
return outputs

View File

@ -77,7 +77,7 @@ class DenseTest(test.TestCase):
self.assertListEqual(dense.trainable_variables,
[dense.kernel, dense.bias])
self.assertListEqual(dense.non_trainable_variables, [])
if context.in_graph_mode():
if not context.executing_eagerly():
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2)
self.assertEqual(dense.kernel.name, 'my_dense/kernel:0')
@ -98,7 +98,7 @@ class DenseTest(test.TestCase):
self.assertListEqual(dense.variables, [dense.kernel])
self.assertListEqual(dense.trainable_variables, [dense.kernel])
self.assertListEqual(dense.non_trainable_variables, [])
if context.in_graph_mode():
if not context.executing_eagerly():
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
self.assertEqual(dense.kernel.name, 'my_dense/kernel:0')
@ -113,7 +113,7 @@ class DenseTest(test.TestCase):
self.assertListEqual(dense.non_trainable_variables,
[dense.kernel, dense.bias])
self.assertListEqual(dense.trainable_variables, [])
if context.in_graph_mode():
if not context.executing_eagerly():
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 0)
@ -162,13 +162,13 @@ class DenseTest(test.TestCase):
dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1')
inputs = random_ops.random_uniform((5, 3), seed=1)
outputs = dense(inputs)
if context.in_graph_mode():
if not context.executing_eagerly():
self.assertEqual(outputs.op.name, 'dense1/Relu')
dense = core_layers.Dense(2, name='dense2')
inputs = random_ops.random_uniform((5, 3), seed=1)
outputs = dense(inputs)
if context.in_graph_mode():
if not context.executing_eagerly():
self.assertEqual(outputs.op.name, 'dense2/BiasAdd')
def testActivityRegularizer(self):
@ -374,7 +374,7 @@ class DropoutTest(test.TestCase):
dp = core_layers.Dropout(0.5)
inputs = array_ops.ones((5, 3))
dropped = dp.apply(inputs, training=True)
if context.in_graph_mode():
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
np_output = self.evaluate(dropped)
self.assertAlmostEqual(0., np_output.min())

View File

@ -338,8 +338,9 @@ class BatchNormalization(base.Layer):
return var
with ops.device(None):
device = ((lambda _: self.moving_mean.device)
if context.in_graph_mode() else self.moving_mean.device)
device = (
self.moving_mean.device if context.executing_eagerly() else
(lambda _: self.moving_mean.device))
with ops.device(device):
self.renorm_mean = _renorm_variable('renorm_mean', param_shape)
self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
@ -347,8 +348,9 @@ class BatchNormalization(base.Layer):
# renorm_stddev_weight. This allows us to (1) mix the average
# stddev with the minibatch stddev early in training, and (2) compute
# the unbiased average stddev by dividing renorm_stddev by the weight.
device = ((lambda _: self.moving_variance.device)
if context.in_graph_mode() else self.moving_variance.device)
device = (
self.moving_variance.device if context.executing_eagerly() else
(lambda _: self.moving_variance.device))
with ops.device(device):
self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape)
self.renorm_stddev_weight = _renorm_variable(
@ -420,7 +422,7 @@ class BatchNormalization(base.Layer):
one_minus_decay)
variance_update = self._assign_moving_average(self.moving_variance,
variance, one_minus_decay)
if context.in_graph_mode():
if not context.executing_eagerly():
# Note that in Eager mode, the updates are already executed when running
# assign_moving_averages. So we do not need to put them into
# collections.
@ -493,7 +495,7 @@ class BatchNormalization(base.Layer):
return (r, d, new_mean, new_variance)
def call(self, inputs, training=False):
in_eager_mode = context.in_eager_mode()
in_eager_mode = context.executing_eagerly()
if self.virtual_batch_size is not None:
# Virtual batches (aka ghost batches) can be simulated by reshaping the
# Tensor and reusing the existing batch norm implementation
@ -610,7 +612,7 @@ class BatchNormalization(base.Layer):
training,
lambda: _do_update(self.moving_variance, new_variance),
lambda: self.moving_variance)
if context.in_graph_mode():
if not context.executing_eagerly():
self.add_update(mean_update, inputs=inputs)
self.add_update(variance_update, inputs=inputs)

View File

@ -80,7 +80,7 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
def _ExtractInputShapes(inputs):
"""Extract the shapes of a set of input tensors."""
if not context.in_graph_mode():
if context.executing_eagerly():
return array_ops.shape_n(inputs)
sizes = []
fully_known = True
@ -106,7 +106,7 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
out_grads = []
if isinstance(grad, ops.Tensor):
if context.in_eager_mode():
if context.executing_eagerly():
# Using mod here for convenience since concat_dim is already verified
# in concat implementation to be within the allowed [-rank, rank) range.
non_neg_concat_dim = (
@ -428,7 +428,7 @@ def _GatherV2Grad(op, grad):
# For axis 0 gathers, build an appropriately shaped IndexedSlices.
if axis_static == 0:
if context.in_eager_mode():
if context.executing_eagerly():
params_tail_shape = params_shape.cpu()[1:]
else:
params_tail_shape = params_shape[1:]
@ -578,7 +578,7 @@ def _TileGrad(op, grad):
axes = math_ops.range(0, array_ops.size(split_shape), 2)
input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes)
# Fix shape inference
if context.in_graph_mode():
if not context.executing_eagerly():
input_grad.set_shape(op.inputs[0].get_shape())
return [input_grad, None]

View File

@ -128,9 +128,7 @@ def identity(input, name=None): # pylint: disable=redefined-builtin
Returns:
A `Tensor`. Has the same type as `input`.
"""
if context.in_graph_mode():
return gen_array_ops.identity(input, name=name)
else:
if context.executing_eagerly():
input = ops.convert_to_tensor(input)
in_device = input.device
# TODO(ashankar): Does 'identity' need to invoke execution callbacks?
@ -140,6 +138,8 @@ def identity(input, name=None): # pylint: disable=redefined-builtin
if context_device != in_device:
return input._copy() # pylint: disable=protected-access
return input
else:
return gen_array_ops.identity(input, name=name)
# pylint: disable=redefined-builtin,protected-access
@ -305,7 +305,7 @@ def shape_internal(input, name=None, optimize=True, out_type=dtypes.int32):
sparse_tensor.SparseTensorValue)):
return gen_math_ops.cast(input.dense_shape, out_type)
else:
if context.in_graph_mode():
if not context.executing_eagerly():
input_tensor = ops.convert_to_tensor(input)
input_shape = input_tensor.get_shape()
if optimize and input_shape.is_fully_defined():
@ -330,7 +330,7 @@ def shape_n(input, out_type=dtypes.int32, name=None):
"""
output = gen_array_ops.shape_n(input, out_type=out_type, name=name)
if context.in_graph_mode():
if not context.executing_eagerly():
for i, input_tensor in enumerate(input):
input_tensor = ops.convert_to_tensor(input_tensor)
input_shape = input_tensor.get_shape()
@ -385,9 +385,8 @@ def size_internal(input, name=None, optimize=True, out_type=dtypes.int32):
Returns:
A `Tensor` of type `out_type`. Defaults to `tf.int32`.
"""
if context.in_eager_mode() and not isinstance(
input, (sparse_tensor.SparseTensor,
sparse_tensor.SparseTensorValue)):
if context.executing_eagerly() and not isinstance(
input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
return np.prod(ops.convert_to_tensor(input)._shape_tuple()) # pylint: disable=protected-access
with ops.name_scope(name, "Size", [input]) as name:
if isinstance(input, (sparse_tensor.SparseTensor,
@ -783,7 +782,7 @@ def strided_slice(input_,
new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask)
if context.in_graph_mode():
if not context.executing_eagerly():
# TODO(apassos) In eager mode assignment will be done by overriding
# __setitem__ instead.
op.assign = assign
@ -1457,7 +1456,7 @@ def transpose(a, perm=None, name="transpose", conjugate=False):
ret = transpose_fn(a, perm, name=name)
# NOTE(mrry): Setting the shape explicitly because
# reverse is not handled by the shape function.
if context.in_graph_mode():
if not context.executing_eagerly():
input_shape = ret.op.inputs[0].get_shape().dims
if input_shape is not None:
ret.set_shape(input_shape[::-1])
@ -1622,7 +1621,7 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
with ops.name_scope(name, "zeros_like", [tensor]) as name:
tensor = ops.convert_to_tensor(tensor, name="tensor")
if context.in_eager_mode():
if context.executing_eagerly():
if dtype is not None and dtype != tensor.dtype:
return zeros(
shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
@ -1678,7 +1677,7 @@ def ones_like(tensor, dtype=None, name=None, optimize=True):
if dtype is None:
dtype = tensor.dtype
ret = ones(ones_shape, dtype=dtype, name=name)
if context.in_graph_mode():
if not context.executing_eagerly():
ret.set_shape(tensor.get_shape())
return ret
@ -1759,7 +1758,7 @@ def placeholder(dtype, shape=None, name=None):
Raises:
RuntimeError: if eager execution is enabled
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError("tf.placeholder() is not compatible with "
"eager execution.")
@ -1822,7 +1821,7 @@ def sparse_placeholder(dtype, shape=None, name=None):
Raises:
RuntimeError: if eager execution is enabled
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError("tf.placeholder() is not compatible with "
"eager execution.")
@ -1921,7 +1920,7 @@ def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pyl
raise ValueError("Unknown padding mode: %s" % mode)
# Restore shape information where possible.
if context.in_graph_mode():
if not context.executing_eagerly():
paddings_constant = tensor_util.constant_value(
result.op.inputs[1], partial=True)
input_shape = result.op.inputs[0].shape

View File

@ -169,7 +169,7 @@ def assert_negative(x, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_negative', [x, data]):
x = ops.convert_to_tensor(x, name='x')
if data is None:
if context.in_eager_mode():
if context.executing_eagerly():
name = _shape_and_dtype_str(x)
else:
name = x.name
@ -210,7 +210,7 @@ def assert_positive(x, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_positive', [x, data]):
x = ops.convert_to_tensor(x, name='x')
if data is None:
if context.in_eager_mode():
if context.executing_eagerly():
name = _shape_and_dtype_str(x)
else:
name = x.name
@ -251,7 +251,7 @@ def assert_non_negative(x, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_non_negative', [x, data]):
x = ops.convert_to_tensor(x, name='x')
if data is None:
if context.in_eager_mode():
if context.executing_eagerly():
name = _shape_and_dtype_str(x)
else:
name = x.name
@ -293,7 +293,7 @@ def assert_non_positive(x, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_non_positive', [x, data]):
x = ops.convert_to_tensor(x, name='x')
if data is None:
if context.in_eager_mode():
if context.executing_eagerly():
name = _shape_and_dtype_str(x)
else:
name = x.name
@ -343,7 +343,7 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
if context.in_eager_mode():
if context.executing_eagerly():
eq = math_ops.equal(x, y)
condition = math_ops.reduce_all(eq)
if not condition:
@ -435,7 +435,7 @@ def assert_none_equal(
with ops.name_scope(name, 'assert_none_equal', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
if context.in_eager_mode():
if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y)
else:
@ -512,7 +512,7 @@ def assert_near(
rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=x.dtype)
atol = ops.convert_to_tensor(atol, name='atol', dtype=x.dtype)
if context.in_eager_mode():
if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y)
else:
@ -562,7 +562,7 @@ def assert_less(x, y, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_less', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
if context.in_eager_mode():
if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y)
else:
@ -610,7 +610,7 @@ def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_less_equal', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
if context.in_eager_mode():
if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y)
else:
@ -658,7 +658,7 @@ def assert_greater(x, y, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_greater', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
if context.in_eager_mode():
if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y)
else:
@ -708,7 +708,7 @@ def assert_greater_equal(x, y, data=None, summarize=None, message=None,
with ops.name_scope(name, 'assert_greater_equal', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
if context.in_eager_mode():
if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y)
else:
@ -808,7 +808,7 @@ def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
static_condition = lambda actual_rank, given_rank: actual_rank == given_rank
dynamic_condition = math_ops.equal
if context.in_eager_mode():
if context.executing_eagerly():
name = ''
else:
name = x.name
@ -873,7 +873,7 @@ def assert_rank_at_least(
static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank
dynamic_condition = math_ops.greater_equal
if context.in_eager_mode():
if context.executing_eagerly():
name = ''
else:
name = x.name
@ -1001,7 +1001,7 @@ def assert_rank_in(
ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks])
message = message or ''
if context.in_eager_mode():
if context.executing_eagerly():
name = ''
else:
name = x.name
@ -1054,7 +1054,7 @@ def assert_integer(x, message=None, name=None):
with ops.name_scope(name, 'assert_integer', [x]):
x = ops.convert_to_tensor(x, name='x')
if not x.dtype.is_integer:
if context.in_eager_mode():
if context.executing_eagerly():
name = 'tensor'
else:
name = x.name
@ -1087,12 +1087,11 @@ def assert_type(tensor, tf_type, message=None, name=None):
with ops.name_scope(name, 'assert_type', [tensor]):
tensor = ops.convert_to_tensor(tensor, name='tensor')
if tensor.dtype != tf_type:
if context.in_graph_mode():
raise TypeError(
'%s %s must be of type %s' % (message, tensor.name, tf_type))
if context.executing_eagerly():
raise TypeError('%s tensor must be of type %s' % (message, tf_type))
else:
raise TypeError(
'%s tensor must be of type %s' % (message, tf_type))
raise TypeError('%s %s must be of type %s' % (message, tensor.name,
tf_type))
return control_flow_ops.no_op('statically_determined_correct_type')
@ -1240,7 +1239,7 @@ def assert_scalar(tensor, name=None):
tensor = ops.convert_to_tensor(tensor, name=name_scope)
shape = tensor.get_shape()
if shape.ndims != 0:
if context.in_eager_mode():
if context.executing_eagerly():
raise ValueError('Expected scalar shape, saw shape: %s.'
% (shape,))
else:

View File

@ -152,7 +152,7 @@ def Assert(condition, data, summarize=None, name=None):
@compatibility{eager} `tf.errors.InvalidArgumentError` if `condition`
is not true
"""
if context.in_eager_mode():
if context.executing_eagerly():
if not condition:
xs = ops.convert_n_to_tensor(data)
data_str = [_summarize_eager(x, summarize) for x in xs]
@ -178,7 +178,7 @@ def Assert(condition, data, summarize=None, name=None):
condition, data, summarize, name="Assert")
guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard")
if context.in_eager_mode():
if context.executing_eagerly():
return
return guarded_assert.op
@ -2025,7 +2025,7 @@ def cond(pred,
raise TypeError("false_fn must be callable.")
with ops.name_scope(name, "cond", [pred]):
if context.in_eager_mode():
if context.executing_eagerly():
if pred:
return _UnpackIfSingleton(true_fn())
return _UnpackIfSingleton(false_fn())
@ -3177,7 +3177,7 @@ def while_loop(cond,
math_ops.logical_and(i < maximum_iterations, orig_cond(*lv)))
body = lambda i, lv: (i + 1, orig_body(*lv))
if context.in_eager_mode():
if context.executing_eagerly():
while cond(*loop_vars):
loop_vars = body(*loop_vars)
if maximum_iterations is not None:
@ -3271,7 +3271,7 @@ def with_dependencies(dependencies, output_tensor, name=None):
Raises:
TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
"""
if context.in_eager_mode():
if context.executing_eagerly():
return output_tensor
with ops.name_scope(name, "control_dependency",
list(dependencies) + [output_tensor]) as name:
@ -3316,7 +3316,7 @@ def group(*inputs, **kwargs):
Raises:
ValueError: If an unknown keyword argument is provided.
"""
if context.in_eager_mode():
if context.executing_eagerly():
return None
name = kwargs.pop("name", None)
if kwargs:
@ -3396,7 +3396,7 @@ def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined
objects.
"""
if context.in_eager_mode():
if context.executing_eagerly():
return tensors
with ops.name_scope(name, "tuple", tensors) as name:
tensors = [t if (isinstance(t, ops.Operation)

View File

@ -92,7 +92,7 @@ def custom_gradient(f):
def decorated(*args, **kwargs):
"""Decorated function with custom gradient."""
if context.in_graph_mode():
if not context.executing_eagerly():
if kwargs:
raise ValueError(
"The custom_gradient decorator currently suports keywords "

View File

@ -159,7 +159,7 @@ class QueueBase(object):
ValueError: If one of the arguments is invalid.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError(
"Queues are not supported when eager execution is enabled. "
"Instead, please use tf.data to get data into your model.")
@ -177,10 +177,10 @@ class QueueBase(object):
else:
self._names = None
self._queue_ref = queue_ref
if context.in_graph_mode():
self._name = self._queue_ref.op.name.split("/")[-1]
else:
if context.executing_eagerly():
self._name = context.context().scope_name
else:
self._name = self._queue_ref.op.name.split("/")[-1]
@staticmethod
def from_list(index, queues):
@ -231,9 +231,9 @@ class QueueBase(object):
@property
def name(self):
"""The name of the underlying queue."""
if context.in_graph_mode():
return self._queue_ref.op.name
return self._name
if context.executing_eagerly():
return self._name
return self._queue_ref.op.name
@property
def dtypes(self):
@ -444,7 +444,7 @@ class QueueBase(object):
# NOTE(mrry): Not using a shape function because we need access to
# the `QueueBase` object.
if context.in_graph_mode():
if not context.executing_eagerly():
op = ret[0].op
for output, shape in zip(op.values(), self._shapes):
output.set_shape(shape)
@ -484,7 +484,7 @@ class QueueBase(object):
# NOTE(mrry): Not using a shape function because we need access to
# the Queue object.
if context.in_graph_mode():
if not context.executing_eagerly():
op = ret[0].op
batch_dim = tensor_shape.Dimension(
tensor_util.constant_value(op.inputs[1]))
@ -528,7 +528,7 @@ class QueueBase(object):
# NOTE(mrry): Not using a shape function because we need access to
# the Queue object.
if context.in_graph_mode():
if not context.executing_eagerly():
op = ret[0].op
for output, shape in zip(op.values(), self._shapes):
output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape))
@ -990,10 +990,10 @@ class Barrier(object):
shapes=self._shapes,
shared_name=shared_name,
name=name)
if context.in_graph_mode():
self._name = self._barrier_ref.op.name.split("/")[-1]
else:
if context.executing_eagerly():
self._name = context.context().scope_name
else:
self._name = self._barrier_ref.op.name.split("/")[-1]
@property
def barrier_ref(self):
@ -1003,9 +1003,9 @@ class Barrier(object):
@property
def name(self):
"""The name of the underlying barrier."""
if context.in_graph_mode():
return self._barrier_ref.op.name
return self._name
if context.executing_eagerly():
return self._name
return self._barrier_ref.op.name
def insert_many(self, component_index, keys, values, name=None):
"""For each key, assigns the respective value to the specified component.
@ -1083,7 +1083,7 @@ class Barrier(object):
# NOTE(mrry): Not using a shape function because we need access to
# the Barrier object.
if context.in_graph_mode():
if not context.executing_eagerly():
op = ret[0].op
if allow_small_batch:
batch_dim = None
@ -1183,10 +1183,10 @@ class ConditionalAccumulatorBase(object):
else:
self._shape = tensor_shape.unknown_shape()
self._accumulator_ref = accumulator_ref
if context.in_graph_mode():
self._name = self._accumulator_ref.op.name.split("/")[-1]
else:
if context.executing_eagerly():
self._name = context.context().scope_name
else:
self._name = self._accumulator_ref.op.name.split("/")[-1]
@property
def accumulator_ref(self):

View File

@ -90,7 +90,7 @@ def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
if not callable(fn):
raise TypeError("fn must be callable.")
in_graph_mode = context.in_graph_mode()
in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "foldl", [elems]):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
@ -178,7 +178,7 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
if not callable(fn):
raise TypeError("fn must be callable.")
in_graph_mode = context.in_graph_mode()
in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "foldr", [elems]):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
@ -343,7 +343,7 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
elems_flat = input_flatten(elems)
in_graph_mode = context.in_graph_mode()
in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "map", elems_flat):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
@ -536,7 +536,7 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
elems_flat = input_flatten(elems)
in_graph_mode = context.in_graph_mode()
in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "scan", elems_flat):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager

View File

@ -86,7 +86,7 @@ def _IndexedSlicesToTensor(value, dtype=None, name=None, as_ref=False):
% str(value))
# TODO(mrry): Consider adding static shape information to
# IndexedSlices, to avoid using numpy here.
if context.in_graph_mode():
if not context.executing_eagerly():
dense_shape_value = tensor_util.constant_value(value.dense_shape)
if dense_shape_value is not None:
num_elements = np.prod(dense_shape_value)
@ -491,9 +491,10 @@ def gradients(ys,
def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
gate_gradients, aggregation_method, stop_gradients):
"""Implementation of gradients()."""
if context.in_eager_mode():
raise RuntimeError("tf.gradients not supported in EAGER mode. Use "
"functions in tf.contrib.eager.backprop instead.")
if context.executing_eagerly():
raise RuntimeError("tf.gradients not supported when eager execution "
"is enabled. Use tf.contrib.eager.GradientTape "
"instead.")
ys = _AsList(ys)
xs = _AsList(xs)
stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)

View File

@ -173,7 +173,7 @@ class ReaderBase(object):
Raises:
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError(
"Readers are not supported when eager execution is enabled. "
"Instead, please use tf.data to get data into your model.")

View File

@ -157,10 +157,10 @@ class InitializableLookupTableBase(LookupInterface):
default_value: The value to use if a key is missing in the table.
initializer: The table initializer to use.
"""
if context.in_graph_mode():
name = table_ref.op.name.split("/")[-1]
else:
if context.executing_eagerly():
name = context.context().scope_name
else:
name = table_ref.op.name.split("/")[-1]
super(InitializableLookupTableBase,
self).__init__(initializer.key_dtype, initializer.value_dtype,
name)
@ -521,7 +521,7 @@ class TextFileInitializer(TableInitializerBase):
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
# If the filename tensor is anything other than a string constant (e.g., if
# it is a placeholder) then it does not make sense to track it as an asset.
if context.in_graph_mode() and constant_op.is_constant(filename):
if not context.executing_eagerly() and constant_op.is_constant(filename):
ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
return init_op

View File

@ -136,7 +136,7 @@ def _num_present(losses, weights, per_batch=False):
`[batch_size]`. Otherwise, a single scalar tensor is returned.
"""
if ((isinstance(weights, float) and weights != 0.0) or
(context.in_eager_mode() and weights._rank() == 0 # pylint: disable=protected-access
(context.executing_eagerly() and weights._rank() == 0 # pylint: disable=protected-access
and not math_ops.equal(weights, 0.0))):
return _num_elements(losses)
with ops.name_scope(None, "num_present", (losses, weights)) as scope:

View File

@ -52,14 +52,14 @@ def _SumGrad(op, grad):
if axes is not None:
rank = len(input_0_shape)
if np.array_equal(axes, np.arange(rank)): # Reduce all dims.
if context.in_graph_mode():
new_shape = [1] * rank
else:
if context.executing_eagerly():
ctx = context.context()
new_shape = ctx.ones_rank_cache().get(rank)
if new_shape is None:
new_shape = constant_op.constant([1] * rank, dtype=dtypes.int32)
ctx.ones_rank_cache().put(rank, new_shape)
else:
new_shape = [1] * rank
grad = array_ops.reshape(grad, new_shape)
# If shape is not fully defined (but rank is), we use Shape.
if None not in input_0_shape:
@ -997,7 +997,7 @@ def _SparseMatMulGrad(op, grad):
op.inputs[0]: op.get_attr("a_is_sparse"),
op.inputs[1]: op.get_attr("b_is_sparse"),
# Use heuristic to figure out if grad might be sparse
grad: context.in_graph_mode() and (grad.op.type == "ReluGrad")
grad: not context.executing_eagerly() and (grad.op.type == "ReluGrad")
}
def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False):

View File

@ -2007,14 +2007,14 @@ def matmul(a,
if transpose_b and adjoint_b:
raise ValueError("Only one of transpose_b and adjoint_b can be True.")
if context.in_graph_mode():
a = ops.convert_to_tensor(a, name="a")
b = ops.convert_to_tensor(b, name="b")
else:
if context.executing_eagerly():
if not isinstance(a, (ops.EagerTensor, _resource_variable_type)):
a = ops.convert_to_tensor(a, name="a")
if not isinstance(b, (ops.EagerTensor, _resource_variable_type)):
b = ops.convert_to_tensor(b, name="b")
else:
a = ops.convert_to_tensor(a, name="a")
b = ops.convert_to_tensor(b, name="b")
# TODO(apassos) remove _shape_tuple here when it is not needed.
a_shape = a._shape_tuple() # pylint: disable=protected-access
@ -2249,7 +2249,7 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
return inputs[0]
elif len(inputs) == 1 and name is not None:
return array_ops.identity(inputs[0], name=name)
elif context.in_eager_mode():
elif context.executing_eagerly():
# TemporaryVariable not currently supported in eager mode; fall back
# onto AddN for now.
# TODO(frreiss) remove this once the lifetime of eager variables gets

View File

@ -60,7 +60,7 @@ class ReduceTest(test_util.TensorFlowTestCase):
@test_util.run_in_graph_and_eager_modes()
def testReduceInvalidAxis(self):
if context.in_eager_mode():
if context.executing_eagerly():
# The shape check is in run a graph construction time. In eager mode,
# it misses the check, magically return result given wrong shape.
return
@ -249,7 +249,7 @@ class ScalarMulTest(test_util.TensorFlowTestCase):
@test_util.run_in_graph_and_eager_modes()
def testAcceptsRefs(self):
if context.in_eager_mode():
if context.executing_eagerly():
var = resource_variable_ops.ResourceVariable(10, name="var")
else:
var = variables.Variable(10)

View File

@ -308,7 +308,7 @@ def mean(values,
or tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean is not supported when eager execution '
'is enabled.')
@ -394,7 +394,7 @@ def accuracy(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.accuracy is not supported when eager '
'execution is enabled.')
@ -644,7 +644,7 @@ def auc(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.auc is not supported when eager execution '
'is enabled.')
@ -758,7 +758,7 @@ def mean_absolute_error(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_absolute_error is not supported '
'when eager execution is enabled.')
@ -818,7 +818,7 @@ def mean_cosine_distance(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when '
'eager execution is enabled.')
@ -891,7 +891,7 @@ def mean_per_class_accuracy(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported '
'when eager execution is enabled.')
@ -996,7 +996,7 @@ def mean_iou(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_iou is not supported when '
'eager execution is enabled.')
@ -1098,7 +1098,7 @@ def mean_relative_error(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_relative_error is not supported when '
'eager execution is enabled.')
@ -1165,7 +1165,7 @@ def mean_squared_error(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_squared_error is not supported when '
'eager execution is enabled.')
@ -1223,7 +1223,7 @@ def mean_tensor(values,
or tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_tensor is not supported when '
'eager execution is enabled.')
@ -1304,7 +1304,7 @@ def percentage_below(values,
or tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.percentage_below is not supported when '
'eager execution is enabled.')
@ -1397,7 +1397,7 @@ def false_negatives(labels,
or tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.false_negatives is not supported when '
'eager execution is enabled.')
@ -1453,7 +1453,7 @@ def false_negatives_at_thresholds(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not '
'supported when eager execution is enabled.')
@ -1507,7 +1507,7 @@ def false_positives(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.false_positives is not supported when '
'eager execution is enabled.')
@ -1563,7 +1563,7 @@ def false_positives_at_thresholds(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.false_positives_at_thresholds is not '
'supported when eager execution is enabled.')
@ -1617,7 +1617,7 @@ def true_negatives(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.true_negatives is not '
'supported when eager execution is enabled.')
@ -1673,7 +1673,7 @@ def true_negatives_at_thresholds(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not '
'supported when eager execution is enabled.')
@ -1727,7 +1727,7 @@ def true_positives(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.true_positives is not '
'supported when eager execution is enabled.')
@ -1783,7 +1783,7 @@ def true_positives_at_thresholds(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.true_positives_at_thresholds is not '
'supported when eager execution is enabled.')
@ -1851,7 +1851,7 @@ def precision(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.precision is not '
'supported when eager execution is enabled.')
@ -1947,7 +1947,7 @@ def precision_at_thresholds(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.precision_at_thresholds is not '
'supported when eager execution is enabled.')
@ -2023,7 +2023,7 @@ def recall(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.recall is not supported is not '
'supported when eager execution is enabled.')
@ -2400,7 +2400,7 @@ def recall_at_k(labels,
are not a list or tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.recall_at_k is not '
'supported when eager execution is enabled.')
@ -2549,7 +2549,7 @@ def recall_at_thresholds(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.recall_at_thresholds is not '
'supported when eager execution is enabled.')
@ -2626,7 +2626,7 @@ def root_mean_squared_error(labels,
tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.root_mean_squared_error is not '
'supported when eager execution is enabled.')
@ -2707,7 +2707,7 @@ def sensitivity_at_specificity(labels,
or `updates_collections` are not a list or tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.sensitivity_at_specificity is not '
'supported when eager execution is enabled.')
@ -3098,7 +3098,7 @@ def average_precision_at_k(labels,
ValueError: if k is invalid.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not '
'supported when eager execution is enabled.')
@ -3267,7 +3267,7 @@ def precision_at_top_k(labels,
are not a list or tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.precision_at_top_k is not '
'supported when eager execution is enabled.')
@ -3396,7 +3396,7 @@ def precision_at_k(labels,
are not a list or tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.sparse_precision_at_k is not '
'supported when eager execution is enabled.')
@ -3473,7 +3473,7 @@ def specificity_at_sensitivity(labels,
or `updates_collections` are not a list or tuple.
RuntimeError: If eager execution is enabled.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError('tf.metrics.specificity_at_sensitivity is not '
'supported when eager execution is enabled.')

View File

@ -456,7 +456,7 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
def IsZero(g):
# Some introspection to check if the gradient is feeding zeros
if context.in_eager_mode():
if context.executing_eagerly():
# TODO(apassos) add an efficient way to detect eager zeros here.
return False
if g.op.type in ("ZerosLike", "Zeros"):

View File

@ -1504,7 +1504,7 @@ def bias_add(value, bias, data_format=None, name=None):
A `Tensor` with the same type as `value`.
"""
with ops.name_scope(name, "BiasAdd", [value, bias]) as name:
if context.in_graph_mode():
if not context.executing_eagerly():
value = ops.convert_to_tensor(value, name="input")
bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
return gen_nn_ops.bias_add(value, bias, data_format=data_format, name=name)
@ -1616,7 +1616,7 @@ def _flatten_outer_dims(logits):
output = array_ops.reshape(logits, array_ops.concat([[-1], last_dim_size], 0))
# Set output shape if known.
if context.in_graph_mode():
if not context.executing_eagerly():
shape = logits.get_shape()
if shape is not None and shape.dims is not None:
shape = shape.as_list()
@ -1881,7 +1881,8 @@ def softmax_cross_entropy_with_logits_v2(
# Make shape inference work since reshape and transpose may erase its static
# shape.
if context.in_graph_mode() and shape is not None and shape.dims is not None:
if not context.executing_eagerly(
) and shape is not None and shape.dims is not None:
shape = shape.as_list()
del shape[dim]
cost.set_shape(shape)
@ -2318,7 +2319,7 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: di
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
binary_tensor = math_ops.floor(random_tensor)
ret = math_ops.div(x, keep_prob) * binary_tensor
if context.in_graph_mode():
if not context.executing_eagerly():
ret.set_shape(x.get_shape())
return ret

View File

@ -74,7 +74,7 @@ def add_check_numerics_ops():
the checked operations.
@enc_compatibility
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError(
"add_check_numerics_ops() is not compatible with eager execution. "
"To check for Inf's and NaN's under eager execution, call "

View File

@ -135,10 +135,10 @@ class EagerResourceDeleter(object):
# valid, and so on. Printing warnings in these cases is silly
# (exceptions raised from __del__ are printed as warnings to stderr).
pass # 'NoneType' object is not callable when the handle has been
# partially unloaded.
# partially unloaded.
except AttributeError:
pass # 'NoneType' object has no attribute 'eager_mode' when context has
# been unloaded. Will catch other module unloads as well.
# been unloaded. Will catch other module unloads as well.
def shape_safe_assign_variable_handle(handle, shape, value, name=None):
@ -267,9 +267,9 @@ class ResourceVariable(variables.Variable):
if initial_value is not None:
raise ValueError("variable_def and initial_value are mutually "
"exclusive.")
if not context.in_graph_mode():
raise ValueError("Creating ResourceVariable from variable_def"
" only supported in GRAPH mode.")
if context.executing_eagerly():
raise ValueError("Creating ResourceVariable from variable_def is "
"not supported when eager execution is enabled.")
self._init_from_proto(variable_def, import_scope=import_scope)
else:
self._init_from_args(
@ -363,7 +363,7 @@ class ResourceVariable(variables.Variable):
# this graph.
self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
with ops.init_scope():
self._in_graph_mode = context.in_graph_mode()
self._in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "Variable", []
if init_from_fn else [initial_value]) as name:
# pylint: disable=protected-access
@ -470,7 +470,7 @@ class ResourceVariable(variables.Variable):
self._cached_value = self._read_variable_op()
else:
self._cached_value = None
if context.in_graph_mode():
if not context.executing_eagerly():
ops.add_to_collections(collections, self)
elif ops.GraphKeys.GLOBAL_STEP in collections:
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)
@ -489,7 +489,7 @@ class ResourceVariable(variables.Variable):
def _init_from_proto(self, variable_def, import_scope=None):
"""Initializes from `VariableDef` proto."""
# Note that init_from_proto is currently not supported in Eager mode.
assert context.in_graph_mode()
assert not context.executing_eagerly()
self._in_graph_mode = True
assert isinstance(variable_def, variable_pb2.VariableDef)
if not variable_def.is_resource:
@ -582,7 +582,8 @@ class ResourceVariable(variables.Variable):
def create(self):
"""The op responsible for initializing this variable."""
if not self._in_graph_mode:
raise RuntimeError("Calling create in EAGER mode not supported.")
raise RuntimeError("Calling create is not supported when eager execution"
" is enabled.")
return self._initializer_op
@property
@ -610,7 +611,7 @@ class ResourceVariable(variables.Variable):
@property
def initial_value(self):
"""Returns the Tensor used as the initial value for the variable."""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError("initial_value not supported in EAGER mode.")
return self._initial_value
@ -631,15 +632,15 @@ class ResourceVariable(variables.Variable):
def eval(self, session=None):
"""Evaluates and returns the value of this variable."""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError("Trying to eval in EAGER mode")
return self._graph_element.eval(session=session)
def numpy(self):
if context.in_graph_mode():
raise NotImplementedError(
"numpy() is only available when eager execution is enabled.")
return self.read_value().numpy()
if context.executing_eagerly():
return self.read_value().numpy()
raise NotImplementedError(
"numpy() is only available when eager execution is enabled.")
def count_up_to(self, limit):
"""Increments this variable until it reaches `limit`.
@ -720,7 +721,7 @@ class ResourceVariable(variables.Variable):
A `VariableDef` protocol buffer, or `None` if the `Variable` is not
in the specified name scope.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError("to_proto not supported in EAGER mode.")
if export_scope is None or self.handle.name.startswith(export_scope):
var_def = variable_pb2.VariableDef()
@ -747,7 +748,7 @@ class ResourceVariable(variables.Variable):
@staticmethod
def from_proto(variable_def, import_scope=None):
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError("from_proto not supported in EAGER mode.")
return ResourceVariable(
variable_def=variable_def, import_scope=import_scope)
@ -984,10 +985,10 @@ class _UnreadVariable(ResourceVariable):
self._is_initialized_op = None
self._initializer_op = None
self._parent_op = parent_op
if context.in_graph_mode():
self._graph_element = self.read_value()
else:
if context.executing_eagerly():
self._graph_element = None
else:
self._graph_element = self.read_value()
self._handle_deleter = deleter
def value(self):

View File

@ -575,7 +575,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
if context.in_graph_mode():
if not context.executing_eagerly():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
@ -616,7 +616,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
["Expected shape for Tensor %s is " % x.name,
packed_shape, " but saw shape: ", x_shape])
if context.in_graph_mode() and sequence_length is not None:
if not context.executing_eagerly() and sequence_length is not None:
# Perform some shape validation
with ops.control_dependencies(
[_assert_has_shape(sequence_length, [batch_size])]):
@ -742,7 +742,7 @@ def _dynamic_rnn_loop(cell,
element_shape=element_shape,
tensor_array_name=base_name + name)
in_graph_mode = context.in_graph_mode()
in_graph_mode = not context.executing_eagerly()
if in_graph_mode:
output_ta = tuple(
_create_ta(
@ -1027,7 +1027,7 @@ def raw_rnn(cell, loop_fn,
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
with vs.variable_scope(scope or "rnn") as varscope:
if context.in_graph_mode():
if not context.executing_eagerly():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
@ -1242,7 +1242,7 @@ def static_rnn(cell,
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
with vs.variable_scope(scope or "rnn") as varscope:
if context.in_graph_mode():
if not context.executing_eagerly():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)

View File

@ -128,7 +128,7 @@ def _zero_state_tensors(state_size, batch_size, dtype):
"""Combine s with batch_size to get a proper tensor shape."""
c = _concat(batch_size, s)
size = array_ops.zeros(c, dtype=dtype)
if context.in_graph_mode():
if not context.executing_eagerly():
c_static = _concat(batch_size, s, static=True)
size.set_shape(c_static)
return size
@ -192,12 +192,13 @@ class RNNCell(base_layer.Layer):
def _rnn_get_variable(self, getter, *args, **kwargs):
variable = getter(*args, **kwargs)
if context.in_graph_mode():
trainable = (variable in tf_variables.trainable_variables() or
(isinstance(variable, tf_variables.PartitionedVariable) and
list(variable)[0] in tf_variables.trainable_variables()))
else:
if context.executing_eagerly():
trainable = variable._trainable # pylint: disable=protected-access
else:
trainable = (
variable in tf_variables.trainable_variables() or
(isinstance(variable, tf_variables.PartitionedVariable) and
list(variable)[0] in tf_variables.trainable_variables()))
if trainable and variable not in self._trainable_weights:
self._trainable_weights.append(variable)
elif not trainable and variable not in self._non_trainable_weights:
@ -241,7 +242,7 @@ class RNNCell(base_layer.Layer):
# Try to use the last cached zero_state. This is done to avoid recreating
# zeros, especially when eager execution is enabled.
state_size = self.state_size
is_eager = context.in_eager_mode()
is_eager = context.executing_eagerly()
if is_eager and hasattr(self, "_last_zero_state"):
(last_state_size, last_batch_size, last_dtype,
last_output) = getattr(self, "_last_zero_state")

View File

@ -317,7 +317,7 @@ def py_func(func, inp, Tout, stateful=True, name=None):
Returns:
A list of `Tensor` or a single `Tensor` which `func` computes.
"""
if context.in_eager_mode():
if context.executing_eagerly():
result = func(*[x.numpy() for x in inp])
result = nest.flatten(result)

View File

@ -186,7 +186,7 @@ def is_variable_initialized(ref, name=None):
if ref.dtype._is_ref_dtype:
return gen_state_ops.is_variable_initialized(ref=ref, name=name)
# Handle resource variables.
if context.in_eager_mode() or ref.op.type == "VarHandleOp":
if context.executing_eagerly() or ref.op.type == "VarHandleOp":
return gen_resource_variable_ops.var_is_initialized_op(ref.handle,
name=name)

View File

@ -204,7 +204,7 @@ def make_template_internal(name_,
if kwargs:
func_ = tf_decorator.make_decorator(func_, functools.partial(
func_, **kwargs))
if context.in_eager_mode():
if context.executing_eagerly():
if unique_name_ is not None:
raise ValueError(
"unique_name_ cannot be used when eager exeuction is enabled.")
@ -364,7 +364,7 @@ class Template(checkpointable.CheckpointableBase):
"""
def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
# we don't want to propagate.
# we don't want to propagate.
return next_creator(
initial_value=initializer,
name=name,
@ -647,7 +647,7 @@ class EagerTemplate(Template):
Raises:
RuntimeError: if eager execution is not enabled.
"""
if not context.in_eager_mode():
if not context.executing_eagerly():
raise RuntimeError(
"{} objects can only be used when eager execution is enabled, use "
"tf.Template for graph construction".

View File

@ -338,7 +338,7 @@ class _GraphTensorArray(object):
with ops.name_scope(name, "TensorArrayScatter",
[self._handle, value, indices]):
value = ops.convert_to_tensor(value, name="value")
if self._infer_shape and context.in_graph_mode():
if self._infer_shape and not context.executing_eagerly():
self._merge_element_shape(value.shape[1:])
with self._maybe_colocate_with(value):
flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
@ -363,7 +363,7 @@ class _GraphTensorArray(object):
value = ops.convert_to_tensor(value, name="value")
with self._maybe_colocate_with(value):
lengths_64 = math_ops.to_int64(lengths)
if self._infer_shape and context.in_graph_mode():
if self._infer_shape and not context.executing_eagerly():
clengths = tensor_util.constant_value(lengths_64)
if value.shape.dims is not None:
if clengths is not None and clengths.max() == clengths.min():
@ -774,10 +774,10 @@ class TensorArray(object):
ValueError: if both handle and tensor_array_name are provided.
TypeError: if handle is provided but is not a Tensor.
"""
if context.in_graph_mode():
implementation = _GraphTensorArray
else:
if context.executing_eagerly():
implementation = _EagerTensorArray
else:
implementation = _GraphTensorArray
self._implementation = implementation(
dtype,

View File

@ -321,7 +321,7 @@ class _VariableStore(object):
raise ValueError(
"Passed a custom_getter which is not callable: %s" % custom_getter)
if context.in_eager_mode():
if context.executing_eagerly():
if not self._store_eager_variables and reuse:
raise RuntimeError(
"When eager execution is enabled variable reuse is only supported"
@ -518,7 +518,7 @@ class _VariableStore(object):
when violating reuse during variable creation, or if an existing
sharded variable exists for the given name but with different sharding.
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise NotImplementedError("Partitioned variables are not yet supported "
"when eager execution is enabled.")
@ -798,7 +798,7 @@ class _VariableStore(object):
validate_shape=validate_shape,
constraint=constraint,
use_resource=use_resource)
if context.in_graph_mode() or self._store_eager_variables:
if not context.executing_eagerly() or self._store_eager_variables:
# In eager mode we do not want to keep default references to Variable
# objects as this will prevent their memory from being released.
self._vars[name] = v
@ -811,12 +811,12 @@ class _VariableStore(object):
with ops.name_scope(name + "/Regularizer/"):
loss = regularizer(v)
if loss is not None:
if context.in_graph_mode():
v_name = v.name
loss_name = loss.name
else:
if context.executing_eagerly():
v_name = "v_%s" % type(v)
loss_name = "loss_%s" % type(loss)
else:
v_name = v.name
loss_name = loss.name
logging.vlog(1, "Applied regularizer to %s and added the result %s "
"to REGULARIZATION_LOSSES.", v_name, loss_name)
ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, loss)
@ -920,7 +920,7 @@ class VariableScope(object):
self._dtype = dtype
self._use_resource = use_resource
self._constraint = constraint
if context.in_eager_mode():
if context.executing_eagerly():
if self._caching_device is not None:
raise NotImplementedError("Caching devices is not yet supported "
"when eager execution is enabled.")
@ -988,7 +988,7 @@ class VariableScope(object):
def set_use_resource(self, use_resource):
"""Sets whether to use ResourceVariables for this scope."""
if context.in_eager_mode() and not use_resource:
if context.executing_eagerly() and not use_resource:
raise ValueError("When eager execution is enabled, "
"use_resource cannot be set to false.")
self._use_resource = use_resource
@ -999,14 +999,14 @@ class VariableScope(object):
def set_caching_device(self, caching_device):
"""Set caching_device for this scope."""
if context.in_eager_mode():
if context.executing_eagerly():
raise NotImplementedError("Caching devices are not yet supported "
"when eager execution is enabled.")
self._caching_device = caching_device
def set_partitioner(self, partitioner):
"""Set partitioner for this scope."""
if partitioner and context.in_eager_mode():
if partitioner and context.executing_eagerly():
raise NotImplementedError("Partitioned variables are not yet supported "
"when eager execution is enabled.")
self._partitioner = partitioner
@ -1057,14 +1057,14 @@ class VariableScope(object):
partitioner = self._partitioner
if custom_getter is None:
custom_getter = self._custom_getter
if context.in_graph_mode():
if context.executing_eagerly():
reuse = False
use_resource = True
else:
if reuse is None:
reuse = self._reuse
if use_resource is None:
use_resource = self._use_resource
else:
reuse = False
use_resource = True
full_name = self.name + "/" + name if self.name else name
# Variable names only depend on variable_scope (full_name here),
@ -1107,7 +1107,7 @@ class VariableScope(object):
use_resource=None,
constraint=None):
"""Gets an existing variable with this name or create a new one."""
if context.in_eager_mode():
if context.executing_eagerly():
raise NotImplementedError("Partitioned variables are not yet supported "
"when eager execution is enabled.")
if initializer is None:
@ -1871,7 +1871,7 @@ class variable_scope(object):
raise ValueError("The reuse parameter must be True or False or None.")
if self._values is None:
self._values = []
self._in_graph_mode = not context.in_eager_mode()
self._in_graph_mode = not context.executing_eagerly()
if self._in_graph_mode:
self._graph = ops._get_graph_from_inputs(self._values) # pylint: disable=protected-access
self._cached_pure_variable_scope = None
@ -2111,13 +2111,13 @@ def default_variable_creator(next_creator=None, **kwargs):
use_resource = kwargs.get("use_resource", None)
if use_resource is None:
use_resource = get_variable_scope().use_resource
if use_resource or (use_resource is None and context.in_eager_mode()):
if use_resource or (use_resource is None and context.executing_eagerly()):
return resource_variable_ops.ResourceVariable(
initial_value=initial_value, trainable=trainable,
collections=collections, validate_shape=validate_shape,
caching_device=caching_device, name=name, dtype=dtype,
constraint=constraint)
elif not use_resource and context.in_eager_mode():
elif not use_resource and context.executing_eagerly():
raise RuntimeError(
"VariableScope should use resource variable when eager execution is"
" enabled, but use_resource is False."

View File

@ -210,10 +210,11 @@ class Variable(checkpointable.CheckpointableBase):
for details on how variables work in eager execution.
@end_compatibility
"""
if not context.in_graph_mode():
raise RuntimeError("tf.Variable not supported in Eager mode. "
"Please use tfe.Variable instead")
self._in_graph_mode = context.in_graph_mode()
if context.executing_eagerly():
raise RuntimeError(
"tf.Variable not supported when eager execution is enabled. "
"Please use tf.contrib.eager.Variable instead")
self._in_graph_mode = True
if variable_def:
# If variable_def is provided, recreates the variable from its fields.
if initial_value:
@ -234,7 +235,7 @@ class Variable(checkpointable.CheckpointableBase):
constraint=constraint)
def __repr__(self):
if context.in_eager_mode():
if context.executing_eagerly():
return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % (
self.name, self.get_shape(), self.dtype.name,
ops.numpy_text(self.read_value(), is_repr=True))
@ -740,15 +741,15 @@ class Variable(checkpointable.CheckpointableBase):
Raises:
ValueError: Session is not passed and no default session
"""
if context.in_graph_mode():
if context.executing_eagerly():
self.assign(value)
else:
session = session or ops.get_default_session()
if session is None:
raise ValueError(
"Either session argument should be provided or default session "
"should be established")
session.run(self._initializer_op, {self._initializer_op.inputs[1]: value})
else:
self.assign(value)
# Conversion to tensor.
@staticmethod
@ -1248,9 +1249,9 @@ class PartitionedVariable(object):
information does not match `shape`, or `partitions` has invalid values.
RuntimeError: If eager execution is enabled
"""
if not context.in_graph_mode():
raise RuntimeError("tf.PartitionedVariable not supported in "
"eager mode. Please use tfe.Variable instead")
if context.executing_eagerly():
raise RuntimeError(
"tf.PartitionedVariable not supported with eager execution enabled.")
if not isinstance(variable_list, (list, tuple)):
raise TypeError(
"variable_list is not a list or tuple: %s" % variable_list)
@ -1541,7 +1542,7 @@ def variables_initializer(var_list, name="init"):
Returns:
An Op that run the initializers of all the specified variables.
"""
if var_list and context.in_graph_mode():
if var_list and not context.executing_eagerly():
return control_flow_ops.group(*[v.initializer for v in var_list], name=name)
return control_flow_ops.no_op(name=name)
@ -1563,7 +1564,7 @@ def global_variables_initializer():
Returns:
An Op that initializes global variables in the graph.
"""
if context.in_eager_mode():
if context.executing_eagerly():
return control_flow_ops.no_op(name="global_variables_initializer")
return variables_initializer(global_variables())
@ -1585,7 +1586,7 @@ def local_variables_initializer():
Returns:
An Op that initializes all local variables in the graph.
"""
if context.in_eager_mode():
if context.executing_eagerly():
return control_flow_ops.no_op(name="local_variables_initializer")
return variables_initializer(local_variables())

View File

@ -172,7 +172,7 @@ class Profiler(object):
op_log: optional. tensorflow::tfprof::OpLogProto proto. Used to define
extra op types.
"""
if not graph and context.in_graph_mode():
if not graph and not context.executing_eagerly():
graph = ops.get_default_graph()
self._coverage = 0.0
self._graph = graph
@ -336,7 +336,7 @@ def profile(graph=None,
If cmd is 'op' or 'code', returns MultiGraphNodeProto proto.
Side effect: stdout/file/timeline.json depending on options['output']
"""
if not graph and context.in_graph_mode():
if not graph and not context.executing_eagerly():
graph = ops.get_default_graph()
if options == _DEFAULT_PROFILE_OPTIONS:

View File

@ -156,7 +156,7 @@ def merge_default_with_oplog(graph, op_log=None, run_meta=None,
Returns:
tmp_op_log: Merged OpLogProto proto.
"""
if not graph and context.in_graph_mode():
if not graph and not context.executing_eagerly():
graph = ops.get_default_graph()
tmp_op_log = tfprof_log_pb2.OpLogProto()
@ -210,7 +210,7 @@ def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True):
add_trace: Whether to add python code trace information.
Used to support "code" view.
"""
if not graph and context.in_graph_mode():
if not graph and not context.executing_eagerly():
graph = ops.get_default_graph()
op_log = merge_default_with_oplog(graph, op_log, run_meta, add_trace)

View File

@ -278,7 +278,7 @@ def merge(inputs, collections=None, name=None):
@end_compatbility
"""
# pylint: enable=line-too-long
if _context.in_eager_mode():
if _context.executing_eagerly():
raise RuntimeError(
'Merging tf.summary.* ops is not compatible with eager execution. '
'Use tf.contrib.summary instead.')
@ -311,7 +311,7 @@ def merge_all(key=_ops.GraphKeys.SUMMARIES, scope=None):
summaries under eager execution, use `tf.contrib.summary` instead.
@end_compatbility
"""
if _context.in_eager_mode():
if _context.executing_eagerly():
raise RuntimeError(
'Merging tf.summary.* ops is not compatible with eager execution. '
'Use tf.contrib.summary instead.')

View File

@ -343,7 +343,7 @@ class FileWriter(SummaryToEventTransformer):
summaries under eager execution, use `tf.contrib.summary` instead.
@end_compatbility
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError(
"tf.summary.FileWriter is not compatible with eager execution. "
"Use tf.contrib.summary instead.")

View File

@ -106,10 +106,10 @@ class AdamOptimizer(optimizer.Optimizer):
self._updated_lr = None
def _get_beta_accumulators(self):
if context.in_graph_mode():
graph = ops.get_default_graph()
else:
if context.executing_eagerly():
graph = None
else:
graph = ops.get_default_graph()
return (self._get_non_slot_variable("beta1_power", graph=graph),
self._get_non_slot_variable("beta2_power", graph=graph))

View File

@ -184,7 +184,7 @@ class AdamOptimizerTest(test.TestCase):
# Shouldn't return non-slot variables from other graphs.
self.assertEqual(0, len(opt.variables()))
if context.in_graph_mode():
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
@ -194,7 +194,7 @@ class AdamOptimizerTest(test.TestCase):
# Run 3 steps of Adam
for t in range(1, 4):
if context.in_graph_mode():
if not context.executing_eagerly():
self.evaluate(update)
elif t > 1:
opt.apply_gradients(zip([grads0, grads1], [var0, var1]))

View File

@ -208,7 +208,7 @@ class _CheckpointPosition(object):
# Name saveables based on the name this object had when it was checkpointed.
named_saveables = {}
restore_ops = []
in_graph_mode = context.in_graph_mode()
building_graph = not context.executing_eagerly()
for serialized_tensor in self.object_proto.attributes:
saveable_object = saveables.get(serialized_tensor.name, None)
if saveable_object is None:
@ -219,7 +219,7 @@ class _CheckpointPosition(object):
self._checkpoint.unused_attributes.setdefault(
self.checkpointable, []).append(serialized_tensor.name)
continue
if in_graph_mode:
if building_graph:
existing_ops = self._checkpoint.restore_ops_by_name.get(
serialized_tensor.name, None)
else:
@ -245,7 +245,7 @@ class _CheckpointPosition(object):
saveable_index:saveable_index + num_specs]
saveable_index += num_specs
restore_op = saveable.restore(saveable_tensors, restored_shapes=None)
if in_graph_mode:
if building_graph:
assert saveable.name not in self._checkpoint.restore_ops_by_name
self._checkpoint.restore_ops_by_name[saveable.name] = restore_op
restore_ops.append(restore_op)
@ -388,7 +388,7 @@ class CheckpointableBase(object):
"Checkpointable._add_variable called to create another with "
"that name. Variable names must be unique within a Checkpointable "
"object.") % (name,))
if context.in_eager_mode():
if context.executing_eagerly():
# If this is a variable with a single Tensor stored in the checkpoint, we
# can set that value as an initializer rather than initializing and then
# assigning (when executing eagerly). This call returns None if there is

View File

@ -71,6 +71,6 @@ class GradientDescentOptimizer(optimizer.Optimizer):
return var.scatter_sub(delta, use_locking=self._use_locking)
def _prepare(self):
if context.in_graph_mode() or self._learning_rate_tensor is None:
if not context.executing_eagerly() or self._learning_rate_tensor is None:
self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
name="learning_rate")

View File

@ -159,7 +159,7 @@ def input_producer(input_tensor,
enabled. Please use the `tf.data` API to ingest data under eager execution.
@end_compatibility
"""
if context.in_eager_mode():
if context.executing_eagerly():
raise RuntimeError(
"Input pipelines based on Queues are not supported when eager execution"
" is enabled. Please use tf.data to ingest data into your model"
@ -737,7 +737,7 @@ def _batch(tensors, batch_size, keep_input, num_threads=1, capacity=32,
allow_smaller_final_batch=False, shared_name=None,
name=None):
"""Helper function for `batch` and `maybe_batch`."""
if context.in_eager_mode():
if context.executing_eagerly():
raise ValueError(
"Input pipelines based on Queues are not supported when eager execution"
" is enabled. Please use tf.data to ingest data into your model"
@ -775,7 +775,7 @@ def _batch_join(tensors_list, batch_size, keep_input, capacity=32,
enqueue_many=False, shapes=None, dynamic_pad=False,
allow_smaller_final_batch=False, shared_name=None, name=None):
"""Helper function for `batch_join` and `maybe_batch_join`."""
if context.in_eager_mode():
if context.executing_eagerly():
raise ValueError(
"Input pipelines based on Queues are not supported when eager execution"
" is enabled. Please use tf.data to ingest data into your model"
@ -810,7 +810,7 @@ def _shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
shapes=None, allow_smaller_final_batch=False,
shared_name=None, name=None):
"""Helper function for `shuffle_batch` and `maybe_shuffle_batch`."""
if context.in_eager_mode():
if context.executing_eagerly():
raise ValueError(
"Input pipelines based on Queues are not supported when eager execution"
" is enabled. Please use tf.data to ingest data into your model"
@ -855,7 +855,7 @@ def _shuffle_batch_join(tensors_list, batch_size, capacity,
allow_smaller_final_batch=False, shared_name=None,
name=None):
"""Helper function for `shuffle_batch_join` and `maybe_shuffle_batch_join`."""
if context.in_eager_mode():
if context.executing_eagerly():
raise ValueError(
"Input pipelines based on Queues are not supported when eager execution"
" is enabled. Please use tf.data to ingest data into your model"

View File

@ -113,7 +113,7 @@ class LRDecayTest(test_util.TensorFlowTestCase):
learning_rate_decay.piecewise_constant(x, boundaries, values)
# Test that ref types are valid.
if context.in_graph_mode():
if not context.executing_eagerly():
x = variables.Variable(0.0)
x_ref = x.op.outputs[0] # float32_ref tensor should be accepted
boundaries, values = [1.0, 2.0], [1, 2, 3]

Some files were not shown because too many files have changed in this diff Show More