eager: Rename in_eager_mode to executing_eagerly and get rid of in_graph_mode.

This is in preparation to introduce one public, stable symbol: tf.executing_eagerly()
(i.e., part of moving APIs related to eager execution from "contrib" to a namespace
where we provide API stability guarantees)

PiperOrigin-RevId: 188212646
This commit is contained in:
Asim Shankar 2018-03-07 12:03:56 -08:00 committed by TensorFlower Gardener
parent 808b569e85
commit 37cef895bf
110 changed files with 790 additions and 854 deletions

View File

@ -44,7 +44,7 @@ class PrivateThreadPool(object):
def __init__(self, num_threads, display_name=None): def __init__(self, num_threads, display_name=None):
"""Creates a `PrivateThreadPool` with the given number of threads.""" """Creates a `PrivateThreadPool` with the given number of threads."""
if context.in_eager_mode(): if context.executing_eagerly():
shared_name = _generate_shared_name("privatethreadpool") shared_name = _generate_shared_name("privatethreadpool")
self._resource = gen_dataset_ops.thread_pool_handle( self._resource = gen_dataset_ops.thread_pool_handle(
num_threads=num_threads, num_threads=num_threads,

View File

@ -395,7 +395,7 @@ class CheckpointLoadStatus(_LoadStatus):
def run_restore_ops(self, session=None): def run_restore_ops(self, session=None):
"""Run operations to restore objects in the dependency graph.""" """Run operations to restore objects in the dependency graph."""
if context.in_eager_mode(): if context.executing_eagerly():
return # Run eagerly return # Run eagerly
if session is None: if session is None:
session = ops.get_default_session() session = ops.get_default_session()
@ -459,7 +459,7 @@ class InitializationOnlyStatus(_LoadStatus):
session: The session to run initialization ops in. If `None`, uses the session: The session to run initialization ops in. If `None`, uses the
default session. default session.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
return # run eagerly return # run eagerly
if session is None: if session is None:
session = ops.get_default_session() session = ops.get_default_session()
@ -491,7 +491,7 @@ class NameBasedSaverStatus(_LoadStatus):
date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS) date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS)
def run_restore_ops(self, session=None): def run_restore_ops(self, session=None):
"""Load the name-based training checkpoint using a new `tf.train.Saver`.""" """Load the name-based training checkpoint using a new `tf.train.Saver`."""
if session is None and context.in_graph_mode(): if session is None and not context.executing_eagerly():
session = ops.get_default_session() session = ops.get_default_session()
saver_lib.Saver(self._object_saver._global_variable_names()).restore( # pylint: disable=protected-access saver_lib.Saver(self._object_saver._global_variable_names()).restore( # pylint: disable=protected-access
sess=session, save_path=self._save_path) sess=session, save_path=self._save_path)
@ -548,7 +548,7 @@ class CheckpointableSaver(object):
# Allow passing in a weak reference to avoid reference cycles when # Allow passing in a weak reference to avoid reference cycles when
# `Checkpointable` objects save themselves. # `Checkpointable` objects save themselves.
self._root_checkpointable_ref = root_checkpointable self._root_checkpointable_ref = root_checkpointable
if context.in_graph_mode(): if not context.executing_eagerly():
with ops.device("/cpu:0"): with ops.device("/cpu:0"):
self._file_prefix_placeholder = constant_op.constant("model") self._file_prefix_placeholder = constant_op.constant("model")
else: else:
@ -597,7 +597,7 @@ class CheckpointableSaver(object):
""" """
named_variables, graph_proto = _serialize_object_graph( named_variables, graph_proto = _serialize_object_graph(
self._root_checkpointable) self._root_checkpointable)
in_graph_mode = context.in_graph_mode() in_graph_mode = not context.executing_eagerly()
if in_graph_mode: if in_graph_mode:
if session is None: if session is None:
session = ops.get_default_session() session = ops.get_default_session()
@ -714,7 +714,7 @@ class CheckpointableSaver(object):
""" """
if save_path is None: if save_path is None:
return InitializationOnlyStatus(self._root_checkpointable) return InitializationOnlyStatus(self._root_checkpointable)
in_graph_mode = context.in_graph_mode() in_graph_mode = not context.executing_eagerly()
if in_graph_mode: if in_graph_mode:
if session is None: if session is None:
session = ops.get_default_session() session = ops.get_default_session()
@ -850,7 +850,7 @@ class Checkpoint(core_checkpointable.Checkpointable):
def save(self, file_prefix, session=None): def save(self, file_prefix, session=None):
"""Save a checkpoint. Wraps `tfe.CheckpointableSaver.save`.""" """Save a checkpoint. Wraps `tfe.CheckpointableSaver.save`."""
in_graph_mode = context.in_graph_mode() in_graph_mode = not context.executing_eagerly()
if in_graph_mode: if in_graph_mode:
if session is None: if session is None:
session = ops.get_default_session() session = ops.get_default_session()

View File

@ -108,14 +108,14 @@ class InterfaceTests(test.TestCase):
[0., 0.]], self.evaluate(bare_initializer)) [0., 0.]], self.evaluate(bare_initializer))
self.assertEqual("a_variable:0", obj.a_variable.name) self.assertEqual("a_variable:0", obj.a_variable.name)
self.assertEqual("duplicate:0", other_duplicate.name) self.assertEqual("duplicate:0", other_duplicate.name)
if context.in_graph_mode(): if context.executing_eagerly():
# The .name attribute may be globally influenced, but the checkpoint name
# won't be (tested below).
self.assertEqual("duplicate_1:0", duplicate.name)
else:
# When executing eagerly, there's no uniquification of variable names. The # When executing eagerly, there's no uniquification of variable names. The
# checkpoint name will be the same. # checkpoint name will be the same.
self.assertEqual("duplicate:0", duplicate.name) self.assertEqual("duplicate:0", duplicate.name)
else:
# The .name attribute may be globally influenced, but the checkpoint name
# won't be (tested below).
self.assertEqual("duplicate_1:0", duplicate.name)
named_variables, _ = checkpointable_utils._serialize_object_graph(obj) named_variables, _ = checkpointable_utils._serialize_object_graph(obj)
expected_checkpoint_names = ( expected_checkpoint_names = (
"a_variable/.ATTRIBUTES/VARIABLE_VALUE", "a_variable/.ATTRIBUTES/VARIABLE_VALUE",
@ -165,7 +165,7 @@ class CheckpointingTests(test.TestCase):
optimizer_step = training_util.get_or_create_global_step() optimizer_step = training_util.get_or_create_global_step()
root_checkpointable = checkpointable_utils.Checkpoint( root_checkpointable = checkpointable_utils.Checkpoint(
optimizer=optimizer, model=model, optimizer_step=optimizer_step) optimizer=optimizer, model=model, optimizer_step=optimizer_step)
if context.in_eager_mode(): if context.executing_eagerly():
optimizer.minimize( optimizer.minimize(
lambda: model(input_value), lambda: model(input_value),
global_step=optimizer_step) global_step=optimizer_step)
@ -268,7 +268,7 @@ class CheckpointingTests(test.TestCase):
root_checkpointable = checkpointable_utils.Checkpoint( root_checkpointable = checkpointable_utils.Checkpoint(
optimizer=optimizer, model=model) optimizer=optimizer, model=model)
input_value = constant_op.constant([[3.]]) input_value = constant_op.constant([[3.]])
if context.in_eager_mode(): if context.executing_eagerly():
optimizer.minimize( optimizer.minimize(
lambda: model(input_value)) lambda: model(input_value))
else: else:
@ -293,7 +293,7 @@ class CheckpointingTests(test.TestCase):
self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1])) self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1]))
self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter)) self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter))
self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
if context.in_graph_mode(): if not context.executing_eagerly():
return # Restore-on-create is only supported when executing eagerly return # Restore-on-create is only supported when executing eagerly
on_create_model = MyModel() on_create_model = MyModel()
on_create_optimizer = adam.AdamOptimizer(0.001) on_create_optimizer = adam.AdamOptimizer(0.001)
@ -400,7 +400,7 @@ class CheckpointingTests(test.TestCase):
optimizer.minimize, optimizer.minimize,
functools.partial(model, input_value), functools.partial(model, input_value),
global_step=root.global_step) global_step=root.global_step)
if context.in_graph_mode(): if not context.executing_eagerly():
train_fn = functools.partial(self.evaluate, train_fn()) train_fn = functools.partial(self.evaluate, train_fn())
status.initialize_or_restore() status.initialize_or_restore()
for _ in range(num_training_steps): for _ in range(num_training_steps):
@ -524,7 +524,9 @@ class CheckpointingTests(test.TestCase):
root.var = checkpointable_utils.add_variable( root.var = checkpointable_utils.add_variable(
root, name="var", initializer=0.) root, name="var", initializer=0.)
optimizer = adam.AdamOptimizer(0.1) optimizer = adam.AdamOptimizer(0.1)
if context.in_graph_mode(): if context.executing_eagerly():
optimizer.minimize(root.var.read_value)
else:
train_op = optimizer.minimize(root.var) train_op = optimizer.minimize(root.var)
# Note that `optimizer` has not been added as a dependency of # Note that `optimizer` has not been added as a dependency of
# `root`. Create a one-off grouping so that slot variables for `root.var` # `root`. Create a one-off grouping so that slot variables for `root.var`
@ -532,8 +534,6 @@ class CheckpointingTests(test.TestCase):
self.evaluate(checkpointable_utils.gather_initializers( self.evaluate(checkpointable_utils.gather_initializers(
checkpointable_utils.Checkpoint(root=root, optimizer=optimizer))) checkpointable_utils.Checkpoint(root=root, optimizer=optimizer)))
self.evaluate(train_op) self.evaluate(train_op)
else:
optimizer.minimize(root.var.read_value)
self.evaluate(state_ops.assign(root.var, 12.)) self.evaluate(state_ops.assign(root.var, 12.))
no_slots_path = checkpointable_utils.CheckpointableSaver(root).save( no_slots_path = checkpointable_utils.CheckpointableSaver(root).save(
os.path.join(checkpoint_directory, "no_slots")) os.path.join(checkpoint_directory, "no_slots"))
@ -561,7 +561,7 @@ class CheckpointingTests(test.TestCase):
with self.assertRaisesRegexp(AssertionError, "beta1_power"): with self.assertRaisesRegexp(AssertionError, "beta1_power"):
slot_status.assert_consumed() slot_status.assert_consumed()
self.assertEqual(12., self.evaluate(new_root.var)) self.assertEqual(12., self.evaluate(new_root.var))
if context.in_eager_mode(): if context.executing_eagerly():
# Slot variables are only created with restoring initializers when # Slot variables are only created with restoring initializers when
# executing eagerly. # executing eagerly.
self.assertEqual(14., self.evaluate( self.assertEqual(14., self.evaluate(
@ -569,7 +569,9 @@ class CheckpointingTests(test.TestCase):
else: else:
self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var), self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
None) None)
if context.in_graph_mode(): if context.executing_eagerly():
new_root.optimizer.minimize(new_root.var.read_value)
else:
train_op = new_root.optimizer.minimize(new_root.var) train_op = new_root.optimizer.minimize(new_root.var)
# The slot variable now exists; restore() didn't create it, but we should # The slot variable now exists; restore() didn't create it, but we should
# now have a restore op for it. # now have a restore op for it.
@ -577,8 +579,6 @@ class CheckpointingTests(test.TestCase):
self.assertEqual(14., self.evaluate( self.assertEqual(14., self.evaluate(
new_root.optimizer.get_slot(name="m", var=new_root.var))) new_root.optimizer.get_slot(name="m", var=new_root.var)))
self.evaluate(train_op) self.evaluate(train_op)
else:
new_root.optimizer.minimize(new_root.var.read_value)
slot_status.assert_consumed() slot_status.assert_consumed()
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)

View File

@ -68,7 +68,7 @@ class Iterator(object):
RuntimeError: When invoked without eager execution enabled. RuntimeError: When invoked without eager execution enabled.
""" """
if not context.in_eager_mode(): if not context.executing_eagerly():
raise RuntimeError( raise RuntimeError(
"{} objects can only be used when eager execution is enabled, use " "{} objects can only be used when eager execution is enabled, use "
"tf.data.Dataset.make_initializable_iterator or " "tf.data.Dataset.make_initializable_iterator or "

View File

@ -57,7 +57,7 @@ class Evaluator(object):
self._model = model self._model = model
self._metrics = {} self._metrics = {}
self._evaluators = {} self._evaluators = {}
if context.in_graph_mode(): if not context.executing_eagerly():
self.call = function.defun(self.call) self.call = function.defun(self.call)
# ---- API for users ---- # ---- API for users ----
@ -90,7 +90,7 @@ class Evaluator(object):
Only for graph execution. Only for graph execution.
@end_compatibility @end_compatibility
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError("Evaluator.init_variables() not needed when " raise RuntimeError("Evaluator.init_variables() not needed when "
"eager execution is enabled.") "eager execution is enabled.")
return control_flow_ops.group([m.init_variables() for _, m in self.metrics]) return control_flow_ops.group([m.init_variables() for _, m in self.metrics])
@ -113,7 +113,8 @@ class Evaluator(object):
with summary_ops.create_file_writer( with summary_ops.create_file_writer(
summary_logdir).as_default(), summary_ops.always_record_summaries(): summary_logdir).as_default(), summary_ops.always_record_summaries():
return self._all_metric_results() return self._all_metric_results()
if context.in_eager_mode():
if context.executing_eagerly():
return f() return f()
else: else:
return function.defun(f)() return function.defun(f)()
@ -158,16 +159,16 @@ class Evaluator(object):
@end_compatibility @end_compatibility
""" """
summary_logdir = kwargs.pop("summary_logdir", None) summary_logdir = kwargs.pop("summary_logdir", None)
if context.in_graph_mode(): if context.executing_eagerly():
call_op = self.__call__(dataset.make_one_shot_iterator().get_next(),
*args, **kwargs)
init_op = self.init_variables()
results_op = self.all_metric_results(summary_logdir)
return (init_op, call_op, results_op)
# Eager case
for example in datasets.Iterator(dataset): for example in datasets.Iterator(dataset):
self.__call__(example, *args, **kwargs) self.__call__(example, *args, **kwargs)
return self.all_metric_results(summary_logdir) return self.all_metric_results(summary_logdir)
# Graph construction
call_op = self.__call__(dataset.make_one_shot_iterator().get_next(), *args,
**kwargs)
init_op = self.init_variables()
results_op = self.all_metric_results(summary_logdir)
return (init_op, call_op, results_op)
@staticmethod @staticmethod
def run_evaluation(init_op, call_op, results_op, sess=None): def run_evaluation(init_op, call_op, results_op, sess=None):
@ -192,7 +193,7 @@ class Evaluator(object):
Only for graph execution. Only for graph execution.
@end_compatibility @end_compatibility
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError("Evaluator.run_evaluation() not supported when " raise RuntimeError("Evaluator.run_evaluation() not supported when "
"eager execution is enabled.") "eager execution is enabled.")
sess = sess or ops.get_default_session() sess = sess or ops.get_default_session()

View File

@ -109,13 +109,13 @@ class Metric(checkpointable.CheckpointableBase):
pos = scope.name.rfind(scope_name) pos = scope.name.rfind(scope_name)
self._name = name + scope.name[pos + len(scope_name):] self._name = name + scope.name[pos + len(scope_name):]
self._scope = scope self._scope = scope
if context.in_graph_mode(): if context.executing_eagerly():
self._construction_scope = context.eager_mode
else:
# We make self.call() into a graph callable here, so that we can # We make self.call() into a graph callable here, so that we can
# return a single op that performs all of the variable updates. # return a single op that performs all of the variable updates.
self._construction_scope = ops.get_default_graph().as_default self._construction_scope = ops.get_default_graph().as_default
self.call = function.defun(self.call) self.call = function.defun(self.call)
else:
self._construction_scope = context.eager_mode
# ---- API for users ---- # ---- API for users ----
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
@ -156,10 +156,11 @@ class Metric(checkpointable.CheckpointableBase):
initialization. Under eager execution, the variables are reset to their initialization. Under eager execution, the variables are reset to their
initial values as a side effect and this function returns None. initial values as a side effect and this function returns None.
""" """
if context.in_graph_mode(): if context.executing_eagerly():
return control_flow_ops.group([v.initializer for v in self._vars])
for v in self._vars: for v in self._vars:
v.assign(self._initial_values[v]) v.assign(self._initial_values[v])
else:
return control_flow_ops.group([v.initializer for v in self._vars])
# ---- To be implemented by descendants --- # ---- To be implemented by descendants ---
def build(self, *args, **kwargs): def build(self, *args, **kwargs):
@ -201,10 +202,10 @@ class Metric(checkpointable.CheckpointableBase):
def value(self): def value(self):
"""In graph mode returns the result Tensor while in eager the callable.""" """In graph mode returns the result Tensor while in eager the callable."""
if context.in_graph_mode(): if context.executing_eagerly():
return self.result()
else:
return self.result return self.result
else:
return self.result()
# We can support two different strategies of for doing data-parallel # We can support two different strategies of for doing data-parallel
# distributed metric computations: # distributed metric computations:
@ -246,7 +247,7 @@ class Metric(checkpointable.CheckpointableBase):
"""***Only for use by descendants of Metric***.""" """***Only for use by descendants of Metric***."""
if self._built: if self._built:
raise RuntimeError("Can't call add_variable() except in build().") raise RuntimeError("Can't call add_variable() except in build().")
if context.in_eager_mode(): if context.executing_eagerly():
collections = None collections = None
else: else:
if self._use_global_variables: if self._use_global_variables:
@ -270,7 +271,7 @@ class Metric(checkpointable.CheckpointableBase):
# Checkpointable. # Checkpointable.
overwrite=True) overwrite=True)
self._vars.append(v) self._vars.append(v)
if context.in_eager_mode(): if context.executing_eagerly():
self._initial_values[v] = v.value() self._initial_values[v] = v.value()
return v return v

View File

@ -639,7 +639,7 @@ def _make_custom_getter_for_deferred_restorations():
# Mark as already restored from this checkpoint. # Mark as already restored from this checkpoint.
delayed_restoration.checkpointed_variables_to_restore[ delayed_restoration.checkpointed_variables_to_restore[
checkpoint_name] = None checkpoint_name] = None
if context.in_graph_mode(): if not context.executing_eagerly():
delayed_restoration.session.run(variable.initializer) delayed_restoration.session.run(variable.initializer)
if found_value: if found_value:
# Error checking should run even if we've already restored a value. # Error checking should run even if we've already restored a value.
@ -772,7 +772,7 @@ def save_network_checkpoint(
variable_map[mapped_name]._shared_name, variable_map[mapped_name]._shared_name,
variable._shared_name, variable._shared_name,
network.scope_name)) network.scope_name))
if context.in_eager_mode(): if context.executing_eagerly():
sess = None sess = None
else: else:
sess = ops.get_default_session() sess = ops.get_default_session()
@ -853,7 +853,7 @@ def _restore_existing_variables(network, save_path, map_func, user_map_func):
network_name=network.name, network_name=network.name,
network_scope_name=network.scope_name)) network_scope_name=network.scope_name))
if existing_variables_by_checkpoint_name: if existing_variables_by_checkpoint_name:
if context.in_eager_mode(): if context.executing_eagerly():
sess = None sess = None
else: else:
sess = ops.get_default_session() sess = ops.get_default_session()
@ -880,7 +880,7 @@ def _set_restore_on_create(network, save_path, map_func, user_map_func,
# _DeferredRestoration objects once a Network has been built (so that # _DeferredRestoration objects once a Network has been built (so that
# restoring in a loop does not take increasing amounts of memory). # restoring in a loop does not take increasing amounts of memory).
if checkpointed_variables_to_restore: if checkpointed_variables_to_restore:
if context.in_eager_mode(): if context.executing_eagerly():
sess = None sess = None
else: else:
sess = ops.get_default_session() sess = ops.get_default_session()

View File

@ -73,7 +73,7 @@ def restore_variables_on_create(save_path, map_func=None):
NotFoundError: If the variable is not found in checkpoint. NotFoundError: If the variable is not found in checkpoint.
ValueError: If not used in eager mode or map_func is not callable. ValueError: If not used in eager mode or map_func is not callable.
""" """
if context.in_graph_mode(): if not context.executing_eagerly():
raise ValueError( raise ValueError(
"Currently, restore_variables_on_create can only be used with " "Currently, restore_variables_on_create can only be used with "
"eager execution enabled.") "eager execution enabled.")
@ -131,7 +131,7 @@ class Saver(object):
Raises: Raises:
RuntimeError: if invoked when eager execution has not been enabled. RuntimeError: if invoked when eager execution has not been enabled.
""" """
if context.in_graph_mode(): if not context.executing_eagerly():
raise RuntimeError("tfe.Saver can only be used when eager " raise RuntimeError("tfe.Saver can only be used when eager "
"execution is enabled. Use tf.train.Saver when " "execution is enabled. Use tf.train.Saver when "
"building graphs.") "building graphs.")

View File

@ -60,8 +60,8 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@Checkpointable @@Checkpointable
@@CheckpointableSaver @@CheckpointableSaver
@@executing_eagerly
@@in_eager_mode @@in_eager_mode
@@in_graph_mode
@@run_test_in_graph_and_eager_modes @@run_test_in_graph_and_eager_modes
@ -93,8 +93,7 @@ from tensorflow.python.eager import function
from tensorflow.python.eager.context import DEVICE_PLACEMENT_EXPLICIT from tensorflow.python.eager.context import DEVICE_PLACEMENT_EXPLICIT
from tensorflow.python.eager.context import DEVICE_PLACEMENT_WARN from tensorflow.python.eager.context import DEVICE_PLACEMENT_WARN
from tensorflow.python.eager.context import DEVICE_PLACEMENT_SILENT from tensorflow.python.eager.context import DEVICE_PLACEMENT_SILENT
from tensorflow.python.eager.context import in_eager_mode from tensorflow.python.eager.context import executing_eagerly
from tensorflow.python.eager.context import in_graph_mode
from tensorflow.python.eager.context import list_devices from tensorflow.python.eager.context import list_devices
from tensorflow.python.eager.context import num_gpus from tensorflow.python.eager.context import num_gpus
from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import add_execution_callback
@ -122,5 +121,6 @@ implicit_value_and_gradients = backprop.implicit_val_and_grad
gradients_function = backprop.gradients_function gradients_function = backprop.gradients_function
value_and_gradients_function = backprop.val_and_grad_function value_and_gradients_function = backprop.val_and_grad_function
GradientTape = backprop.GradientTape # pylint: disable=invalid-name GradientTape = backprop.GradientTape # pylint: disable=invalid-name
in_eager_mode = executing_eagerly
remove_undocumented(__name__) remove_undocumented(__name__)

View File

@ -47,7 +47,8 @@ class TFETest(test_util.TensorFlowTestCase):
def testVariableError(self): def testVariableError(self):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
RuntimeError, r'Variable not supported in Eager mode'): RuntimeError,
r'Variable not supported when eager execution is enabled'):
variables.Variable(initial_value=1.0) variables.Variable(initial_value=1.0)
def testGradients(self): def testGradients(self):

View File

@ -154,7 +154,7 @@ class CriticalSection(object):
self._handle = gen_resource_variable_ops.mutex_v2( self._handle = gen_resource_variable_ops.mutex_v2(
shared_name=shared_name, container=container, name=name) shared_name=shared_name, container=container, name=name)
if context.in_graph_mode(): if not context.executing_eagerly():
ops.add_to_collections(CRITICAL_SECTIONS, self) ops.add_to_collections(CRITICAL_SECTIONS, self)
@property @property
@ -221,7 +221,7 @@ class CriticalSection(object):
"This is illegal and would cause deadlocks. " "This is illegal and would cause deadlocks. "
"CriticalSection: %s." % self._handle) "CriticalSection: %s." % self._handle)
if context.in_graph_mode(): if not context.executing_eagerly():
# Collections and op introspection does not work in eager # Collections and op introspection does not work in eager
# mode. This is generally ok; since eager mode (as of # mode. This is generally ok; since eager mode (as of
# writing) executes sequentially anyway. # writing) executes sequentially anyway.
@ -250,7 +250,7 @@ class CriticalSection(object):
return x.identity() return x.identity()
elif isinstance(x, ops.Operation): elif isinstance(x, ops.Operation):
return control_flow_ops.group(x) return control_flow_ops.group(x)
elif context.in_eager_mode() and x is None: elif context.executing_eagerly() and x is None:
return None return None
else: else:
return array_ops.identity(x) return array_ops.identity(x)
@ -274,7 +274,7 @@ class CriticalSection(object):
with ops.control_dependencies([ensure_lock_exists]): with ops.control_dependencies([ensure_lock_exists]):
outputs = nest.map_structure(identity, r) outputs = nest.map_structure(identity, r)
if context.in_graph_mode(): if not context.executing_eagerly():
signature = _ExecutionSignature( signature = _ExecutionSignature(
op=lock.op, op=lock.op,
handle=self._handle, handle=self._handle,

View File

@ -2746,7 +2746,7 @@ def softmax(logits, scope=None):
logits_2d = array_ops.reshape(logits, [-1, num_logits]) logits_2d = array_ops.reshape(logits, [-1, num_logits])
predictions = nn.softmax(logits_2d) predictions = nn.softmax(logits_2d)
predictions = array_ops.reshape(predictions, array_ops.shape(logits)) predictions = array_ops.reshape(predictions, array_ops.shape(logits))
if context.in_graph_mode(): if not context.executing_eagerly():
predictions.set_shape(logits.get_shape()) predictions.set_shape(logits.get_shape())
return predictions return predictions

View File

@ -3646,7 +3646,7 @@ def cohen_kappa(labels,
`updates_collections` are not a list or tuple. `updates_collections` are not a list or tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.contrib.metrics.cohen_kappa is not supported' raise RuntimeError('tf.contrib.metrics.cohen_kappa is not supported'
'when eager execution is enabled.') 'when eager execution is enabled.')
if num_classes < 2: if num_classes < 2:

View File

@ -267,5 +267,5 @@ def _check_device(tensor, expected=None):
def _check_graph_mode(): def _check_graph_mode():
if context.in_eager_mode(): if context.executing_eagerly():
raise ValueError('Nccl ops are not supported in eager mode') raise ValueError('Nccl ops are not supported in eager mode')

View File

@ -97,7 +97,7 @@ class AddSignTest(test.TestCase):
global_step=global_step) global_step=global_step)
neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]), neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),
global_step=global_step) global_step=global_step)
if context.in_graph_mode(): if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer()) self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0)) self.assertAllClose([1.0, 2.0], self.evaluate(var0))
@ -108,13 +108,13 @@ class AddSignTest(test.TestCase):
# last 3 steps with negative gradient (sign(gm) should be -1) # last 3 steps with negative gradient (sign(gm) should be -1)
for t in range(1, 8): for t in range(1, 8):
if t < 5: if t < 5:
if context.in_graph_mode(): if not context.executing_eagerly():
self.evaluate(update) self.evaluate(update)
elif t > 1: elif t > 1:
opt.apply_gradients(zip([grads0, grads1], [var0, var1]), opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
global_step=global_step) global_step=global_step)
else: else:
if context.in_graph_mode(): if not context.executing_eagerly():
self.evaluate(neg_update) self.evaluate(neg_update)
elif t > 1: elif t > 1:
opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]), opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),

View File

@ -99,7 +99,7 @@ class PowerSignTest(test.TestCase):
neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]), neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),
global_step=global_step) global_step=global_step)
if context.in_graph_mode(): if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer()) self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0)) self.assertAllClose([1.0, 2.0], self.evaluate(var0))
@ -110,13 +110,13 @@ class PowerSignTest(test.TestCase):
# last 3 steps with negative gradient (sign(gm) should be -1) # last 3 steps with negative gradient (sign(gm) should be -1)
for t in range(1, 8): for t in range(1, 8):
if t < 5: if t < 5:
if context.in_graph_mode(): if not context.executing_eagerly():
self.evaluate(update) self.evaluate(update)
elif t > 1: elif t > 1:
opt.apply_gradients(zip([grads0, grads1], [var0, var1]), opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
global_step=global_step) global_step=global_step)
else: else:
if context.in_graph_mode(): if not context.executing_eagerly():
self.evaluate(neg_update) self.evaluate(neg_update)
elif t > 1: elif t > 1:
opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]), opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),

View File

@ -869,7 +869,7 @@ class LSTMTest(test.TestCase):
num_proj = 4 num_proj = 4
max_length = 8 max_length = 8
sequence_length = [4, 6] sequence_length = [4, 6]
in_graph_mode = context.in_graph_mode() in_graph_mode = not context.executing_eagerly()
with self.test_session(graph=ops_lib.Graph()) as sess: with self.test_session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer( initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed) -0.01, 0.01, seed=self._seed)
@ -934,8 +934,7 @@ class LSTMTest(test.TestCase):
if in_graph_mode: if in_graph_mode:
self.assertAllEqual(outputs_static, outputs_dynamic) self.assertAllEqual(outputs_static, outputs_dynamic)
else: else:
self.assertAllEqual( self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic)
array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy())
self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic)) self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic))
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
@ -946,7 +945,7 @@ class LSTMTest(test.TestCase):
num_proj = 4 num_proj = 4
max_length = 8 max_length = 8
sequence_length = [4, 6] sequence_length = [4, 6]
in_graph_mode = context.in_graph_mode() in_graph_mode = not context.executing_eagerly()
with self.test_session(graph=ops_lib.Graph()) as sess: with self.test_session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer( initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed) -0.01, 0.01, seed=self._seed)
@ -1022,10 +1021,9 @@ class LSTMTest(test.TestCase):
if in_graph_mode: if in_graph_mode:
self.assertAllEqual(outputs_static, outputs_dynamic) self.assertAllEqual(outputs_static, outputs_dynamic)
else: else:
self.assertAllEqual( self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic)
array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy()) state_static = nest.flatten(state_static)
state_static = [s.numpy() for s in nest.flatten(state_static)] state_dynamic = nest.flatten(state_dynamic)
state_dynamic = [s.numpy() for s in nest.flatten(state_dynamic)]
self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic)) self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic))
def _testDynamicEquivalentToStaticRNN(self, use_sequence_length): def _testDynamicEquivalentToStaticRNN(self, use_sequence_length):
@ -1043,7 +1041,7 @@ class LSTMTest(test.TestCase):
else: else:
sequence_length = None sequence_length = None
in_graph_mode = context.in_graph_mode() in_graph_mode = not context.executing_eagerly()
# TODO(b/68017812): Eager ignores operation seeds, so we need to create a # TODO(b/68017812): Eager ignores operation seeds, so we need to create a
# single cell and reuse it across the static and dynamic RNNs. Remove this # single cell and reuse it across the static and dynamic RNNs. Remove this

View File

@ -110,7 +110,7 @@ class SummaryWriter(object):
def __init__(self, resource): def __init__(self, resource):
self._resource = resource self._resource = resource
if context.in_eager_mode() and self._resource is not None: if context.executing_eagerly() and self._resource is not None:
self._resource_deleter = resource_variable_ops.EagerResourceDeleter( self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
handle=self._resource, handle_device="cpu:0") handle=self._resource, handle_device="cpu:0")
@ -158,7 +158,7 @@ def initialize(
@{tf.contrib.summary.SummaryWriter}. @{tf.contrib.summary.SummaryWriter}.
ValueError: If session wasn't passed and no default session. ValueError: If session wasn't passed and no default session.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
return return
if context.context().summary_writer_resource is None: if context.context().summary_writer_resource is None:
raise RuntimeError("No default tf.contrib.summary.SummaryWriter found") raise RuntimeError("No default tf.contrib.summary.SummaryWriter found")
@ -269,7 +269,7 @@ def _make_summary_writer(name, factory, **kwargs):
resource = gen_summary_ops.summary_writer(shared_name=name) resource = gen_summary_ops.summary_writer(shared_name=name)
# TODO(apassos): Consider doing this instead. # TODO(apassos): Consider doing this instead.
# node = factory(resource, **kwargs) # node = factory(resource, **kwargs)
# if not context.in_eager_mode(): # if not context.executing_eagerly():
# ops.get_default_session().run(node) # ops.get_default_session().run(node)
ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME,
factory(resource, **kwargs)) factory(resource, **kwargs))
@ -295,7 +295,7 @@ def all_summary_ops():
Returns: Returns:
The summary ops. The summary ops.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
return None return None
return ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access return ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
@ -309,7 +309,7 @@ def summary_writer_initializer_op():
Raises: Raises:
RuntimeError: If in Eager mode. RuntimeError: If in Eager mode.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError( raise RuntimeError(
"tf.contrib.summary.summary_writer_initializer_op is only " "tf.contrib.summary.summary_writer_initializer_op is only "
"supported in graph mode.") "supported in graph mode.")
@ -477,7 +477,7 @@ def graph(param, step=None, name=None):
Raises: Raises:
TypeError: If `param` isn't already a @{tf.Tensor} in graph mode. TypeError: If `param` isn't already a @{tf.Tensor} in graph mode.
""" """
if not context.in_eager_mode() and not isinstance(param, ops.Tensor): if not context.executing_eagerly() and not isinstance(param, ops.Tensor):
raise TypeError("graph() needs a tf.Tensor (e.g. tf.placeholder) in graph " raise TypeError("graph() needs a tf.Tensor (e.g. tf.placeholder) in graph "
"mode, but was: %s" % type(param)) "mode, but was: %s" % type(param))
writer = context.context().summary_writer_resource writer = context.context().summary_writer_resource

View File

@ -91,7 +91,7 @@ class Dataset(object):
Raises: Raises:
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError( raise RuntimeError(
"dataset.make_initializable_iterator is not supported when eager " "dataset.make_initializable_iterator is not supported when eager "
"execution is enabled.") "execution is enabled.")
@ -123,7 +123,7 @@ class Dataset(object):
Raises: Raises:
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError( raise RuntimeError(
"dataset.make_one_shot_iterator is not supported when eager " "dataset.make_one_shot_iterator is not supported when eager "
"execution is enabled.") "execution is enabled.")

View File

@ -65,7 +65,7 @@ class RandomSeedTest(test.TestCase):
self.assertEqual((g_seed, op_seed), toutput, msg=msg) self.assertEqual((g_seed, op_seed), toutput, msg=msg)
random_seed.set_random_seed(None) random_seed.set_random_seed(None)
if context.in_graph_mode(): if not context.executing_eagerly():
random_seed.set_random_seed(1) random_seed.set_random_seed(1)
tinput = (1, None) tinput = (1, None)
toutput = (1, ops.get_default_graph()._last_id) # pylint: disable=protected-access toutput = (1, ops.get_default_graph()._last_id) # pylint: disable=protected-access

View File

@ -55,7 +55,7 @@ def c_tfe_py_fastpath_execute(a,
transpose_b=False, transpose_b=False,
name=None): name=None):
ctx = context.context() ctx = context.context()
assert not ctx.in_graph_mode( assert ctx.in_eager_mode(
), "The prototype doesn't contain C code for graph construction" ), "The prototype doesn't contain C code for graph construction"
try: try:
return pywrap_tensorflow.TFE_Py_FastPathExecute( return pywrap_tensorflow.TFE_Py_FastPathExecute(

View File

@ -260,12 +260,8 @@ class Context(object):
if mode == EAGER_MODE: if mode == EAGER_MODE:
context_stack.pop() context_stack.pop()
def in_graph_mode(self): def executing_eagerly(self):
"""Returns True if current thread is in GRAPH mode.""" """Returns True if current thread has eager executing enabled."""
return self._eager_context.mode == GRAPH_MODE
def in_eager_mode(self):
"""Returns True if current thread is in EAGER mode."""
return self._eager_context.mode == EAGER_MODE return self._eager_context.mode == EAGER_MODE
def scalar_cache(self): def scalar_cache(self):
@ -522,23 +518,23 @@ def internal_operation_seed():
return context()._internal_operation_seed() # pylint: disable=protected-access return context()._internal_operation_seed() # pylint: disable=protected-access
def in_graph_mode(): def executing_eagerly():
"""Returns True if current thread is in GRAPH mode for default context.""" """Returns True if the current thread has eager execution enabled."""
return context().in_graph_mode() return context().executing_eagerly()
def in_eager_mode(): def in_eager_mode():
"""Returns True if current thread is in EAGER mode for default context.""" """Use executing_eagerly() instead. This function will be removed."""
return context().in_eager_mode() return executing_eagerly()
def graph_mode(): def graph_mode():
"""Context-manager to enable GRAPH mode for current thread.""" """Context-manager to disable eager execution for the current thread."""
return context()._mode(GRAPH_MODE) # pylint: disable=protected-access return context()._mode(GRAPH_MODE) # pylint: disable=protected-access
def eager_mode(): def eager_mode():
"""Context-manager to enable EAGER mode for current thread.""" """Context-manager to enable eager execution for the current thread."""
return context()._mode(EAGER_MODE) # pylint: disable=protected-access return context()._mode(EAGER_MODE) # pylint: disable=protected-access
@ -631,4 +627,8 @@ def export_run_metadata():
# (for example, enable_eager_execution in python/framework/ops.py), # (for example, enable_eager_execution in python/framework/ops.py),
# but they do all import this file. Note that IS_IN_GRAPH_MODE and # but they do all import this file. Note that IS_IN_GRAPH_MODE and
# in_graph_mode are both parameterless functions. # in_graph_mode are both parameterless functions.
is_in_graph_mode.IS_IN_GRAPH_MODE = in_graph_mode def _tmp_in_graph_mode():
return not executing_eagerly()
is_in_graph_mode.IS_IN_GRAPH_MODE = _tmp_in_graph_mode

View File

@ -57,8 +57,7 @@ class TFETest(test_util.TensorFlowTestCase):
def testContext(self): def testContext(self):
ctx = context.Context() ctx = context.Context()
self.assertFalse(ctx.in_graph_mode()) self.assertTrue(ctx.executing_eagerly())
self.assertTrue(ctx.in_eager_mode())
self.assertEqual('', ctx.scope_name) self.assertEqual('', ctx.scope_name)
ctx.scope_name = 'foo' ctx.scope_name = 'foo'
@ -150,9 +149,9 @@ class TFETest(test_util.TensorFlowTestCase):
def get_context_values(ctx): def get_context_values(ctx):
return [ return [
ctx.in_graph_mode(), ctx.executing_eagerly(), ctx.scope_name, ctx.summary_writer_resource,
ctx.in_eager_mode(), ctx.scope_name, ctx.summary_writer_resource, ctx.device_name,
ctx.device_name, ctx.num_gpus() ctx.num_gpus()
] ]
def get_values(ctx, values): def get_values(ctx, values):

View File

@ -112,7 +112,7 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
""" """
del as_ref # Unused. del as_ref # Unused.
if context.in_eager_mode(): if context.executing_eagerly():
return value return value
default_graph = ops.get_default_graph() default_graph = ops.get_default_graph()
@ -295,7 +295,7 @@ class _EagerDefinedFunction(object):
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
function_def = function_pb2.FunctionDef() function_def = function_pb2.FunctionDef()
function_def.ParseFromString(compat.as_bytes(proto_data)) function_def.ParseFromString(compat.as_bytes(proto_data))
if context.in_eager_mode(): if context.executing_eagerly():
_register(fn) _register(fn)
self.definition = function_def self.definition = function_def
self.name = function_def.signature.name self.name = function_def.signature.name
@ -438,7 +438,14 @@ class GraphModeFunction(object):
all_args = args + self._extra_inputs all_args = args + self._extra_inputs
signature = self._forward_fdef.signature signature = self._forward_fdef.signature
ctx = context.context() ctx = context.context()
if ctx.in_graph_mode(): if ctx.executing_eagerly():
outputs = execute.execute(
str(signature.name),
num_outputs=len(signature.output_arg),
inputs=all_args,
attrs=None,
ctx=ctx)
else:
g = ops.get_default_graph() g = ops.get_default_graph()
g._add_function(self._forward_fdef) # pylint: disable=protected-access g._add_function(self._forward_fdef) # pylint: disable=protected-access
op = g.create_op( op = g.create_op(
@ -453,13 +460,6 @@ class GraphModeFunction(object):
outputs, (ops.Tensor, type(None))) else list(outputs) outputs, (ops.Tensor, type(None))) else list(outputs)
for i, s in enumerate(self._output_shapes): for i, s in enumerate(self._output_shapes):
outputs[i].set_shape(s) outputs[i].set_shape(s)
else:
outputs = execute.execute(
str(signature.name),
num_outputs=len(signature.output_arg),
inputs=all_args,
attrs=None,
ctx=ctx)
real_outputs = outputs[:len(self._returns)] real_outputs = outputs[:len(self._returns)]
side_outputs = outputs[len(self._returns):] side_outputs = outputs[len(self._returns):]
@ -530,7 +530,14 @@ class GraphModeFunction(object):
return self._backprop_call(tensor_inputs) return self._backprop_call(tensor_inputs)
ctx = context.context() ctx = context.context()
if ctx.in_graph_mode(): if ctx.executing_eagerly():
result = execute.execute(
str(self._func_name),
num_outputs=self._num_outputs,
inputs=tensor_inputs + self._extra_inputs,
attrs=None,
ctx=ctx)
else:
g = ops.get_default_graph() g = ops.get_default_graph()
self.add_to_graph(g) self.add_to_graph(g)
signature = self._function_def.definition.signature signature = self._function_def.definition.signature
@ -547,13 +554,6 @@ class GraphModeFunction(object):
return op return op
for i, s in enumerate(self._output_shapes): for i, s in enumerate(self._output_shapes):
result[i].set_shape(s) result[i].set_shape(s)
else:
result = execute.execute(
str(self._func_name),
num_outputs=self._num_outputs,
inputs=tensor_inputs + self._extra_inputs,
attrs=None,
ctx=ctx)
return self._build_call_outputs(result) return self._build_call_outputs(result)
@ -666,7 +666,7 @@ def _defun_internal(name, func, args, kwds):
if x not in all_ignored_ops) if x not in all_ignored_ops)
# Register any other functions defined in the graph # Register any other functions defined in the graph
# TODO(ashankar): Oh lord, forgive me for this lint travesty. # TODO(ashankar): Oh lord, forgive me for this lint travesty.
if context.in_eager_mode(): if context.executing_eagerly():
for f in tmp_graph._functions.values(): # pylint: disable=protected-access for f in tmp_graph._functions.values(): # pylint: disable=protected-access
# TODO(ashankar): What about the gradient registry? # TODO(ashankar): What about the gradient registry?
_register(f._c_func) # pylint: disable=protected-access _register(f._c_func) # pylint: disable=protected-access
@ -906,7 +906,7 @@ class AutomaticControlDependencies(object):
return tensor return tensor
def __enter__(self): def __enter__(self):
if context.in_eager_mode(): if context.executing_eagerly():
return self return self
# This code assumes no other thread is adding ops to the graph while # This code assumes no other thread is adding ops to the graph while
# we're adding ops to the graph. # we're adding ops to the graph.
@ -977,7 +977,7 @@ class AutomaticControlDependencies(object):
merge_for_resource[o] = new_merge[0].op merge_for_resource[o] = new_merge[0].op
def __exit__(self, unused_type, unused_value, unused_traceback): def __exit__(self, unused_type, unused_value, unused_traceback):
if context.in_eager_mode(): if context.executing_eagerly():
return return
if self._graph is not ops.get_default_graph(): if self._graph is not ops.get_default_graph():

View File

@ -406,7 +406,7 @@ def graph_callable(shape_and_dtypes):
A callable graph object. A callable graph object.
""" """
# TODO(alive,apassos): support initialized_value and friends from tf.Variable. # TODO(alive,apassos): support initialized_value and friends from tf.Variable.
assert context.in_eager_mode(), ( assert context.executing_eagerly(), (
"graph_callable can only be used when Eager execution is enabled.") "graph_callable can only be used when Eager execution is enabled.")
def decorator(func): def decorator(func):
return tf_decorator.make_decorator(func, return tf_decorator.make_decorator(func,

View File

@ -367,7 +367,7 @@ void GenEagerPythonOp::HandleGraphMode(const string& function_setup) {
// Handle graph-mode case // Handle graph-mode case
strings::StrAppend(&result_, strings::StrAppend(&result_,
" _ctx = _context.context()\n" " _ctx = _context.context()\n"
" if _ctx.in_graph_mode():\n", " if not _ctx.executing_eagerly():\n",
function_setup, function_setup,
" _, _, _op = _op_def_lib._apply_op_helper(\n"); " _, _, _op = _op_def_lib._apply_op_helper(\n");
AddBodyNoReturn(" "); AddBodyNoReturn(" ");

View File

@ -169,7 +169,7 @@ class Tests(test.TestCase):
def testFastpathExecute_InvalidInputs(self): def testFastpathExecute_InvalidInputs(self):
a_2_by_2 = random_ops.random_uniform((2, 2)) a_2_by_2 = random_ops.random_uniform((2, 2))
ctx = context.context() ctx = context.context()
assert not ctx.in_graph_mode( assert ctx.executing_eagerly(
), "The prototype doesn't contain C code for graph construction" ), "The prototype doesn't contain C code for graph construction"
ctx_handle = ctx._handle # pylint: disable=protected-access ctx_handle = ctx._handle # pylint: disable=protected-access

View File

@ -166,7 +166,7 @@ class Estimator(object):
ValueError: if this is called via a subclass and if that class overrides ValueError: if this is called via a subclass and if that class overrides
a member of `Estimator`. a member of `Estimator`.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError( raise RuntimeError(
'Estimators are not supported when eager execution is enabled.') 'Estimators are not supported when eager execution is enabled.')

View File

@ -181,7 +181,7 @@ def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
TypeError: if shape is incorrectly specified or unsupported. TypeError: if shape is incorrectly specified or unsupported.
""" """
ctx = context.context() ctx = context.context()
if not ctx.in_graph_mode(): if ctx.executing_eagerly():
t = convert_to_eager_tensor(value, ctx, dtype) t = convert_to_eager_tensor(value, ctx, dtype)
if shape is None: if shape is None:
return t return t

View File

@ -489,10 +489,10 @@ class _DefinedFunction(object):
# Adds this function into 'g'. # Adds this function into 'g'.
# pylint: disable=protected-access # pylint: disable=protected-access
if context.in_graph_mode(): if context.executing_eagerly():
g._add_function(self)
else:
context.context().add_function_def(self.definition) context.context().add_function_def(self.definition)
else:
g._add_function(self)
# pylint: enable=protected-access # pylint: enable=protected-access
# Ensures related sub-routines are defined in 'g', too. # Ensures related sub-routines are defined in 'g', too.

View File

@ -695,7 +695,7 @@ def import_scoped_meta_graph(meta_graph_or_file,
Raises: Raises:
ValueError: If the graph_def contains unbound inputs. ValueError: If the graph_def contains unbound inputs.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise ValueError("Exporting/importing meta graphs is not supported when " raise ValueError("Exporting/importing meta graphs is not supported when "
"eager execution is enabled.") "eager execution is enabled.")
if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
@ -856,7 +856,7 @@ def export_scoped_meta_graph(filename=None,
Raises: Raises:
ValueError: When the `GraphDef` is larger than 2GB. ValueError: When the `GraphDef` is larger than 2GB.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise ValueError("Exporting/importing meta graphs is not supported when " raise ValueError("Exporting/importing meta graphs is not supported when "
"Eager Execution is enabled.") "Eager Execution is enabled.")
graph = graph or ops.get_default_graph() graph = graph or ops.get_default_graph()

View File

@ -395,10 +395,10 @@ class Tensor(_TensorLike):
"Tensor._shape cannot be assigned, use Tensor.set_shape instead.") "Tensor._shape cannot be assigned, use Tensor.set_shape instead.")
def __iter__(self): def __iter__(self):
if context.in_graph_mode(): if not context.executing_eagerly():
raise TypeError( raise TypeError(
"`Tensor` objects are not iterable when eager execution is not " "Tensor objects are not iterable when eager execution is not "
"enabled. To iterate over this tensor use `tf.map_fn`.") "enabled. To iterate over this tensor use tf.map_fn.")
shape = self._shape_tuple() shape = self._shape_tuple()
if shape is None: if shape is None:
raise TypeError("Cannot iterate over a tensor with unknown shape.") raise TypeError("Cannot iterate over a tensor with unknown shape.")
@ -772,7 +772,7 @@ class _EagerTensorBase(Tensor):
six.raise_from(core._status_to_exception(e.code, e.message), None) six.raise_from(core._status_to_exception(e.code, e.message), None)
# Record the copy on tape and define backprop copy as well. # Record the copy on tape and define backprop copy as well.
if not context.in_graph_mode(): if context.executing_eagerly():
self_device = self.device self_device = self.device
def grad_fun(dresult): def grad_fun(dresult):
return [dresult._copy(device_name=self_device)] return [dresult._copy(device_name=self_device)]
@ -993,7 +993,7 @@ def internal_convert_to_tensor(value,
""" """
if ctx is None: ctx = context.context() if ctx is None: ctx = context.context()
if ctx.in_eager_mode(): if ctx.executing_eagerly():
# Fast path for EagerTensors that don't need any conversion. # Fast path for EagerTensors that don't need any conversion.
if isinstance(value, EagerTensor): if isinstance(value, EagerTensor):
# Note that we don't check that value's dtype matches the dtype # Note that we don't check that value's dtype matches the dtype
@ -4797,15 +4797,15 @@ def device(device_name_or_function):
Raises: Raises:
RuntimeError: If eager execution is enabled and a function is passed in. RuntimeError: If eager execution is enabled and a function is passed in.
""" """
if context.in_graph_mode(): if context.executing_eagerly():
return get_default_graph().device(device_name_or_function)
else:
# TODO(agarwal): support device functions in EAGER mode. # TODO(agarwal): support device functions in EAGER mode.
if callable(device_name_or_function): if callable(device_name_or_function):
raise RuntimeError( raise RuntimeError(
"tf.device does not support functions when eager execution " "tf.device does not support functions when eager execution "
"is enabled.") "is enabled.")
return context.device(device_name_or_function) return context.device(device_name_or_function)
else:
return get_default_graph().device(device_name_or_function)
@tf_export("container") @tf_export("container")
@ -4824,7 +4824,12 @@ def container(container_name):
@tf_export("colocate_with") @tf_export("colocate_with")
def colocate_with(op, ignore_existing=False): def colocate_with(op, ignore_existing=False):
if context.in_graph_mode(): if context.executing_eagerly():
if op is not None:
return device(op.device)
else:
return _NullContextmanager()
else:
default_graph = get_default_graph() default_graph = get_default_graph()
if isinstance(op, EagerTensor): if isinstance(op, EagerTensor):
if default_graph.building_function: if default_graph.building_function:
@ -4833,11 +4838,6 @@ def colocate_with(op, ignore_existing=False):
raise ValueError("Encountered an Eager-defined Tensor during graph " raise ValueError("Encountered an Eager-defined Tensor during graph "
"construction, but a function was not being built.") "construction, but a function was not being built.")
return default_graph.colocate_with(op, ignore_existing) return default_graph.colocate_with(op, ignore_existing)
else:
if op is not None:
return device(op.device)
else:
return _NullContextmanager()
@tf_export("control_dependencies") @tf_export("control_dependencies")
@ -4857,10 +4857,10 @@ def control_dependencies(control_inputs):
A context manager that specifies control dependencies for all A context manager that specifies control dependencies for all
operations constructed within the context. operations constructed within the context.
""" """
if context.in_graph_mode(): if context.executing_eagerly():
return get_default_graph().control_dependencies(control_inputs)
else:
return _NullContextmanager() return _NullContextmanager()
else:
return get_default_graph().control_dependencies(control_inputs)
class _DefaultStack(threading.local): class _DefaultStack(threading.local):
@ -5123,7 +5123,7 @@ def init_scope():
""" """
# pylint: enable=g-doc-return-or-yield,line-too-long # pylint: enable=g-doc-return-or-yield,line-too-long
if context.in_eager_mode(): if context.executing_eagerly():
# Fastpath. # Fastpath.
with tape.stop_recording(): with tape.stop_recording():
yield yield
@ -5705,7 +5705,7 @@ class name_scope(object): # pylint: disable=invalid-name
self._default_name = default_name self._default_name = default_name
self._values = values self._values = values
self._ctx = context.context() self._ctx = context.context()
self._in_eager_mode = self._ctx.in_eager_mode() self._in_eager_mode = self._ctx.executing_eagerly()
def __enter__(self): def __enter__(self):
"""Start the scope block. """Start the scope block.
@ -5884,7 +5884,7 @@ def get_from_proto_function(collection_name):
def _assert_collection_is_ok(collection_name): def _assert_collection_is_ok(collection_name):
if context.in_eager_mode(): if context.executing_eagerly():
if collection_name in GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access if collection_name in GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access
raise ValueError("When Eager Execution is enabled, variable " raise ValueError("When Eager Execution is enabled, variable "
"collections are not supported.") "collections are not supported.")

View File

@ -1763,7 +1763,13 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
return constant_op.constant(2.0) return constant_op.constant(2.0)
future.calls = 0 future.calls = 0
if context.in_graph_mode(): if context.executing_eagerly():
a = constant_op.constant(1.0)
b = future()
with ops.control_dependencies([a, b]):
c = constant_op.constant(3.0)
self.assertEqual(future.calls, 1)
else:
g = ops.Graph() g = ops.Graph()
with g.as_default(): with g.as_default():
a = constant_op.constant(1.0) a = constant_op.constant(1.0)
@ -1772,12 +1778,6 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
c = constant_op.constant(3.0) c = constant_op.constant(3.0)
self.assertEqual(c.op.control_inputs, [a.op, b.op]) self.assertEqual(c.op.control_inputs, [a.op, b.op])
self.assertEqual(future.calls, 1) self.assertEqual(future.calls, 1)
else:
a = constant_op.constant(1.0)
b = future()
with ops.control_dependencies([a, b]):
c = constant_op.constant(3.0)
self.assertEqual(future.calls, 1)
def testBasicWithConversion(self): def testBasicWithConversion(self):
g = ops.Graph() g = ops.Graph()
@ -2150,11 +2150,11 @@ class InitScopeTest(test_util.TensorFlowTestCase):
with ops.init_scope(): with ops.init_scope():
# Because g is building a function, init_scope should # Because g is building a function, init_scope should
# escape out to the eager context. # escape out to the eager context.
self.assertTrue(context.in_eager_mode()) self.assertTrue(context.executing_eagerly())
# g should be reinstated as the default graph, and the # g should be reinstated as the default graph, and the
# graph context should be re-entered. # graph context should be re-entered.
self.assertIs(g, ops.get_default_graph()) self.assertIs(g, ops.get_default_graph())
self.assertTrue(context.in_graph_mode()) self.assertFalse(context.executing_eagerly())
def testStaysInEagerWhenOnlyEagerContextActive(self): def testStaysInEagerWhenOnlyEagerContextActive(self):
with context.eager_mode(): with context.eager_mode():
@ -2277,12 +2277,13 @@ class InitScopeTest(test_util.TensorFlowTestCase):
with context.eager_mode(): with context.eager_mode():
def foo(): def foo():
with ops.name_scope("inner"), ops.init_scope(): with ops.name_scope("inner"), ops.init_scope():
if context.in_graph_mode(): if context.executing_eagerly():
self.assertEqual(ops.get_name_scope(), "inner")
else:
# A trailing slash is always appended when eager execution is # A trailing slash is always appended when eager execution is
# enabled. # enabled.
self.assertEqual(context.context().scope_name, "inner/") self.assertEqual(context.context().scope_name, "inner/")
else:
self.assertEqual(ops.get_name_scope(), "inner")
foo() foo()
self.assertEqual(ops.get_name_scope(), "") self.assertEqual(ops.get_name_scope(), "")
foo_compiled = eager_function.defun(foo) foo_compiled = eager_function.defun(foo)

View File

@ -52,20 +52,20 @@ def get_seed(op_seed):
A tuple of two integers that should be used for the local seed of this A tuple of two integers that should be used for the local seed of this
operation. operation.
""" """
is_graph_mode = context.in_graph_mode() eager = context.executing_eagerly()
if is_graph_mode: if eager:
global_seed = ops.get_default_graph().seed
else:
global_seed = context.global_seed() global_seed = context.global_seed()
else:
global_seed = ops.get_default_graph().seed
if global_seed is not None: if global_seed is not None:
if op_seed is None: if op_seed is None:
# pylint: disable=protected-access # pylint: disable=protected-access
if is_graph_mode: if eager:
op_seed = ops.get_default_graph()._last_id
else:
op_seed = context.internal_operation_seed() op_seed = context.internal_operation_seed()
else:
op_seed = ops.get_default_graph()._last_id
seeds = _truncate_seed(global_seed), _truncate_seed(op_seed) seeds = _truncate_seed(global_seed), _truncate_seed(op_seed)
else: else:
@ -176,7 +176,7 @@ def set_random_seed(seed):
Args: Args:
seed: integer. seed: integer.
""" """
if context.in_graph_mode(): if context.executing_eagerly():
ops.get_default_graph().seed = seed
else:
context.set_global_seed(seed) context.set_global_seed(seed)
else:
ops.get_default_graph().seed = seed

View File

@ -40,13 +40,13 @@ class RandomSeedTest(test.TestCase):
((2**31 - 1, 0), (0, 2**31 - 1)), # Don't wrap to (0, 0) either ((2**31 - 1, 0), (0, 2**31 - 1)), # Don't wrap to (0, 0) either
((0, 2**31 - 1), (0, 2**31 - 1)), # Wrapping for the other argument ((0, 2**31 - 1), (0, 2**31 - 1)), # Wrapping for the other argument
] ]
if context.in_graph_mode(): if context.executing_eagerly():
# 0 will be the default_graph._lastid.
test_cases.append(((1, None), (1, 0)))
else:
# operation seed is random number generated based on global seed. # operation seed is random number generated based on global seed.
# it's not tested due to possibility of platform or version difference. # it's not tested due to possibility of platform or version difference.
pass pass
else:
# 0 will be the default_graph._lastid.
test_cases.append(((1, None), (1, 0)))
for tc in test_cases: for tc in test_cases:
tinput, toutput = tc[0], tc[1] tinput, toutput = tc[0], tc[1]
random_seed.set_random_seed(tinput[0]) random_seed.set_random_seed(tinput[0])

View File

@ -828,7 +828,7 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name
Returns: Returns:
A `TensorShape` based on the constant value of the given `tensor`. A `TensorShape` based on the constant value of the given `tensor`.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
return tensor_shape.as_shape( return tensor_shape.as_shape(
[dim if dim != -1 else None for dim in tensor.numpy()]) [dim if dim != -1 else None for dim in tensor.numpy()])

View File

@ -816,7 +816,7 @@ class TensorFlowTestCase(googletest.TestCase):
Returns: Returns:
tensors numpy values. tensors numpy values.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
return self._eval_helper(tensors) return self._eval_helper(tensors)
else: else:
sess = ops.get_default_session() sess = ops.get_default_session()

View File

@ -343,7 +343,7 @@ def learning_phase():
Returns: Returns:
Learning phase (scalar integer tensor or Python integer). Learning phase (scalar integer tensor or Python integer).
""" """
if context.in_eager_mode(): if context.executing_eagerly():
if 'eager' not in _GRAPH_LEARNING_PHASES: if 'eager' not in _GRAPH_LEARNING_PHASES:
# Fallback to inference mode as default. # Fallback to inference mode as default.
return 0 return 0
@ -370,7 +370,7 @@ def set_learning_phase(value):
global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned
if value not in {0, 1}: if value not in {0, 1}:
raise ValueError('Expected learning phase to be 0 or 1.') raise ValueError('Expected learning phase to be 0 or 1.')
if context.in_eager_mode(): if context.executing_eagerly():
_GRAPH_LEARNING_PHASES['eager'] = value _GRAPH_LEARNING_PHASES['eager'] = value
else: else:
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
@ -399,7 +399,7 @@ def learning_phase_scope(value):
yield value yield value
finally: finally:
# Restore learning phase to initial value. # Restore learning phase to initial value.
if context.in_eager_mode(): if context.executing_eagerly():
_GRAPH_LEARNING_PHASES['eager'] = previous_value _GRAPH_LEARNING_PHASES['eager'] = previous_value
else: else:
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value
@ -2625,7 +2625,7 @@ def get_value(x):
Returns: Returns:
A Numpy array. A Numpy array.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
return x.numpy() return x.numpy()
return x.eval(session=get_session()) return x.eval(session=get_session())
@ -2640,7 +2640,7 @@ def batch_get_value(tensors):
Returns: Returns:
A list of Numpy arrays. A list of Numpy arrays.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
return [x.numpy() for x in tensors] return [x.numpy() for x in tensors]
if tensors: if tensors:
return get_session().run(tensors) return get_session().run(tensors)
@ -2658,7 +2658,7 @@ def set_value(x, value):
(of the same shape). (of the same shape).
""" """
value = np.asarray(value, dtype=dtype(x)) value = np.asarray(value, dtype=dtype(x))
if context.in_eager_mode(): if context.executing_eagerly():
x.assign(value) x.assign(value)
else: else:
tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0]) tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0])
@ -2681,7 +2681,7 @@ def batch_set_value(tuples):
tuples: a list of tuples `(tensor, value)`. tuples: a list of tuples `(tensor, value)`.
`value` should be a Numpy array. `value` should be a Numpy array.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
for x, value in tuples: for x, value in tuples:
x.assign(np.asarray(value, dtype=dtype(x))) x.assign(np.asarray(value, dtype=dtype(x)))
else: else:
@ -3123,7 +3123,7 @@ def rnn(step_function,
outputs_shape[1] = inputs_shape[1] outputs_shape[1] = inputs_shape[1]
outputs.set_shape(outputs_shape) outputs.set_shape(outputs_shape)
if not context.in_eager_mode(): if not context.executing_eagerly():
last_output._uses_learning_phase = uses_learning_phase last_output._uses_learning_phase = uses_learning_phase
return last_output, outputs, new_states return last_output, outputs, new_states

View File

@ -237,7 +237,7 @@ class Layer(tf_base_layers.Layer):
""" """
# Actually call the layer (optionally building it). # Actually call the layer (optionally building it).
output = super(Layer, self).__call__(inputs, **kwargs) output = super(Layer, self).__call__(inputs, **kwargs)
if context.in_eager_mode(): if context.executing_eagerly():
return output return output
if hasattr(self, '_symbolic_set_inputs') and not self.inputs: if hasattr(self, '_symbolic_set_inputs') and not self.inputs:

View File

@ -92,7 +92,7 @@ class InputLayer(base_layer.Layer):
else: else:
batch_input_shape = None batch_input_shape = None
if context.in_eager_mode(): if context.executing_eagerly():
# In eager mode, create a temporary placeholder to call the layer on. # In eager mode, create a temporary placeholder to call the layer on.
input_tensor = tf_base_layers._DeferredTensor( # pylint: disable=protected-access input_tensor = tf_base_layers._DeferredTensor( # pylint: disable=protected-access
shape=batch_input_shape, shape=batch_input_shape,

View File

@ -99,7 +99,7 @@ class Network(base_layer.Layer):
self._losses = [] # Used in symbolic mode only. self._losses = [] # Used in symbolic mode only.
self._scope = None # Never used. self._scope = None # Never used.
self._reuse = None # Never used. self._reuse = None # Never used.
if context.in_eager_mode(): if context.executing_eagerly():
self._graph = None self._graph = None
else: else:
self._graph = ops.get_default_graph() # Used in symbolic mode only. self._graph = ops.get_default_graph() # Used in symbolic mode only.
@ -126,7 +126,7 @@ class Network(base_layer.Layer):
self.outputs = [outputs] self.outputs = [outputs]
# User-prodived argument validation. # User-prodived argument validation.
if context.in_eager_mode(): if context.executing_eagerly():
# Check that all inputs/outputs are DeferredTensors. # Check that all inputs/outputs are DeferredTensors.
for tensor in self.inputs: for tensor in self.inputs:
if not isinstance(tensor, tf_base_layers._DeferredTensor): # pylint: disable=protected-access if not isinstance(tensor, tf_base_layers._DeferredTensor): # pylint: disable=protected-access
@ -275,7 +275,7 @@ class Network(base_layer.Layer):
self._feed_input_names.append(layer.name) self._feed_input_names.append(layer.name)
self._feed_input_shapes.append(K.int_shape(self.inputs[i])) self._feed_input_shapes.append(K.int_shape(self.inputs[i]))
# layer.input gives an error in eager mode # layer.input gives an error in eager mode
if context.in_graph_mode(): if not context.executing_eagerly():
self._feed_inputs.append(layer.input) self._feed_inputs.append(layer.input)
for layer in self._output_layers: for layer in self._output_layers:
self.output_names.append(layer.name) self.output_names.append(layer.name)
@ -317,7 +317,7 @@ class Network(base_layer.Layer):
raise NotImplementedError('`add_variable` is not supported on Networks.') raise NotImplementedError('`add_variable` is not supported on Networks.')
def add_loss(self, *args, **kwargs): def add_loss(self, *args, **kwargs):
if context.in_eager_mode(): if context.executing_eagerly():
raise NotImplementedError('`add_loss` is not supported on Networks ' raise NotImplementedError('`add_loss` is not supported on Networks '
'when eager execution is enabled.') 'when eager execution is enabled.')
super(Network, self).add_loss(*args, **kwargs) super(Network, self).add_loss(*args, **kwargs)
@ -483,7 +483,7 @@ class Network(base_layer.Layer):
Returns: Returns:
A list of update ops. A list of update ops.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
return [] return []
if not self.trainable and not self.stateful: if not self.trainable and not self.stateful:
@ -530,7 +530,7 @@ class Network(base_layer.Layer):
losses = [] losses = []
for layer in self.layers: for layer in self.layers:
losses += layer.losses losses += layer.losses
if context.in_eager_mode(): if context.executing_eagerly():
return losses return losses
if self.inputs: if self.inputs:
@ -623,7 +623,7 @@ class Network(base_layer.Layer):
else: else:
masks = nest.flatten(mask) masks = nest.flatten(mask)
if context.in_graph_mode(): if not context.executing_eagerly():
# Try to retrieve cached outputs if the layer has already been called # Try to retrieve cached outputs if the layer has already been called
# on these exact inputs. # on these exact inputs.
cache_key = (tf_layers_util.object_list_uid(inputs) cache_key = (tf_layers_util.object_list_uid(inputs)
@ -829,7 +829,7 @@ class Network(base_layer.Layer):
else: else:
output_masks = [None for _ in range(len(output_tensors))] output_masks = [None for _ in range(len(output_tensors))]
if context.in_graph_mode(): if not context.executing_eagerly():
if layer.activity_regularizer is not None: if layer.activity_regularizer is not None:
regularization_losses = [ regularization_losses = [
layer.activity_regularizer(x) for x in output_tensors layer.activity_regularizer(x) for x in output_tensors
@ -859,7 +859,7 @@ class Network(base_layer.Layer):
if output_masks is not None: if output_masks is not None:
output_masks = output_masks[0] output_masks = output_masks[0]
if context.in_graph_mode(): if not context.executing_eagerly():
# Update cache; # Update cache;
# keys are based on ids on input tensors and inputs masks. # keys are based on ids on input tensors and inputs masks.
cache_key = (tf_layers_util.object_list_uid(inputs) cache_key = (tf_layers_util.object_list_uid(inputs)

View File

@ -755,7 +755,17 @@ class TopologyConstructionTest(test.TestCase):
def compute_mask(self, inputs, mask=None): def compute_mask(self, inputs, mask=None):
return array_ops.ones_like(inputs) return array_ops.ones_like(inputs)
if context.in_graph_mode(): if context.executing_eagerly():
a = constant_op.constant([2] * 32)
mask = constant_op.constant([0, 1] * 16)
a._keras_mask = mask
b = MaskedLayer().apply(a)
self.assertTrue(hasattr(b, '_keras_mask'))
self.assertAllEqual(
self.evaluate(array_ops.ones_like(mask)),
self.evaluate(getattr(b, '_keras_mask')))
self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b))
else:
x = keras.Input(shape=(32,)) x = keras.Input(shape=(32,))
y = MaskedLayer()(x) # pylint: disable=not-callable y = MaskedLayer()(x) # pylint: disable=not-callable
network = keras.engine.Network(x, y) network = keras.engine.Network(x, y)
@ -769,15 +779,6 @@ class TopologyConstructionTest(test.TestCase):
x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32)) x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32))
y_2 = network(x_2) y_2 = network(x_2)
self.assertEqual(y_2.get_shape().as_list(), [None, 32]) self.assertEqual(y_2.get_shape().as_list(), [None, 32])
else:
a = constant_op.constant([2] * 32)
mask = constant_op.constant([0, 1] * 16)
a._keras_mask = mask
b = MaskedLayer().apply(a)
self.assertTrue(hasattr(b, '_keras_mask'))
self.assertAllEqual(self.evaluate(array_ops.ones_like(mask)),
self.evaluate(getattr(b, '_keras_mask')))
self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b))
def test_activity_regularization_with_model_composition(self): def test_activity_regularization_with_model_composition(self):
@ -885,13 +886,13 @@ class DeferredModeTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testSimpleNetworkBuilding(self): def testSimpleNetworkBuilding(self):
inputs = keras.engine.Input(shape=(32,)) inputs = keras.engine.Input(shape=(32,))
if context.in_eager_mode(): if context.executing_eagerly():
self.assertIsInstance(inputs, tf_base_layers._DeferredTensor) self.assertIsInstance(inputs, tf_base_layers._DeferredTensor)
self.assertEqual(inputs.dtype.name, 'float32') self.assertEqual(inputs.dtype.name, 'float32')
self.assertEqual(inputs.shape.as_list(), [None, 32]) self.assertEqual(inputs.shape.as_list(), [None, 32])
x = keras.layers.Dense(2)(inputs) x = keras.layers.Dense(2)(inputs)
if context.in_eager_mode(): if context.executing_eagerly():
self.assertIsInstance(x, tf_base_layers._DeferredTensor) self.assertIsInstance(x, tf_base_layers._DeferredTensor)
self.assertEqual(x.dtype.name, 'float32') self.assertEqual(x.dtype.name, 'float32')
self.assertEqual(x.shape.as_list(), [None, 2]) self.assertEqual(x.shape.as_list(), [None, 2])
@ -900,7 +901,7 @@ class DeferredModeTest(test.TestCase):
network = keras.engine.Network(inputs, outputs) network = keras.engine.Network(inputs, outputs)
self.assertIsInstance(network, keras.engine.Network) self.assertIsInstance(network, keras.engine.Network)
if context.in_eager_mode(): if context.executing_eagerly():
# It should be possible to call such a network on EagerTensors. # It should be possible to call such a network on EagerTensors.
inputs = constant_op.constant( inputs = constant_op.constant(
np.random.random((10, 32)).astype('float32')) np.random.random((10, 32)).astype('float32'))
@ -925,7 +926,7 @@ class DeferredModeTest(test.TestCase):
c = keras.layers.Dense(2)(c) c = keras.layers.Dense(2)(c)
network = keras.engine.Network([input_a, input_b], [a, c]) network = keras.engine.Network([input_a, input_b], [a, c])
if context.in_eager_mode(): if context.executing_eagerly():
a_val = constant_op.constant( a_val = constant_op.constant(
np.random.random((10, 32)).astype('float32')) np.random.random((10, 32)).astype('float32'))
b_val = constant_op.constant( b_val = constant_op.constant(

View File

@ -162,7 +162,7 @@ class Model(Network):
`optimizer`, `loss`, `metrics` or `sample_weight_mode`. `optimizer`, `loss`, `metrics` or `sample_weight_mode`.
""" """
loss = loss or {} loss = loss or {}
if context.in_eager_mode() and not isinstance( if context.executing_eagerly() and not isinstance(
optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)): optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
raise ValueError('Only TF native optimizers are supported in Eager mode.') raise ValueError('Only TF native optimizers are supported in Eager mode.')
@ -170,13 +170,13 @@ class Model(Network):
self.loss = loss self.loss = loss
self.metrics = metrics or [] self.metrics = metrics or []
self.loss_weights = loss_weights self.loss_weights = loss_weights
if context.in_eager_mode() and sample_weight_mode is not None: if context.executing_eagerly() and sample_weight_mode is not None:
raise ValueError('sample_weight_mode is not supported in Eager mode.') raise ValueError('sample_weight_mode is not supported in Eager mode.')
self.sample_weight_mode = sample_weight_mode self.sample_weight_mode = sample_weight_mode
if context.in_eager_mode() and weighted_metrics is not None: if context.executing_eagerly() and weighted_metrics is not None:
raise ValueError('weighted_metrics is not supported in Eager mode.') raise ValueError('weighted_metrics is not supported in Eager mode.')
self.weighted_metrics = weighted_metrics self.weighted_metrics = weighted_metrics
if context.in_eager_mode() and target_tensors is not None: if context.executing_eagerly() and target_tensors is not None:
raise ValueError('target_tensors is not supported in Eager mode.') raise ValueError('target_tensors is not supported in Eager mode.')
self.target_tensors = target_tensors self.target_tensors = target_tensors
@ -230,7 +230,7 @@ class Model(Network):
skip_target_weighing_indices.append(i) skip_target_weighing_indices.append(i)
# Prepare output masks. # Prepare output masks.
if context.in_graph_mode(): if not context.executing_eagerly():
masks = self.compute_mask(self.inputs, mask=None) masks = self.compute_mask(self.inputs, mask=None)
if masks is None: if masks is None:
masks = [None for _ in self.outputs] masks = [None for _ in self.outputs]
@ -264,7 +264,7 @@ class Model(Network):
self.loss_weights_list = loss_weights_list self.loss_weights_list = loss_weights_list
# initialization for Eager mode execution # initialization for Eager mode execution
if context.in_eager_mode(): if context.executing_eagerly():
if target_tensors is not None: if target_tensors is not None:
raise ValueError('target_tensors are not currently supported in Eager' raise ValueError('target_tensors are not currently supported in Eager'
'mode.') 'mode.')
@ -738,13 +738,13 @@ class Model(Network):
'TensorFlow tensors. ' 'TensorFlow tensors. '
'You passed: x=' + str(x) + '; y=' + str(y)) 'You passed: x=' + str(x) + '; y=' + str(y))
if context.in_graph_mode(): if context.executing_eagerly():
target_tensors = None
else:
# Handle target tensors if any passed. # Handle target tensors if any passed.
if not isinstance(y, (list, tuple)): if not isinstance(y, (list, tuple)):
y = [y] y = [y]
target_tensors = [v for v in y if tensor_util.is_tensor(v)] target_tensors = [v for v in y if tensor_util.is_tensor(v)]
else:
target_tensors = None
self.compile(optimizer=self.optimizer, self.compile(optimizer=self.optimizer,
loss=self.loss, loss=self.loss,
metrics=self.metrics, metrics=self.metrics,
@ -761,7 +761,7 @@ class Model(Network):
# What follows is input validation and standardization to list format, # What follows is input validation and standardization to list format,
# in the case where all inputs are value arrays. # in the case where all inputs are value arrays.
if context.in_eager_mode(): if context.executing_eagerly():
# In eager mode, do not do shape validation. # In eager mode, do not do shape validation.
feed_input_names = self.input_names feed_input_names = self.input_names
feed_input_shapes = None feed_input_shapes = None
@ -784,7 +784,7 @@ class Model(Network):
exception_prefix='input') exception_prefix='input')
if y is not None: if y is not None:
if context.in_eager_mode(): if context.executing_eagerly():
feed_output_names = self.output_names feed_output_names = self.output_names
feed_output_shapes = None feed_output_shapes = None
# Sample weighting not supported in this case. # Sample weighting not supported in this case.
@ -835,7 +835,7 @@ class Model(Network):
] ]
# Check that all arrays have the same length. # Check that all arrays have the same length.
training_utils.check_array_lengths(x, y, sample_weights) training_utils.check_array_lengths(x, y, sample_weights)
if self._is_graph_network and not context.in_eager_mode(): if self._is_graph_network and not context.executing_eagerly():
# Additional checks to avoid users mistakenly using improper loss fns. # Additional checks to avoid users mistakenly using improper loss fns.
training_utils.check_loss_and_target_compatibility( training_utils.check_loss_and_target_compatibility(
y, self._feed_loss_fns, feed_output_shapes) y, self._feed_loss_fns, feed_output_shapes)
@ -874,7 +874,7 @@ class Model(Network):
whether to build the model's graph in inference mode (False), training whether to build the model's graph in inference mode (False), training
mode (True), or using the Keras learning phase (None). mode (True), or using the Keras learning phase (None).
""" """
if context.in_eager_mode(): if context.executing_eagerly():
self._eager_set_inputs(inputs) self._eager_set_inputs(inputs)
else: else:
self._symbolic_set_inputs(inputs, training=training) self._symbolic_set_inputs(inputs, training=training)
@ -903,7 +903,7 @@ class Model(Network):
Raises: Raises:
ValueError: If the model's inputs are already set. ValueError: If the model's inputs are already set.
""" """
assert context.in_eager_mode() assert context.executing_eagerly()
if self.inputs: if self.inputs:
raise ValueError('Model inputs are already set.') raise ValueError('Model inputs are already set.')
# On-the-fly setting of model inputs/outputs as DeferredTensors, # On-the-fly setting of model inputs/outputs as DeferredTensors,
@ -950,7 +950,7 @@ class Model(Network):
Raises: Raises:
ValueError: If the model's inputs are already set. ValueError: If the model's inputs are already set.
""" """
assert context.in_graph_mode() assert not context.executing_eagerly()
if self.inputs: if self.inputs:
raise ValueError('Model inputs are already set.') raise ValueError('Model inputs are already set.')
@ -1186,7 +1186,7 @@ class Model(Network):
val_y = None val_y = None
val_sample_weights = None val_sample_weights = None
if context.in_eager_mode(): if context.executing_eagerly():
return training_eager.fit_loop( return training_eager.fit_loop(
self, self,
inputs=x, inputs=x,
@ -1289,7 +1289,7 @@ class Model(Network):
sample_weight=sample_weight, sample_weight=sample_weight,
batch_size=batch_size) batch_size=batch_size)
if context.in_eager_mode(): if context.executing_eagerly():
return training_eager.test_loop( return training_eager.test_loop(
self, inputs=x, targets=y, sample_weights=sample_weights, self, inputs=x, targets=y, sample_weights=sample_weights,
batch_size=batch_size, verbose=verbose, steps=steps) batch_size=batch_size, verbose=verbose, steps=steps)
@ -1330,7 +1330,7 @@ class Model(Network):
'argument.') 'argument.')
x, _, _ = self._standardize_user_data(x) x, _, _ = self._standardize_user_data(x)
if context.in_eager_mode(): if context.executing_eagerly():
return training_eager.predict_loop( return training_eager.predict_loop(
self, x, batch_size=batch_size, verbose=verbose, steps=steps) self, x, batch_size=batch_size, verbose=verbose, steps=steps)
else: else:
@ -1381,7 +1381,7 @@ class Model(Network):
sample_weight=sample_weight, sample_weight=sample_weight,
class_weight=class_weight) class_weight=class_weight)
if context.in_eager_mode(): if context.executing_eagerly():
outputs = training_eager.train_on_batch( outputs = training_eager.train_on_batch(
self, x, y, sample_weights=sample_weights) self, x, y, sample_weights=sample_weights)
else: else:
@ -1431,7 +1431,7 @@ class Model(Network):
x, y, sample_weights = self._standardize_user_data( x, y, sample_weights = self._standardize_user_data(
x, y, sample_weight=sample_weight) x, y, sample_weight=sample_weight)
if context.in_eager_mode(): if context.executing_eagerly():
outputs = training_eager.test_on_batch( outputs = training_eager.test_on_batch(
self, x, y, sample_weights=sample_weights) self, x, y, sample_weights=sample_weights)
else: else:
@ -1458,11 +1458,11 @@ class Model(Network):
""" """
x, _, _ = self._standardize_user_data(x) x, _, _ = self._standardize_user_data(x)
if context.in_eager_mode(): if context.executing_eagerly():
inputs = [ops.convert_to_tensor(val, dtype=K.floatx()) for val in x] inputs = [ops.convert_to_tensor(val, dtype=K.floatx()) for val in x]
return self(inputs) # pylint: disable=not-callable return self(inputs) # pylint: disable=not-callable
if context.in_graph_mode(): if not context.executing_eagerly():
if self.uses_learning_phase and not isinstance(K.learning_phase(), int): if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
ins = x + [0] ins = x + [0]
else: else:

View File

@ -553,7 +553,7 @@ class ZeroPaddingTest(test.TestCase):
layer = keras.layers.ZeroPadding1D(padding=2) layer = keras.layers.ZeroPadding1D(padding=2)
layer.build(shape) layer.build(shape)
output = layer(keras.backend.variable(inputs)) output = layer(keras.backend.variable(inputs))
if context.in_eager_mode(): if context.executing_eagerly():
np_output = output.numpy() np_output = output.numpy()
else: else:
np_output = keras.backend.eval(output) np_output = keras.backend.eval(output)
@ -564,7 +564,7 @@ class ZeroPaddingTest(test.TestCase):
layer = keras.layers.ZeroPadding1D(padding=(1, 2)) layer = keras.layers.ZeroPadding1D(padding=(1, 2))
layer.build(shape) layer.build(shape)
output = layer(keras.backend.variable(inputs)) output = layer(keras.backend.variable(inputs))
if context.in_eager_mode(): if context.executing_eagerly():
np_output = output.numpy() np_output = output.numpy()
else: else:
np_output = keras.backend.eval(output) np_output = keras.backend.eval(output)
@ -610,7 +610,7 @@ class ZeroPaddingTest(test.TestCase):
padding=(2, 2), data_format=data_format) padding=(2, 2), data_format=data_format)
layer.build(inputs.shape) layer.build(inputs.shape)
output = layer(keras.backend.variable(inputs)) output = layer(keras.backend.variable(inputs))
if context.in_eager_mode(): if context.executing_eagerly():
np_output = output.numpy() np_output = output.numpy()
else: else:
np_output = keras.backend.eval(output) np_output = keras.backend.eval(output)
@ -629,7 +629,7 @@ class ZeroPaddingTest(test.TestCase):
padding=((1, 2), (3, 4)), data_format=data_format) padding=((1, 2), (3, 4)), data_format=data_format)
layer.build(inputs.shape) layer.build(inputs.shape)
output = layer(keras.backend.variable(inputs)) output = layer(keras.backend.variable(inputs))
if context.in_eager_mode(): if context.executing_eagerly():
np_output = output.numpy() np_output = output.numpy()
else: else:
np_output = keras.backend.eval(output) np_output = keras.backend.eval(output)
@ -683,7 +683,7 @@ class ZeroPaddingTest(test.TestCase):
layer = keras.layers.ZeroPadding3D(padding=(2, 2, 2)) layer = keras.layers.ZeroPadding3D(padding=(2, 2, 2))
layer.build(inputs.shape) layer.build(inputs.shape)
output = layer(keras.backend.variable(inputs)) output = layer(keras.backend.variable(inputs))
if context.in_eager_mode(): if context.executing_eagerly():
np_output = output.numpy() np_output = output.numpy()
else: else:
np_output = keras.backend.eval(output) np_output = keras.backend.eval(output)
@ -737,7 +737,7 @@ class UpSamplingTest(test.TestCase):
size=(length_row, length_col), data_format=data_format) size=(length_row, length_col), data_format=data_format)
layer.build(inputs.shape) layer.build(inputs.shape)
output = layer(keras.backend.variable(inputs)) output = layer(keras.backend.variable(inputs))
if context.in_eager_mode(): if context.executing_eagerly():
np_output = output.numpy() np_output = output.numpy()
else: else:
np_output = keras.backend.eval(output) np_output = keras.backend.eval(output)
@ -790,7 +790,7 @@ class UpSamplingTest(test.TestCase):
data_format=data_format) data_format=data_format)
layer.build(inputs.shape) layer.build(inputs.shape)
output = layer(keras.backend.variable(inputs)) output = layer(keras.backend.variable(inputs))
if context.in_eager_mode(): if context.executing_eagerly():
np_output = output.numpy() np_output = output.numpy()
else: else:
np_output = keras.backend.eval(output) np_output = keras.backend.eval(output)
@ -865,7 +865,7 @@ class CroppingTest(test.TestCase):
cropping=cropping, data_format=data_format) cropping=cropping, data_format=data_format)
layer.build(inputs.shape) layer.build(inputs.shape)
output = layer(keras.backend.variable(inputs)) output = layer(keras.backend.variable(inputs))
if context.in_eager_mode(): if context.executing_eagerly():
np_output = output.numpy() np_output = output.numpy()
else: else:
np_output = keras.backend.eval(output) np_output = keras.backend.eval(output)
@ -892,7 +892,7 @@ class CroppingTest(test.TestCase):
cropping=cropping, data_format=data_format) cropping=cropping, data_format=data_format)
layer.build(inputs.shape) layer.build(inputs.shape)
output = layer(keras.backend.variable(inputs)) output = layer(keras.backend.variable(inputs))
if context.in_eager_mode(): if context.executing_eagerly():
np_output = output.numpy() np_output = output.numpy()
else: else:
np_output = keras.backend.eval(output) np_output = keras.backend.eval(output)
@ -937,7 +937,7 @@ class CroppingTest(test.TestCase):
cropping=cropping, data_format=data_format) cropping=cropping, data_format=data_format)
layer.build(inputs.shape) layer.build(inputs.shape)
output = layer(keras.backend.variable(inputs)) output = layer(keras.backend.variable(inputs))
if context.in_eager_mode(): if context.executing_eagerly():
np_output = output.numpy() np_output = output.numpy()
else: else:
np_output = keras.backend.eval(output) np_output = keras.backend.eval(output)

View File

@ -124,7 +124,7 @@ class Dropout(tf_core_layers.Dropout, Layer):
training = K.learning_phase() training = K.learning_phase()
output = super(Dropout, self).call(inputs, training=training) output = super(Dropout, self).call(inputs, training=training)
# EagerTensor object has no attribute _uses_learning_phase # EagerTensor object has no attribute _uses_learning_phase
if not context.in_eager_mode() and training is K.learning_phase(): if not context.executing_eagerly() and training is K.learning_phase():
output._uses_learning_phase = True # pylint: disable=protected-access output._uses_learning_phase = True # pylint: disable=protected-access
return output return output

View File

@ -111,7 +111,7 @@ class BatchNormalization(tf_normalization_layers.BatchNormalization, Layer):
if training is None: if training is None:
training = K.learning_phase() training = K.learning_phase()
output = super(BatchNormalization, self).call(inputs, training=training) output = super(BatchNormalization, self).call(inputs, training=training)
if context.in_graph_mode() and training is K.learning_phase(): if not context.executing_eagerly() and training is K.learning_phase():
output._uses_learning_phase = True # pylint: disable=protected-access output._uses_learning_phase = True # pylint: disable=protected-access
return output return output

View File

@ -105,7 +105,7 @@ class Pooling2DTest(test.TestCase):
# This part of the test can only run on GPU but doesn't appear # This part of the test can only run on GPU but doesn't appear
# to be properly assigned to a GPU when running in eager mode. # to be properly assigned to a GPU when running in eager mode.
if not context.in_eager_mode(): if not context.executing_eagerly():
# Only runs on GPU with CUDA, channels_first is not supported on CPU. # Only runs on GPU with CUDA, channels_first is not supported on CPU.
# TODO(b/62340061): Support channels_first on CPU. # TODO(b/62340061): Support channels_first on CPU.
if test.is_gpu_available(cuda_only=True): if test.is_gpu_available(cuda_only=True):

View File

@ -936,7 +936,7 @@ class SimpleRNNCell(Layer):
# Properly set learning phase on output tensor. # Properly set learning phase on output tensor.
if 0 < self.dropout + self.recurrent_dropout: if 0 < self.dropout + self.recurrent_dropout:
if training is None and not context.in_eager_mode(): if training is None and not context.executing_eagerly():
# This would be harmless to set in eager mode, but eager tensors # This would be harmless to set in eager mode, but eager tensors
# disallow setting arbitrary attributes. # disallow setting arbitrary attributes.
output._uses_learning_phase = True output._uses_learning_phase = True
@ -1384,7 +1384,7 @@ class GRUCell(Layer):
hh = self.activation(x_h + recurrent_h) hh = self.activation(x_h + recurrent_h)
h = z * h_tm1 + (1 - z) * hh h = z * h_tm1 + (1 - z) * hh
if 0 < self.dropout + self.recurrent_dropout: if 0 < self.dropout + self.recurrent_dropout:
if training is None and not context.in_eager_mode(): if training is None and not context.executing_eagerly():
# This would be harmless to set in eager mode, but eager tensors # This would be harmless to set in eager mode, but eager tensors
# disallow setting arbitrary attributes. # disallow setting arbitrary attributes.
h._uses_learning_phase = True h._uses_learning_phase = True
@ -1877,7 +1877,7 @@ class LSTMCell(Layer):
h = o * self.activation(c) h = o * self.activation(c)
if 0 < self.dropout + self.recurrent_dropout: if 0 < self.dropout + self.recurrent_dropout:
if training is None and not context.in_eager_mode(): if training is None and not context.executing_eagerly():
# This would be harmless to set in eager mode, but eager tensors # This would be harmless to set in eager mode, but eager tensors
# disallow setting arbitrary attributes. # disallow setting arbitrary attributes.
h._uses_learning_phase = True h._uses_learning_phase = True

View File

@ -83,14 +83,14 @@ class AtrousConvolutionTest(test.TestCase):
checks = [] checks = []
def add_check(check, *args, **kwargs): def add_check(check, *args, **kwargs):
if context.in_eager_mode(): if context.executing_eagerly():
args_val, kwargs_val = self.evaluate([args, kwargs]) args_val, kwargs_val = self.evaluate([args, kwargs])
check(*args_val, **kwargs_val) check(*args_val, **kwargs_val)
else: else:
checks.append((check, args, kwargs)) checks.append((check, args, kwargs))
yield add_check yield add_check
if context.in_graph_mode(): if not context.executing_eagerly():
all_values = self.evaluate([[args, kwargs] for _, args, kwargs in checks]) all_values = self.evaluate([[args, kwargs] for _, args, kwargs in checks])
for (check, _, _), (args, kwargs) in zip(checks, all_values): for (check, _, _), (args, kwargs) in zip(checks, all_values):
check(*args, **kwargs) check(*args, **kwargs)

View File

@ -102,14 +102,12 @@ class AssertEqualTest(test.TestCase):
with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"): with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"):
check_ops.assert_equal(static_big, static_small, message="fail") check_ops.assert_equal(static_big, static_small, message="fail")
# Dynamic check def test_raises_when_greater_dynamic(self):
if context.in_graph_mode():
with self.test_session(): with self.test_session():
small = array_ops.placeholder(dtypes.int32, name="small") small = array_ops.placeholder(dtypes.int32, name="small")
big = array_ops.placeholder(dtypes.int32, name="big") big = array_ops.placeholder(dtypes.int32, name="big")
with ops.control_dependencies( with ops.control_dependencies(
[check_ops.assert_equal( [check_ops.assert_equal(big, small, message="fail")]):
big, small, message="fail")]):
out = array_ops.identity(small) out = array_ops.identity(small)
with self.assertRaisesOpError("fail.*big.*small"): with self.assertRaisesOpError("fail.*big.*small"):
out.eval(feed_dict={small: [1, 2], big: [3, 4]}) out.eval(feed_dict={small: [1, 2], big: [3, 4]})
@ -182,8 +180,7 @@ First 2 elements of y:
with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"): with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"):
check_ops.assert_equal(static_big, static_small, message="fail") check_ops.assert_equal(static_big, static_small, message="fail")
# Dynamic check def test_raises_when_less_dynamic(self):
if context.in_graph_mode():
with self.test_session(): with self.test_session():
small = array_ops.placeholder(dtypes.int32, name="small") small = array_ops.placeholder(dtypes.int32, name="small")
big = array_ops.placeholder(dtypes.int32, name="big") big = array_ops.placeholder(dtypes.int32, name="big")

View File

@ -360,7 +360,7 @@ class PyFuncTest(test.TestCase):
raise py_exp("blah") # pylint: disable=not-callable raise py_exp("blah") # pylint: disable=not-callable
if eager: if eager:
if context.in_eager_mode(): if context.executing_eagerly():
with self.assertRaisesRegexp(tf_exp, "blah"): with self.assertRaisesRegexp(tf_exp, "blah"):
f = script_ops.eager_py_func(raise_exception, [], []) f = script_ops.eager_py_func(raise_exception, [], [])
return return
@ -432,7 +432,7 @@ class PyFuncTest(test.TestCase):
output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[]) output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[])
ret = self.evaluate(output) ret = self.evaluate(output)
if context.in_eager_mode(): if context.executing_eagerly():
self.assertEquals(len(ret), 0) self.assertEquals(len(ret), 0)
else: else:
self.assertIsNone(ret) self.assertIsNone(ret)

View File

@ -279,15 +279,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
# Tests for the 'read_value' argument: # Tests for the 'read_value' argument:
assign_with_read = v.assign(3.0, read_value=True) assign_with_read = v.assign(3.0, read_value=True)
if context.in_graph_mode():
self.assertEqual(3.0, assign_with_read.eval())
else:
self.assertEqual(3.0, self.evaluate(assign_with_read)) self.assertEqual(3.0, self.evaluate(assign_with_read))
assign_without_read = v.assign(4.0, read_value=False) assign_without_read = v.assign(4.0, read_value=False)
if context.in_graph_mode(): if context.executing_eagerly():
self.assertIsInstance(assign_without_read, ops.Operation)
else:
self.assertIsNone(assign_without_read) self.assertIsNone(assign_without_read)
else:
self.assertIsInstance(assign_without_read, ops.Operation)
self.evaluate(assign_without_read) self.evaluate(assign_without_read)
self.assertEqual(4.0, self.evaluate(v.value())) self.assertEqual(4.0, self.evaluate(v.value()))
@ -355,15 +352,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
# Tests for the 'read_value' argument: # Tests for the 'read_value' argument:
assign_with_read = v.assign_add(1.0, read_value=True) assign_with_read = v.assign_add(1.0, read_value=True)
if context.in_graph_mode():
self.assertEqual(3.0, assign_with_read.eval())
else:
self.assertEqual(3.0, self.evaluate(assign_with_read)) self.assertEqual(3.0, self.evaluate(assign_with_read))
assign_without_read = v.assign_add(1.0, read_value=False) assign_without_read = v.assign_add(1.0, read_value=False)
if context.in_graph_mode(): if context.executing_eagerly():
self.assertIsInstance(assign_without_read, ops.Operation)
else:
self.assertIsNone(assign_without_read) self.assertIsNone(assign_without_read)
else:
self.assertIsInstance(assign_without_read, ops.Operation)
self.evaluate(assign_without_read) self.evaluate(assign_without_read)
self.assertEqual(4.0, self.evaluate(v.value())) self.assertEqual(4.0, self.evaluate(v.value()))
@ -376,15 +370,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
# Tests for the 'read_value' argument: # Tests for the 'read_value' argument:
assign_with_read = v.assign_sub(1.0, read_value=True) assign_with_read = v.assign_sub(1.0, read_value=True)
if context.in_graph_mode():
self.assertEqual(1.0, assign_with_read.eval())
else:
self.assertEqual(1.0, self.evaluate(assign_with_read)) self.assertEqual(1.0, self.evaluate(assign_with_read))
assign_without_read = v.assign_sub(1.0, read_value=False) assign_without_read = v.assign_sub(1.0, read_value=False)
if context.in_graph_mode(): if context.executing_eagerly():
self.assertIsInstance(assign_without_read, ops.Operation)
else:
self.assertIsNone(assign_without_read) self.assertIsNone(assign_without_read)
else:
self.assertIsInstance(assign_without_read, ops.Operation)
self.evaluate(assign_without_read) self.evaluate(assign_without_read)
self.assertEqual(0.0, self.evaluate(v.value())) self.assertEqual(0.0, self.evaluate(v.value()))
@ -485,7 +476,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.assertEqual("(10, 20, 35)", str(v.get_shape())) self.assertEqual("(10, 20, 35)", str(v.get_shape()))
self.assertEqual("(10, 20, 35)", str(v.value().shape)) self.assertEqual("(10, 20, 35)", str(v.value().shape))
self.assertEqual("(3, 20, 35)", str(v.sparse_read([0, 1, 2]).shape)) self.assertEqual("(3, 20, 35)", str(v.sparse_read([0, 1, 2]).shape))
if context.in_graph_mode(): if not context.executing_eagerly():
self.assertEqual( self.assertEqual(
"<unknown>", "<unknown>",
str(v.sparse_read(array_ops.placeholder(dtypes.int32)).shape)) str(v.sparse_read(array_ops.placeholder(dtypes.int32)).shape))

View File

@ -111,10 +111,10 @@ class RNNTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testInvalidSequenceLengthShape(self): def testInvalidSequenceLengthShape(self):
cell = Plus1RNNCell() cell = Plus1RNNCell()
if context.in_graph_mode(): if context.executing_eagerly():
inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))]
else:
inputs = [constant_op.constant(np.ones((3, 4)))] inputs = [constant_op.constant(np.ones((3, 4)))]
else:
inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))]
with self.assertRaisesRegexp(ValueError, "must be a vector"): with self.assertRaisesRegexp(ValueError, "must be a vector"):
rnn.dynamic_rnn( rnn.dynamic_rnn(
cell, cell,
@ -125,38 +125,30 @@ class RNNTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testBatchSizeFromInput(self): def testBatchSizeFromInput(self):
cell = Plus1RNNCell() cell = Plus1RNNCell()
in_graph_mode = context.in_graph_mode() in_eager_mode = context.executing_eagerly()
# With static batch size # With static batch size
if in_graph_mode: if in_eager_mode:
inputs = array_ops.placeholder(dtypes.float32, shape=(3, 4, 5))
initial_state = array_ops.placeholder(dtypes.float32, shape=(3, 5))
else:
inputs = np.zeros((3, 4, 5), dtype=np.float32) inputs = np.zeros((3, 4, 5), dtype=np.float32)
initial_state = np.zeros((3, 5), dtype=np.float32) initial_state = np.zeros((3, 5), dtype=np.float32)
else:
inputs = array_ops.placeholder(dtypes.float32, shape=(3, 4, 5))
initial_state = array_ops.placeholder(dtypes.float32, shape=(3, 5))
# - Without initial_state # - Without initial_state
outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32)
if in_graph_mode:
self.assertEqual(3, outputs.shape[0].value)
self.assertEqual(3, state.shape[0].value)
else:
self.assertEqual(3, outputs.shape[0]) self.assertEqual(3, outputs.shape[0])
self.assertEqual(3, state.shape[0]) self.assertEqual(3, state.shape[0])
# - With initial_state # - With initial_state
outputs, state = rnn.dynamic_rnn( outputs, state = rnn.dynamic_rnn(
cell, inputs, initial_state=initial_state) cell, inputs, initial_state=initial_state)
if in_graph_mode:
self.assertEqual(3, outputs.shape[0].value)
self.assertEqual(3, state.shape[0].value)
else:
self.assertEqual(3, outputs.shape[0]) self.assertEqual(3, outputs.shape[0])
self.assertEqual(3, state.shape[0]) self.assertEqual(3, state.shape[0])
# Without static batch size # Without static batch size
# Tensor shapes are fully determined in Eager mode, so only run this # Tensor shapes are fully determined with eager execution enabled,
# test in graph mode. # so only run this test for graph construction.
if in_graph_mode: if not in_eager_mode:
inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 5)) inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 5))
# - Without initial_state # - Without initial_state
outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32)
@ -173,56 +165,46 @@ class RNNTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testScalarStateIsAccepted(self): def testScalarStateIsAccepted(self):
cell = ScalarStateRNNCell() cell = ScalarStateRNNCell()
in_graph_mode = context.in_graph_mode() in_eager_mode = context.executing_eagerly()
if in_graph_mode: if in_eager_mode:
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
else:
inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32) inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32)
else:
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
with self.test_session() as sess: with self.test_session() as sess:
outputs, state = rnn.dynamic_rnn( outputs, state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32, sequence_length=[4]) cell, inputs, dtype=dtypes.float32, sequence_length=[4])
if in_graph_mode: if not in_eager_mode:
outputs, state = sess.run( outputs, state = sess.run(
[outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]}) [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]})
if in_graph_mode: self.assertAllEqual([[[1], [2], [3], [4]]], outputs)
self.assertAllEqual(outputs, np.array([[[1], [2], [3], [4]]])) self.assertAllEqual(4, state)
self.assertEqual(state, 4)
else:
self.assertAllEqual(outputs.numpy(), np.array([[[1], [2], [3], [4]]]))
self.assertEqual(state.numpy(), 4)
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testTensorArrayStateIsAccepted(self): def testTensorArrayStateIsAccepted(self):
cell = TensorArrayStateRNNCell() cell = TensorArrayStateRNNCell()
in_graph_mode = context.in_graph_mode() in_eager_mode = context.executing_eagerly()
if in_graph_mode: if in_eager_mode:
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
else:
inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32) inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32)
else:
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
with self.test_session() as sess: with self.test_session() as sess:
outputs, state = rnn.dynamic_rnn( outputs, state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32, sequence_length=[4]) cell, inputs, dtype=dtypes.float32, sequence_length=[4])
state = (state[0], state[1].stack()) state = (state[0], state[1].stack())
if in_graph_mode: if not in_eager_mode:
outputs, state = sess.run( outputs, state = sess.run(
[outputs, state], feed_dict={ [outputs, state], feed_dict={
inputs: [[[1], [2], [3], [4]]] inputs: [[[1], [2], [3], [4]]]
}) })
if in_graph_mode: self.assertAllEqual([[[1], [2], [3], [4]]], outputs)
self.assertAllEqual(outputs, np.array([[[1], [2], [3], [4]]])) self.assertAllEqual(4, state[0])
self.assertEqual(state[0], 4) self.assertAllEqual([[[1]], [[2]], [[3]], [[4]]], state[1])
self.assertAllEqual(state[1], np.array([[[1]], [[2]], [[3]], [[4]]]))
else:
self.assertAllEqual(outputs.numpy(), np.array([[[1], [2], [3], [4]]]))
self.assertEqual(state[0].numpy(), 4)
self.assertAllEqual(state[1].numpy(),
np.array([[[1]], [[2]], [[3]], [[4]]]))
######### Benchmarking RNN code ######### Benchmarking RNN code

View File

@ -283,7 +283,7 @@ class SliceTest(test.TestCase):
# unintended behavior is prevented. # unintended behavior is prevented.
c = constant_op.constant(5.0) c = constant_op.constant(5.0)
with self.assertRaisesWithPredicateMatch( with self.assertRaisesWithPredicateMatch(
TypeError, lambda e: "`Tensor` objects are not iterable" in str(e)): TypeError, lambda e: "Tensor objects are not iterable" in str(e)):
for _ in c: for _ in c:
pass pass

View File

@ -562,7 +562,7 @@ class TemplateTest(test.TestCase):
outputs_b, _ = linear1(inputs) outputs_b, _ = linear1(inputs)
self.assertEquals("foo", linear1.variable_scope.name) self.assertEquals("foo", linear1.variable_scope.name)
self.assertEquals("foo/w:0", w1.name) self.assertEquals("foo/w:0", w1.name)
if context.in_graph_mode(): if not context.executing_eagerly():
self.assertEquals("foo/add:0", outputs_a.name, self.assertEquals("foo/add:0", outputs_a.name,
"First application of template should get " "First application of template should get "
"same name scope as variables.") "same name scope as variables.")
@ -577,7 +577,7 @@ class TemplateTest(test.TestCase):
"New template gets a freshly uniquified variable scope " "New template gets a freshly uniquified variable scope "
"because 'foo' is already taken.") "because 'foo' is already taken.")
self.assertEquals("foo_1/w:0", w2.name) self.assertEquals("foo_1/w:0", w2.name)
if context.in_graph_mode(): if not context.executing_eagerly():
self.assertEquals("foo_1_1/add:0", outputs_c.name, self.assertEquals("foo_1_1/add:0", outputs_c.name,
"First application of template would get " "First application of template would get "
"same name scope as variables, but 'foo_1' is already " "same name scope as variables, but 'foo_1' is already "
@ -592,7 +592,7 @@ class TemplateTest(test.TestCase):
with variable_scope.variable_scope("foo"): with variable_scope.variable_scope("foo"):
# Create two templates with the same name, ensure scopes are made unique. # Create two templates with the same name, ensure scopes are made unique.
ta = template.make_template("bar", variable_scoped_function, True) ta = template.make_template("bar", variable_scoped_function, True)
if context.in_eager_mode(): if context.executing_eagerly():
tb = template.make_template("s", function_with_side_create, tb = template.make_template("s", function_with_side_create,
trainable=False) trainable=False)
else: else:

View File

@ -399,27 +399,13 @@ class TensorArrayTest(test.TestCase):
def testTensorArrayWriteWrongIndexOrDataTypeFails(self): def testTensorArrayWriteWrongIndexOrDataTypeFails(self):
with self.test_session(use_gpu=True): with self.test_session(use_gpu=True):
ta = _make_ta(3, "foo", dtype=dtypes.float32) ta = _make_ta(3, "foo", dtype=dtypes.float32)
in_graph_mode = context.in_graph_mode()
# Test writing the wrong datatype # Test writing the wrong datatype
if in_graph_mode:
with self.assertRaisesOpError( with self.assertRaisesOpError(
"TensorArray dtype is float but Op is trying to write " "TensorArray dtype is (float|float32) but Op is trying to write "
"dtype string"):
self.evaluate(ta.write(0, "wrong_type_scalar").flow)
else:
with self.assertRaisesOpError(
"TensorArray dtype is float32 but Op is trying to write "
"dtype string"): "dtype string"):
self.evaluate(ta.write(0, "wrong_type_scalar").flow) self.evaluate(ta.write(0, "wrong_type_scalar").flow)
if context.in_graph_mode(): with self.assertRaisesOpError("index -1"):
with self.assertRaisesOpError(
"Tried to write to index -1 but array is not "
"resizeable and size is: 3"):
self.evaluate(ta.write(-1, 3.0).flow)
else:
with self.assertRaisesOpError(
r"Writing to negative indices \(index -1\) is not allowed."):
self.evaluate(ta.write(-1, 3.0).flow) self.evaluate(ta.write(-1, 3.0).flow)
# Test reading from too large an index # Test reading from too large an index
@ -435,8 +421,8 @@ class TensorArrayTest(test.TestCase):
w0 = ta.write(0, [[4.0, 5.0]]) w0 = ta.write(0, [[4.0, 5.0]])
# Test reading wrong datatype, which is only possible in graph mode # Test reading wrong datatype (only possible when constructing graphs).
if context.in_graph_mode(): if not context.executing_eagerly():
r0_bad = gen_data_flow_ops.tensor_array_read_v3( r0_bad = gen_data_flow_ops.tensor_array_read_v3(
handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow) handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow)
with self.assertRaisesOpError( with self.assertRaisesOpError(
@ -444,13 +430,7 @@ class TensorArrayTest(test.TestCase):
r0_bad.eval() r0_bad.eval()
# Test reading from a negative index, which is not allowed # Test reading from a negative index, which is not allowed
if context.in_graph_mode(): with self.assertRaisesOpError("index -1"):
with self.assertRaisesOpError(
r"Tried to read from index -1 but array size is: 3"):
self.evaluate(ta.read(-1))
else:
with self.assertRaisesOpError(
r"Reading from negative indices \(index -1\) is not allowed."):
self.evaluate(ta.read(-1)) self.evaluate(ta.read(-1))
# Test reading from too large an index # Test reading from too large an index
@ -467,10 +447,7 @@ class TensorArrayTest(test.TestCase):
with self.assertRaisesOpError( with self.assertRaisesOpError(
"Could not write to TensorArray index 2 because " "Could not write to TensorArray index 2 because "
"it has already been written to."): "it has already been written to."):
if context.in_graph_mode():
self.evaluate(ta.write(2, 3.0).write(2, 3.0).flow) self.evaluate(ta.write(2, 3.0).write(2, 3.0).flow)
else:
self.evaluate(ta.write(2, 3.0).write(2, 3.0))
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testTensorArrayConcatIncompatibleShapesFails(self): def testTensorArrayConcatIncompatibleShapesFails(self):
@ -499,58 +476,40 @@ class TensorArrayTest(test.TestCase):
w2 = w1.write(1, [4.0]) w2 = w1.write(1, [4.0])
w3 = w2.write(2, [[3.0]]) w3 = w2.write(2, [[3.0]])
# The eager-mode implementation just passes up array_op.concat's error # The exact error messages differ between eager execution and graph
# message. # construction as the former bubbles up the error from array_op.concat.
if context.in_graph_mode(): with self.assertRaisesOpError("shape"):
with self.assertRaisesOpError(
r"TensorArray has inconsistent shapes. Index 0 has "
r"\(excepting dimension 0\) shape: \[\] but index 2 has "
r"\(excepting dimension 0\) shape: \[1\]"):
self.evaluate(w3.concat())
else:
with self.assertRaisesOpError(
r".*Ranks of all input tensors should match: shape\[0\] "
r"= \[1\] vs\. shape\[2\] = \[1,1\].*"):
self.evaluate(w3.concat()) self.evaluate(w3.concat())
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testTensorArraySplitIncompatibleShapesFails(self): def testTensorArraySplitIncompatibleShapesFails(self):
with self.test_session(use_gpu=True): with self.test_session(use_gpu=True):
in_graph_mode = context.in_graph_mode() in_eager_mode = context.executing_eagerly()
ta = _make_ta(3, "foo") ta = _make_ta(3, "foo")
with self.assertRaisesOpError( with self.assertRaisesOpError(
r"Expected lengths to be a vector, received shape: \[\]"): r"Expected lengths to be a vector, received shape: \[\]"):
if in_graph_mode: if in_eager_mode:
self.evaluate(ta.split([1.0, 2.0, 3.0], 1))
else:
lengths = array_ops.placeholder(dtypes.int64) lengths = array_ops.placeholder(dtypes.int64)
ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1}) ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1})
else:
self.evaluate(ta.split([1.0, 2.0, 3.0], 1))
with self.assertRaisesOpError( with self.assertRaisesOpError(
r"Expected sum of lengths to be equal to values.shape\[0\], " r"Expected sum of lengths to be equal to values.shape\[0\], "
r"but sum of lengths is 1 and value's shape is: \[3\]"): r"but sum of lengths is 1 and value's shape is: \[3\]"):
if in_graph_mode:
self.evaluate(ta.split([1.0, 2.0, 3.0], [1]).flow) self.evaluate(ta.split([1.0, 2.0, 3.0], [1]).flow)
else:
self.evaluate(ta.split([1.0, 2.0, 3.0], [1]))
ta = _make_ta(1, "baz") ta = _make_ta(1, "baz")
with self.assertRaisesOpError( with self.assertRaisesOpError(
r"Expected value to be at least a vector, but received shape: \[\]"): r"Expected value to be at least a vector, but received shape: \[\]"):
if in_graph_mode:
self.evaluate(ta.split(1.0, [1]).flow) self.evaluate(ta.split(1.0, [1]).flow)
else:
self.evaluate(ta.split(1.0, [1]))
ta = _make_ta(2, "buz") ta = _make_ta(2, "buz")
with self.assertRaisesOpError( with self.assertRaisesOpError(
r"TensorArray's size is not equal to the size of lengths " r"TensorArray's size is not equal to the size of lengths "
r"\(2 vs. 1\), and the TensorArray is not marked as " r"\(2 vs. 1\), and the TensorArray is not marked as "
r"dynamically resizeable"): r"dynamically resizeable"):
if in_graph_mode:
self.evaluate(ta.split([1.0], [1]).flow) self.evaluate(ta.split([1.0], [1]).flow)
else:
self.evaluate(ta.split([1.0], [1]))
def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype): def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype):
with self.test_session(use_gpu=True): with self.test_session(use_gpu=True):
@ -868,14 +827,14 @@ class TensorArrayTest(test.TestCase):
vout = func(v0, state0, var) vout = func(v0, state0, var)
grad_val = -np.arange(3 * 5, dtype=np_dtype).reshape(3, 5) grad_val = -np.arange(3 * 5, dtype=np_dtype).reshape(3, 5)
if context.in_graph_mode(): if context.executing_eagerly():
grad_fn = backprop.gradients_function(func)
v0_grad, state0_grad, var_grad = grad_fn(v0, state0, var, dy=grad_val)
else:
v0_grad = gradients_impl.gradients([vout], [v0], [grad_val])[0] v0_grad = gradients_impl.gradients([vout], [v0], [grad_val])[0]
state0_grad = gradients_impl.gradients([vout], [state0], [grad_val])[0] state0_grad = gradients_impl.gradients([vout], [state0], [grad_val])[0]
var_grad = gradients_impl.gradients([vout], [var], [grad_val])[0] var_grad = gradients_impl.gradients([vout], [var], [grad_val])[0]
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
else:
grad_fn = backprop.gradients_function(func)
v0_grad, state0_grad, var_grad = grad_fn(v0, state0, var, dy=grad_val)
state0_t, var_t, v0_t, vout_t, v0_grad_t, var_grad_t, state0_grad_t = ( state0_t, var_t, v0_t, vout_t, v0_grad_t, var_grad_t, state0_grad_t = (
self.evaluate( self.evaluate(
@ -959,10 +918,10 @@ class TensorArrayTest(test.TestCase):
return r return r
x = constant_op.constant(2.0, name="x") x = constant_op.constant(2.0, name="x")
if context.in_graph_mode(): if context.executing_eagerly():
grad = gradients_impl.gradients(loop(x), [x])[0]
else:
grad = backprop.gradients_function(loop)(x)[0] grad = backprop.gradients_function(loop)(x)[0]
else:
grad = gradients_impl.gradients(loop(x), [x])[0]
self.assertAllClose(31.0, self.evaluate(grad)) self.assertAllClose(31.0, self.evaluate(grad))
def testSumOfTwoReadVariablesWithoutRepeatGrad(self): def testSumOfTwoReadVariablesWithoutRepeatGrad(self):
@ -1158,14 +1117,14 @@ class TensorArrayTest(test.TestCase):
infer_shape=True) infer_shape=True)
w0 = ta1.split(value, [1, 2]) w0 = ta1.split(value, [1, 2])
r0 = w0.read(0) r0 = w0.read(0)
if context.in_graph_mode(): if context.executing_eagerly():
self.assertEqual((1, 2), r0.get_shape())
self.assertEqual((2, 2), w0.read(1).get_shape())
else:
self.assertEqual(r0.get_shape().ndims, None) self.assertEqual(r0.get_shape().ndims, None)
self.assertEqual( self.assertEqual(
tensor_shape.TensorShape( tensor_shape.TensorShape(
ta1.handle.op.get_attr("element_shape")).ndims, None) ta1.handle.op.get_attr("element_shape")).ndims, None)
else:
self.assertEqual((1, 2), r0.get_shape())
self.assertEqual((2, 2), w0.read(1).get_shape())
def testWriteUnknownShape(self): def testWriteUnknownShape(self):
with self.test_session(use_gpu=True): with self.test_session(use_gpu=True):
@ -1297,13 +1256,13 @@ class TensorArrayTest(test.TestCase):
g = func(values) g = func(values)
grad_ys = [[[2.0, 3.0], [4.0, 5.0]]] grad_ys = [[[2.0, 3.0], [4.0, 5.0]]]
# Test combined gradients + aggregation of read(0) # Test combined gradients + aggregation of read(0)
if context.in_graph_mode(): if context.executing_eagerly():
grad = gradients_impl.gradients(ys=[g], xs=[values], grad_ys=grad_ys)
g_vals, grad_vals = session.run([[g], grad])
else:
g_vals = [g] g_vals = [g]
grad_vals = backprop.gradients_function(func)( grad_vals = backprop.gradients_function(func)(
values, dy=constant_op.constant(grad_ys[0], dtype=dtypes.float32)) values, dy=constant_op.constant(grad_ys[0], dtype=dtypes.float32))
else:
grad = gradients_impl.gradients(ys=[g], xs=[values], grad_ys=grad_ys)
g_vals, grad_vals = session.run([[g], grad])
# Gradients for 8 of the 10 unread components are zero. # Gradients for 8 of the 10 unread components are zero.
expected_grad = np.zeros((10, 2)) expected_grad = np.zeros((10, 2))
@ -1453,13 +1412,13 @@ class TensorArrayTest(test.TestCase):
# Tests correct properties on new TensorArrays. # Tests correct properties on new TensorArrays.
self.assertEqual(dtypes.float32, ta0.dtype) self.assertEqual(dtypes.float32, ta0.dtype)
self.assertEqual(dtypes.int32, ta1.dtype) self.assertEqual(dtypes.int32, ta1.dtype)
if context.in_graph_mode(): if context.executing_eagerly():
self.assertEqual(tensor_shape.unknown_shape(), read0.get_shape()) self.assertEqual(tensor_shape.scalar(), read0.get_shape())
else: else:
self.assertEqual(tensor_shape.scalar(), read1.get_shape()) self.assertEqual(tensor_shape.unknown_shape(), read0.get_shape())
self.assertEqual(tensor_shape.scalar(), read1.get_shape()) self.assertEqual(tensor_shape.scalar(), read1.get_shape())
if context.in_graph_mode(): if not context.executing_eagerly():
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
read0_v, read1_v, size0_v, size1_v = self.evaluate((read0, read1, size0, read0_v, read1_v, size0_v, size1_v = self.evaluate((read0, read1, size0,

View File

@ -166,11 +166,9 @@ class VariableScopeTest(test.TestCase):
self.evaluate(variables_lib.variables_initializer([w])) self.evaluate(variables_lib.variables_initializer([w]))
self.assertAllClose(self.evaluate(w.value()), [1, 2, 3]) self.assertAllClose(self.evaluate(w.value()), [1, 2, 3])
if context.in_graph_mode(): # A quirk to be revisited?
with self.assertRaises(TypeError): error = ValueError if context.executing_eagerly() else TypeError
variable_scope.get_variable("x4", initializer={}) with self.assertRaises(error):
else:
with self.assertRaises(ValueError):
variable_scope.get_variable("x4", initializer={}) variable_scope.get_variable("x4", initializer={})
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
@ -267,7 +265,7 @@ class VariableScopeTest(test.TestCase):
self.assertAllClose(self.evaluate(losses[2]), 0.5) self.assertAllClose(self.evaluate(losses[2]), 0.5)
with variable_scope.variable_scope("foo", reuse=True): with variable_scope.variable_scope("foo", reuse=True):
# reuse=True is for now only supported when eager execution is disabled. # reuse=True is for now only supported when eager execution is disabled.
if context.in_graph_mode(): if not context.executing_eagerly():
v = variable_scope.get_variable("v", v = variable_scope.get_variable("v",
[]) # "v" is alredy there, reused []) # "v" is alredy there, reused
losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
@ -374,7 +372,7 @@ class VariableScopeTest(test.TestCase):
v = variable_scope.get_variable("v", []) v = variable_scope.get_variable("v", [])
self.evaluate(variables_lib.variables_initializer([v])) self.evaluate(variables_lib.variables_initializer([v]))
self.assertAllClose(self.evaluate(v.value()), 0.3) self.assertAllClose(self.evaluate(v.value()), 0.3)
if context.in_graph_mode(): if not context.executing_eagerly():
# Check that we can set reuse. # Check that we can set reuse.
variable_scope.get_variable_scope().reuse_variables() variable_scope.get_variable_scope().reuse_variables()
with self.assertRaises(ValueError): # Fail, w does not exist yet. with self.assertRaises(ValueError): # Fail, w does not exist yet.
@ -408,7 +406,7 @@ class VariableScopeTest(test.TestCase):
with variable_scope.variable_scope("tower") as tower: with variable_scope.variable_scope("tower") as tower:
with ops.name_scope("scope2") as sc2: with ops.name_scope("scope2") as sc2:
self.assertEqual(sc2, "testVarScopeNameScope1/tower/scope2/") self.assertEqual(sc2, "testVarScopeNameScope1/tower/scope2/")
if context.in_graph_mode(): if not context.executing_eagerly():
with variable_scope.variable_scope( with variable_scope.variable_scope(
tower): # Re-entering acts like another "tower". tower): # Re-entering acts like another "tower".
with ops.name_scope("scope2") as sc2: with ops.name_scope("scope2") as sc2:
@ -422,7 +420,7 @@ class VariableScopeTest(test.TestCase):
with variable_scope.variable_scope("tower"): with variable_scope.variable_scope("tower"):
with ops.name_scope("scope2") as sc2: with ops.name_scope("scope2") as sc2:
self.assertEqual(sc2, "testVarScopeNameScope2/tower/scope2/") self.assertEqual(sc2, "testVarScopeNameScope2/tower/scope2/")
if context.in_graph_mode(): if not context.executing_eagerly():
with variable_scope.variable_scope(tower): with variable_scope.variable_scope(tower):
with ops.name_scope("scope2") as sc2: with ops.name_scope("scope2") as sc2:
self.assertEqual(sc2, "testVarScopeNameScope2/tower_1/scope2/") self.assertEqual(sc2, "testVarScopeNameScope2/tower_1/scope2/")
@ -903,17 +901,15 @@ class VariableScopeTest(test.TestCase):
"w", [], collections=["foo"]) "w", [], collections=["foo"])
self.assertEqual(local_var.name, "outer/w:0") self.assertEqual(local_var.name, "outer/w:0")
if not context.executing_eagerly():
# Since variable is local, it should be in the local variable collection # Since variable is local, it should be in the local variable collection
# but not the trainable collection. # but not the trainable collection.
if context.in_graph_mode():
self.assertIn(local_var, self.assertIn(local_var,
ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
self.assertIn(local_var, ops.get_collection("foo")) self.assertIn(local_var, ops.get_collection("foo"))
self.assertNotIn(local_var, self.assertNotIn(local_var,
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
# Check that local variable respects `reuse`. # Check that local variable respects `reuse`.
if context.in_graph_mode():
with variable_scope.variable_scope(outer, "default", reuse=True): with variable_scope.variable_scope(outer, "default", reuse=True):
self.assertEqual( self.assertEqual(
variable_scope.get_local_variable("w", []).name, "outer/w:0") variable_scope.get_local_variable("w", []).name, "outer/w:0")

View File

@ -115,7 +115,7 @@ class Layer(checkpointable.CheckpointableBase):
# Provides information about which inputs are compatible with the layer. # Provides information about which inputs are compatible with the layer.
self.input_spec = None self.input_spec = None
if activity_regularizer and context.in_eager_mode(): if activity_regularizer and context.executing_eagerly():
raise ValueError( raise ValueError(
('Activity regularization is not supported when executing eagerly. ' ('Activity regularization is not supported when executing eagerly. '
'Got activity_regularizer=%s') % (activity_regularizer,)) 'Got activity_regularizer=%s') % (activity_regularizer,))
@ -228,7 +228,7 @@ class Layer(checkpointable.CheckpointableBase):
@property @property
def updates(self): def updates(self):
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('Layer.updates not supported in Eager mode.') raise RuntimeError('Layer.updates not supported in Eager mode.')
if not self.trainable and not self.stateful: if not self.trainable and not self.stateful:
return [] return []
@ -260,7 +260,7 @@ class Layer(checkpointable.CheckpointableBase):
have is available at runtime. have is available at runtime.
A step counter might fall into this category. A step counter might fall into this category.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
return # Updates already applied when in eager mode. return # Updates already applied when in eager mode.
updates = _to_list(updates) updates = _to_list(updates)
@ -286,7 +286,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises: Raises:
RuntimeError: If called in Eager mode. RuntimeError: If called in Eager mode.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('`get_updates_for()` not supported in Eager mode.') raise RuntimeError('`get_updates_for()` not supported in Eager mode.')
# Updates disabled if layer is not trainable and not explicitly stateful. # Updates disabled if layer is not trainable and not explicitly stateful.
@ -317,7 +317,7 @@ class Layer(checkpointable.CheckpointableBase):
Returns: Returns:
A list of tensors. A list of tensors.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
# _losses may only contain variable regularization losses when executing # _losses may only contain variable regularization losses when executing
# eagerly, and they have been saved as lambdas to be executed when # eagerly, and they have been saved as lambdas to be executed when
# requested. # requested.
@ -355,7 +355,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises: Raises:
RuntimeError: If called in Eager mode. RuntimeError: If called in Eager mode.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
# TODO(fchollet): it should be possible (and highly desirable) to support # TODO(fchollet): it should be possible (and highly desirable) to support
# `add_loss` in eager mode. This allows great convenience and flexibility # `add_loss` in eager mode. This allows great convenience and flexibility
# in defining custom losses on the fly (e.g. in VAEs). # in defining custom losses on the fly (e.g. in VAEs).
@ -389,7 +389,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises: Raises:
RuntimeError: If called in Eager mode. RuntimeError: If called in Eager mode.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('Layer.get_losses_for not supported in Eager mode.') raise RuntimeError('Layer.get_losses_for not supported in Eager mode.')
if inputs is None: if inputs is None:
@ -509,7 +509,7 @@ class Layer(checkpointable.CheckpointableBase):
# will occur; it should be None if and only if initialization will take # will occur; it should be None if and only if initialization will take
# place in the eager context. # place in the eager context.
init_graph = None init_graph = None
if context.in_graph_mode(): if not context.executing_eagerly():
default_graph = ops.get_default_graph() default_graph = ops.get_default_graph()
if default_graph.building_function: if default_graph.building_function:
with ops.init_scope(): with ops.init_scope():
@ -517,7 +517,7 @@ class Layer(checkpointable.CheckpointableBase):
# will be lifted; if initialization ops will be lifted into # will be lifted; if initialization ops will be lifted into
# the eager context, then there is nothing to retrieve, since variable # the eager context, then there is nothing to retrieve, since variable
# collections are not supported when eager execution is enabled. # collections are not supported when eager execution is enabled.
if context.in_graph_mode(): if not context.executing_eagerly():
init_graph = ops.get_default_graph() init_graph = ops.get_default_graph()
existing_variables = set(tf_variables.global_variables()) existing_variables = set(tf_variables.global_variables())
else: else:
@ -624,17 +624,17 @@ class Layer(checkpointable.CheckpointableBase):
self._set_scope(kwargs.pop('scope', None)) self._set_scope(kwargs.pop('scope', None))
input_list = nest.flatten(inputs) input_list = nest.flatten(inputs)
in_graph_mode = context.in_graph_mode() build_graph = not context.executing_eagerly()
in_deferred_mode = isinstance(input_list[0], _DeferredTensor) in_deferred_mode = isinstance(input_list[0], _DeferredTensor)
# Ensure the Layer, if being reused, is working with inputs from # Ensure the Layer, if being reused, is working with inputs from
# the same graph as where it was created. # the same graph as where it was created.
if in_graph_mode: if build_graph:
try: try:
# Set layer's "graph" at build time # Set layer's "graph" at build time
self._graph = ops._get_graph_from_inputs(input_list, graph=self._graph) # pylint: disable=protected-access self._graph = ops._get_graph_from_inputs(input_list, graph=self._graph) # pylint: disable=protected-access
except ValueError as e: except ValueError as e:
raise ValueError('Input graph and Layer graph are not the same: %s' % e) raise ValueError('Input graph and Layer graph are not the same: %s' % e)
if in_graph_mode or in_deferred_mode: if build_graph or in_deferred_mode:
user_kwargs = copy.copy(kwargs) user_kwargs = copy.copy(kwargs)
# Handle Keras mask propagation from previous layer to current layer. # Handle Keras mask propagation from previous layer to current layer.
@ -669,13 +669,14 @@ class Layer(checkpointable.CheckpointableBase):
with scope_context_manager as scope: with scope_context_manager as scope:
with ops.name_scope(self._name_scope_name(scope)): with ops.name_scope(self._name_scope_name(scope)):
if not self.built: if not self.built:
if not in_graph_mode: if not build_graph:
# Activity regularization is currently unsupported in Eager mode. # Activity regularization is currently unsupported in Eager mode.
if self._activity_regularizer: if self._activity_regularizer:
raise ValueError('activity_regularizer currently unsupported in ' raise ValueError(
'Eager mode. Found an activity_regularizer in ' 'activity_regularizer currently unsupported with '
'eager execution enabled. Found an activity_regularizer in '
'%s(%s).' % (self.__class__.__name__, self)) '%s(%s).' % (self.__class__.__name__, self))
if not in_graph_mode and not in_deferred_mode: if not build_graph and not in_deferred_mode:
# TODO(agarwal): support _keras_history in Eager mode. # TODO(agarwal): support _keras_history in Eager mode.
for x in input_list: for x in input_list:
if hasattr(x, '_keras_history'): if hasattr(x, '_keras_history'):
@ -706,7 +707,7 @@ class Layer(checkpointable.CheckpointableBase):
if call_has_scope_arg: if call_has_scope_arg:
kwargs['scope'] = scope kwargs['scope'] = scope
# Check input assumptions set after layer building, e.g. input shape. # Check input assumptions set after layer building, e.g. input shape.
if in_graph_mode or in_deferred_mode: if build_graph or in_deferred_mode:
self._assert_input_compatibility(inputs) self._assert_input_compatibility(inputs)
if not in_deferred_mode: if not in_deferred_mode:
@ -730,7 +731,7 @@ class Layer(checkpointable.CheckpointableBase):
if len(outputs) == 1: if len(outputs) == 1:
outputs = outputs[0] outputs = outputs[0]
if in_graph_mode: if build_graph:
# Apply activity regularization. # Apply activity regularization.
# Note that it should be applied every time the layer creates a new # Note that it should be applied every time the layer creates a new
# output, since it is output-specific. # output, since it is output-specific.
@ -752,7 +753,7 @@ class Layer(checkpointable.CheckpointableBase):
else: else:
outputs._keras_mask = output_mask # pylint: disable=protected-access outputs._keras_mask = output_mask # pylint: disable=protected-access
if in_graph_mode: if build_graph:
# If all input tensors have history metadata, # If all input tensors have history metadata,
# we update the output tensors # we update the output tensors
# with corresponding history metadata, thus eventually allowing to use # with corresponding history metadata, thus eventually allowing to use
@ -775,7 +776,7 @@ class Layer(checkpointable.CheckpointableBase):
# Update global default collections. # Update global default collections.
_add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
if in_deferred_mode or in_graph_mode: if in_deferred_mode or build_graph:
if _have_all_keras_metadata(inputs): if _have_all_keras_metadata(inputs):
# Add an inbound node to the layer, so it can keep track of this call. # Add an inbound node to the layer, so it can keep track of this call.
# This updates the layer history of the output tensor(s). # This updates the layer history of the output tensor(s).
@ -787,7 +788,7 @@ class Layer(checkpointable.CheckpointableBase):
@property @property
def graph(self): def graph(self):
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('Layer.graph not supported in Eager mode.') raise RuntimeError('Layer.graph not supported in Eager mode.')
return self._graph return self._graph
@ -891,7 +892,7 @@ class Layer(checkpointable.CheckpointableBase):
mode. mode.
ValueError: If the index provided does not match any node. ValueError: If the index provided does not match any node.
""" """
assert context.in_graph_mode() assert not context.executing_eagerly()
if not self._inbound_nodes: if not self._inbound_nodes:
raise RuntimeError('The layer has never been called ' raise RuntimeError('The layer has never been called '
'and thus has no defined ' + attr_name + '.') 'and thus has no defined ' + attr_name + '.')
@ -921,7 +922,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises: Raises:
RuntimeError: If called in Eager mode. RuntimeError: If called in Eager mode.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError( raise RuntimeError(
'Layer.get_input_shape_at not supported in Eager mode.') 'Layer.get_input_shape_at not supported in Eager mode.')
return self._get_node_attribute_at_index(node_index, 'input_shapes', return self._get_node_attribute_at_index(node_index, 'input_shapes',
@ -943,7 +944,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises: Raises:
RuntimeError: If called in Eager mode. RuntimeError: If called in Eager mode.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError( raise RuntimeError(
'Layer.get_output_shape_at not supported in Eager mode.') 'Layer.get_output_shape_at not supported in Eager mode.')
return self._get_node_attribute_at_index(node_index, 'output_shapes', return self._get_node_attribute_at_index(node_index, 'output_shapes',
@ -964,7 +965,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises: Raises:
RuntimeError: If called in Eager mode. RuntimeError: If called in Eager mode.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('Layer.get_input_at not supported in Eager mode.') raise RuntimeError('Layer.get_input_at not supported in Eager mode.')
return self._get_node_attribute_at_index(node_index, 'input_tensors', return self._get_node_attribute_at_index(node_index, 'input_tensors',
'input') 'input')
@ -984,7 +985,7 @@ class Layer(checkpointable.CheckpointableBase):
Raises: Raises:
RuntimeError: If called in Eager mode. RuntimeError: If called in Eager mode.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('Layer.get_output_at not supported in Eager mode.') raise RuntimeError('Layer.get_output_at not supported in Eager mode.')
return self._get_node_attribute_at_index(node_index, 'output_tensors', return self._get_node_attribute_at_index(node_index, 'output_tensors',
'output') 'output')
@ -1007,7 +1008,7 @@ class Layer(checkpointable.CheckpointableBase):
RuntimeError: If called in Eager mode. RuntimeError: If called in Eager mode.
AttributeError: If no inbound nodes are found. AttributeError: If no inbound nodes are found.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('Layer.input not supported in Eager mode.') raise RuntimeError('Layer.input not supported in Eager mode.')
if not self._inbound_nodes: if not self._inbound_nodes:
raise AttributeError('Layer ' + self.name + raise AttributeError('Layer ' + self.name +
@ -1029,7 +1030,7 @@ class Layer(checkpointable.CheckpointableBase):
layers. layers.
RuntimeError: if called in Eager mode. RuntimeError: if called in Eager mode.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('Layer.output not supported in Eager mode.') raise RuntimeError('Layer.output not supported in Eager mode.')
if not self._inbound_nodes: if not self._inbound_nodes:
raise AttributeError('Layer ' + self.name + ' has no inbound nodes.') raise AttributeError('Layer ' + self.name + ' has no inbound nodes.')
@ -1051,7 +1052,7 @@ class Layer(checkpointable.CheckpointableBase):
AttributeError: if the layer has no defined input_shape. AttributeError: if the layer has no defined input_shape.
RuntimeError: if called in Eager mode. RuntimeError: if called in Eager mode.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('Layer.input_shape not supported in Eager mode.') raise RuntimeError('Layer.input_shape not supported in Eager mode.')
if not self._inbound_nodes: if not self._inbound_nodes:
raise AttributeError('The layer has never been called ' raise AttributeError('The layer has never been called '
@ -1112,7 +1113,7 @@ class Layer(checkpointable.CheckpointableBase):
AttributeError: if the layer has no defined output shape. AttributeError: if the layer has no defined output shape.
RuntimeError: if called in Eager mode. RuntimeError: if called in Eager mode.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('Layer.output_shape not supported in Eager mode.') raise RuntimeError('Layer.output_shape not supported in Eager mode.')
if not self._inbound_nodes: if not self._inbound_nodes:
raise AttributeError('The layer has never been called ' raise AttributeError('The layer has never been called '
@ -1470,7 +1471,7 @@ def _to_list(x):
def _add_elements_to_collection(elements, collection_list): def _add_elements_to_collection(elements, collection_list):
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('Using collections from Layers not supported in Eager ' raise RuntimeError('Using collections from Layers not supported in Eager '
'mode. Tried to add %s to %s' % (elements, 'mode. Tried to add %s to %s' % (elements,
collection_list)) collection_list))

View File

@ -44,7 +44,7 @@ class BaseLayerTest(test.TestCase):
self.assertEqual(layer.variables, []) self.assertEqual(layer.variables, [])
self.assertEqual(layer.trainable_variables, []) self.assertEqual(layer.trainable_variables, [])
self.assertEqual(layer.non_trainable_variables, []) self.assertEqual(layer.non_trainable_variables, [])
if context.in_graph_mode(): if not context.executing_eagerly():
# updates, losses only supported in GRAPH mode # updates, losses only supported in GRAPH mode
self.assertEqual(layer.updates, []) self.assertEqual(layer.updates, [])
self.assertEqual(layer.losses, []) self.assertEqual(layer.losses, [])
@ -63,7 +63,7 @@ class BaseLayerTest(test.TestCase):
self.assertEqual(layer.variables, [variable]) self.assertEqual(layer.variables, [variable])
self.assertEqual(layer.trainable_variables, [variable]) self.assertEqual(layer.trainable_variables, [variable])
self.assertEqual(layer.non_trainable_variables, []) self.assertEqual(layer.non_trainable_variables, [])
if context.in_graph_mode(): if not context.executing_eagerly():
self.assertEqual( self.assertEqual(
layer.variables, layer.variables,
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
@ -77,7 +77,7 @@ class BaseLayerTest(test.TestCase):
self.assertEqual(layer.variables, [variable, variable_2]) self.assertEqual(layer.variables, [variable, variable_2])
self.assertEqual(layer.trainable_variables, [variable]) self.assertEqual(layer.trainable_variables, [variable])
self.assertEqual(layer.non_trainable_variables, [variable_2]) self.assertEqual(layer.non_trainable_variables, [variable_2])
if context.in_graph_mode(): if not context.executing_eagerly():
self.assertEqual( self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1) len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
@ -161,7 +161,7 @@ class BaseLayerTest(test.TestCase):
inputs = random_ops.random_uniform((5,), seed=1) inputs = random_ops.random_uniform((5,), seed=1)
outputs = layer.apply(inputs) outputs = layer.apply(inputs)
self.assertEqual(layer.built, True) self.assertEqual(layer.built, True)
if context.in_graph_mode(): if not context.executing_eagerly():
# op is only supported in GRAPH mode # op is only supported in GRAPH mode
self.assertEqual(outputs.op.name, 'my_layer/Square') self.assertEqual(outputs.op.name, 'my_layer/Square')
@ -210,7 +210,7 @@ class BaseLayerTest(test.TestCase):
inputs = random_ops.random_uniform((5,), seed=1) inputs = random_ops.random_uniform((5,), seed=1)
outputs = layer.apply(inputs) outputs = layer.apply(inputs)
self.assertEqual(layer.built, True) self.assertEqual(layer.built, True)
if context.in_graph_mode(): if not context.executing_eagerly():
# op only supported in GRAPH mode. # op only supported in GRAPH mode.
self.assertEqual(outputs.op.name, 'my_layer/Square') self.assertEqual(outputs.op.name, 'my_layer/Square')
@ -280,7 +280,7 @@ class BaseLayerTest(test.TestCase):
def call(self, inputs): def call(self, inputs):
return inputs return inputs
if context.in_graph_mode(): if not context.executing_eagerly():
layer = CustomerLayer() layer = CustomerLayer()
with self.assertRaisesRegexp(ValueError, r'requires a defined rank'): with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
layer.apply(array_ops.placeholder('int32')) layer.apply(array_ops.placeholder('int32'))
@ -307,7 +307,7 @@ class BaseLayerTest(test.TestCase):
def call(self, inputs): def call(self, inputs):
return inputs return inputs
if context.in_graph_mode(): if not context.executing_eagerly():
layer = CustomerLayer() layer = CustomerLayer()
with self.assertRaisesRegexp(ValueError, r'requires a defined rank'): with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
layer.apply(array_ops.placeholder('int32')) layer.apply(array_ops.placeholder('int32'))
@ -335,7 +335,7 @@ class BaseLayerTest(test.TestCase):
def call(self, inputs): def call(self, inputs):
return inputs return inputs
if context.in_graph_mode(): if not context.executing_eagerly():
layer = CustomerLayer() layer = CustomerLayer()
with self.assertRaisesRegexp(ValueError, r'requires a defined rank'): with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
layer.apply(array_ops.placeholder('int32')) layer.apply(array_ops.placeholder('int32'))
@ -430,7 +430,7 @@ class BaseLayerTest(test.TestCase):
layer.apply(constant_op.constant(1)) layer.apply(constant_op.constant(1))
# Works # Works
if context.in_graph_mode(): if not context.executing_eagerly():
layer.apply(array_ops.placeholder('int32')) layer.apply(array_ops.placeholder('int32'))
layer.apply(array_ops.placeholder('int32', shape=(2, 3))) layer.apply(array_ops.placeholder('int32', shape=(2, 3)))
@ -453,13 +453,7 @@ class BaseLayerTest(test.TestCase):
return {'l' + key: inputs[key] for key in inputs} return {'l' + key: inputs[key] for key in inputs}
layer = DictLayer() layer = DictLayer()
if context.in_graph_mode(): if context.executing_eagerly():
i1 = array_ops.placeholder('int32')
i2 = array_ops.placeholder('float32')
result = layer.apply({'abel': i1, 'ogits': i2})
self.assertTrue(isinstance(result, dict))
self.assertEqual(set(['label', 'logits']), set(result.keys()))
else:
i1 = constant_op.constant(3) i1 = constant_op.constant(3)
i2 = constant_op.constant(4.0) i2 = constant_op.constant(4.0)
result = layer.apply({'abel': i1, 'ogits': i2}) result = layer.apply({'abel': i1, 'ogits': i2})
@ -467,6 +461,12 @@ class BaseLayerTest(test.TestCase):
self.assertEqual(set(['label', 'logits']), set(result.keys())) self.assertEqual(set(['label', 'logits']), set(result.keys()))
self.assertEqual(3, result['label'].numpy()) self.assertEqual(3, result['label'].numpy())
self.assertEqual(4.0, result['logits'].numpy()) self.assertEqual(4.0, result['logits'].numpy())
else:
i1 = array_ops.placeholder('int32')
i2 = array_ops.placeholder('float32')
result = layer.apply({'abel': i1, 'ogits': i2})
self.assertTrue(isinstance(result, dict))
self.assertEqual(set(['label', 'logits']), set(result.keys()))
def testActivityRegularizer(self): def testActivityRegularizer(self):
regularizer = math_ops.reduce_sum regularizer = math_ops.reduce_sum

View File

@ -1664,7 +1664,7 @@ class Conv2DTranspose(Conv2D):
padding=self.padding.upper(), padding=self.padding.upper(),
data_format=utils.convert_data_format(self.data_format, ndim=4)) data_format=utils.convert_data_format(self.data_format, ndim=4))
if context.in_graph_mode(): if not context.executing_eagerly():
# Infer the static output shape: # Infer the static output shape:
out_shape = inputs.get_shape().as_list() out_shape = inputs.get_shape().as_list()
out_shape[c_axis] = self.filters out_shape[c_axis] = self.filters
@ -1969,7 +1969,7 @@ class Conv3DTranspose(Conv3D):
data_format=utils.convert_data_format(self.data_format, ndim=5), data_format=utils.convert_data_format(self.data_format, ndim=5),
padding=self.padding.upper()) padding=self.padding.upper())
if context.in_graph_mode(): if not context.executing_eagerly():
# Infer the static output shape: # Infer the static output shape:
out_shape = inputs.get_shape().as_list() out_shape = inputs.get_shape().as_list()
out_shape[c_axis] = self.filters out_shape[c_axis] = self.filters

View File

@ -156,7 +156,7 @@ class Dense(base.Layer):
outputs = standard_ops.tensordot(inputs, self.kernel, [[len(shape) - 1], outputs = standard_ops.tensordot(inputs, self.kernel, [[len(shape) - 1],
[0]]) [0]])
# Reshape the output back to the original ndim of the input. # Reshape the output back to the original ndim of the input.
if context.in_graph_mode(): if not context.executing_eagerly():
output_shape = shape[:-1] + [self.units] output_shape = shape[:-1] + [self.units]
outputs.set_shape(output_shape) outputs.set_shape(output_shape)
else: else:
@ -374,7 +374,7 @@ class Flatten(base.Layer):
def call(self, inputs): def call(self, inputs):
outputs = array_ops.reshape(inputs, (array_ops.shape(inputs)[0], -1)) outputs = array_ops.reshape(inputs, (array_ops.shape(inputs)[0], -1))
if context.in_graph_mode(): if not context.executing_eagerly():
outputs.set_shape(self.compute_output_shape(inputs.get_shape())) outputs.set_shape(self.compute_output_shape(inputs.get_shape()))
return outputs return outputs

View File

@ -77,7 +77,7 @@ class DenseTest(test.TestCase):
self.assertListEqual(dense.trainable_variables, self.assertListEqual(dense.trainable_variables,
[dense.kernel, dense.bias]) [dense.kernel, dense.bias])
self.assertListEqual(dense.non_trainable_variables, []) self.assertListEqual(dense.non_trainable_variables, [])
if context.in_graph_mode(): if not context.executing_eagerly():
self.assertEqual( self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2) len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2)
self.assertEqual(dense.kernel.name, 'my_dense/kernel:0') self.assertEqual(dense.kernel.name, 'my_dense/kernel:0')
@ -98,7 +98,7 @@ class DenseTest(test.TestCase):
self.assertListEqual(dense.variables, [dense.kernel]) self.assertListEqual(dense.variables, [dense.kernel])
self.assertListEqual(dense.trainable_variables, [dense.kernel]) self.assertListEqual(dense.trainable_variables, [dense.kernel])
self.assertListEqual(dense.non_trainable_variables, []) self.assertListEqual(dense.non_trainable_variables, [])
if context.in_graph_mode(): if not context.executing_eagerly():
self.assertEqual( self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1) len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
self.assertEqual(dense.kernel.name, 'my_dense/kernel:0') self.assertEqual(dense.kernel.name, 'my_dense/kernel:0')
@ -113,7 +113,7 @@ class DenseTest(test.TestCase):
self.assertListEqual(dense.non_trainable_variables, self.assertListEqual(dense.non_trainable_variables,
[dense.kernel, dense.bias]) [dense.kernel, dense.bias])
self.assertListEqual(dense.trainable_variables, []) self.assertListEqual(dense.trainable_variables, [])
if context.in_graph_mode(): if not context.executing_eagerly():
self.assertEqual( self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 0) len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 0)
@ -162,13 +162,13 @@ class DenseTest(test.TestCase):
dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1') dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1')
inputs = random_ops.random_uniform((5, 3), seed=1) inputs = random_ops.random_uniform((5, 3), seed=1)
outputs = dense(inputs) outputs = dense(inputs)
if context.in_graph_mode(): if not context.executing_eagerly():
self.assertEqual(outputs.op.name, 'dense1/Relu') self.assertEqual(outputs.op.name, 'dense1/Relu')
dense = core_layers.Dense(2, name='dense2') dense = core_layers.Dense(2, name='dense2')
inputs = random_ops.random_uniform((5, 3), seed=1) inputs = random_ops.random_uniform((5, 3), seed=1)
outputs = dense(inputs) outputs = dense(inputs)
if context.in_graph_mode(): if not context.executing_eagerly():
self.assertEqual(outputs.op.name, 'dense2/BiasAdd') self.assertEqual(outputs.op.name, 'dense2/BiasAdd')
def testActivityRegularizer(self): def testActivityRegularizer(self):
@ -374,7 +374,7 @@ class DropoutTest(test.TestCase):
dp = core_layers.Dropout(0.5) dp = core_layers.Dropout(0.5)
inputs = array_ops.ones((5, 3)) inputs = array_ops.ones((5, 3))
dropped = dp.apply(inputs, training=True) dropped = dp.apply(inputs, training=True)
if context.in_graph_mode(): if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer()) self.evaluate(variables.global_variables_initializer())
np_output = self.evaluate(dropped) np_output = self.evaluate(dropped)
self.assertAlmostEqual(0., np_output.min()) self.assertAlmostEqual(0., np_output.min())

View File

@ -338,8 +338,9 @@ class BatchNormalization(base.Layer):
return var return var
with ops.device(None): with ops.device(None):
device = ((lambda _: self.moving_mean.device) device = (
if context.in_graph_mode() else self.moving_mean.device) self.moving_mean.device if context.executing_eagerly() else
(lambda _: self.moving_mean.device))
with ops.device(device): with ops.device(device):
self.renorm_mean = _renorm_variable('renorm_mean', param_shape) self.renorm_mean = _renorm_variable('renorm_mean', param_shape)
self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ()) self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
@ -347,8 +348,9 @@ class BatchNormalization(base.Layer):
# renorm_stddev_weight. This allows us to (1) mix the average # renorm_stddev_weight. This allows us to (1) mix the average
# stddev with the minibatch stddev early in training, and (2) compute # stddev with the minibatch stddev early in training, and (2) compute
# the unbiased average stddev by dividing renorm_stddev by the weight. # the unbiased average stddev by dividing renorm_stddev by the weight.
device = ((lambda _: self.moving_variance.device) device = (
if context.in_graph_mode() else self.moving_variance.device) self.moving_variance.device if context.executing_eagerly() else
(lambda _: self.moving_variance.device))
with ops.device(device): with ops.device(device):
self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape) self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape)
self.renorm_stddev_weight = _renorm_variable( self.renorm_stddev_weight = _renorm_variable(
@ -420,7 +422,7 @@ class BatchNormalization(base.Layer):
one_minus_decay) one_minus_decay)
variance_update = self._assign_moving_average(self.moving_variance, variance_update = self._assign_moving_average(self.moving_variance,
variance, one_minus_decay) variance, one_minus_decay)
if context.in_graph_mode(): if not context.executing_eagerly():
# Note that in Eager mode, the updates are already executed when running # Note that in Eager mode, the updates are already executed when running
# assign_moving_averages. So we do not need to put them into # assign_moving_averages. So we do not need to put them into
# collections. # collections.
@ -493,7 +495,7 @@ class BatchNormalization(base.Layer):
return (r, d, new_mean, new_variance) return (r, d, new_mean, new_variance)
def call(self, inputs, training=False): def call(self, inputs, training=False):
in_eager_mode = context.in_eager_mode() in_eager_mode = context.executing_eagerly()
if self.virtual_batch_size is not None: if self.virtual_batch_size is not None:
# Virtual batches (aka ghost batches) can be simulated by reshaping the # Virtual batches (aka ghost batches) can be simulated by reshaping the
# Tensor and reusing the existing batch norm implementation # Tensor and reusing the existing batch norm implementation
@ -610,7 +612,7 @@ class BatchNormalization(base.Layer):
training, training,
lambda: _do_update(self.moving_variance, new_variance), lambda: _do_update(self.moving_variance, new_variance),
lambda: self.moving_variance) lambda: self.moving_variance)
if context.in_graph_mode(): if not context.executing_eagerly():
self.add_update(mean_update, inputs=inputs) self.add_update(mean_update, inputs=inputs)
self.add_update(variance_update, inputs=inputs) self.add_update(variance_update, inputs=inputs)

View File

@ -80,7 +80,7 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
def _ExtractInputShapes(inputs): def _ExtractInputShapes(inputs):
"""Extract the shapes of a set of input tensors.""" """Extract the shapes of a set of input tensors."""
if not context.in_graph_mode(): if context.executing_eagerly():
return array_ops.shape_n(inputs) return array_ops.shape_n(inputs)
sizes = [] sizes = []
fully_known = True fully_known = True
@ -106,7 +106,7 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
out_grads = [] out_grads = []
if isinstance(grad, ops.Tensor): if isinstance(grad, ops.Tensor):
if context.in_eager_mode(): if context.executing_eagerly():
# Using mod here for convenience since concat_dim is already verified # Using mod here for convenience since concat_dim is already verified
# in concat implementation to be within the allowed [-rank, rank) range. # in concat implementation to be within the allowed [-rank, rank) range.
non_neg_concat_dim = ( non_neg_concat_dim = (
@ -428,7 +428,7 @@ def _GatherV2Grad(op, grad):
# For axis 0 gathers, build an appropriately shaped IndexedSlices. # For axis 0 gathers, build an appropriately shaped IndexedSlices.
if axis_static == 0: if axis_static == 0:
if context.in_eager_mode(): if context.executing_eagerly():
params_tail_shape = params_shape.cpu()[1:] params_tail_shape = params_shape.cpu()[1:]
else: else:
params_tail_shape = params_shape[1:] params_tail_shape = params_shape[1:]
@ -578,7 +578,7 @@ def _TileGrad(op, grad):
axes = math_ops.range(0, array_ops.size(split_shape), 2) axes = math_ops.range(0, array_ops.size(split_shape), 2)
input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes) input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes)
# Fix shape inference # Fix shape inference
if context.in_graph_mode(): if not context.executing_eagerly():
input_grad.set_shape(op.inputs[0].get_shape()) input_grad.set_shape(op.inputs[0].get_shape())
return [input_grad, None] return [input_grad, None]

View File

@ -128,9 +128,7 @@ def identity(input, name=None): # pylint: disable=redefined-builtin
Returns: Returns:
A `Tensor`. Has the same type as `input`. A `Tensor`. Has the same type as `input`.
""" """
if context.in_graph_mode(): if context.executing_eagerly():
return gen_array_ops.identity(input, name=name)
else:
input = ops.convert_to_tensor(input) input = ops.convert_to_tensor(input)
in_device = input.device in_device = input.device
# TODO(ashankar): Does 'identity' need to invoke execution callbacks? # TODO(ashankar): Does 'identity' need to invoke execution callbacks?
@ -140,6 +138,8 @@ def identity(input, name=None): # pylint: disable=redefined-builtin
if context_device != in_device: if context_device != in_device:
return input._copy() # pylint: disable=protected-access return input._copy() # pylint: disable=protected-access
return input return input
else:
return gen_array_ops.identity(input, name=name)
# pylint: disable=redefined-builtin,protected-access # pylint: disable=redefined-builtin,protected-access
@ -305,7 +305,7 @@ def shape_internal(input, name=None, optimize=True, out_type=dtypes.int32):
sparse_tensor.SparseTensorValue)): sparse_tensor.SparseTensorValue)):
return gen_math_ops.cast(input.dense_shape, out_type) return gen_math_ops.cast(input.dense_shape, out_type)
else: else:
if context.in_graph_mode(): if not context.executing_eagerly():
input_tensor = ops.convert_to_tensor(input) input_tensor = ops.convert_to_tensor(input)
input_shape = input_tensor.get_shape() input_shape = input_tensor.get_shape()
if optimize and input_shape.is_fully_defined(): if optimize and input_shape.is_fully_defined():
@ -330,7 +330,7 @@ def shape_n(input, out_type=dtypes.int32, name=None):
""" """
output = gen_array_ops.shape_n(input, out_type=out_type, name=name) output = gen_array_ops.shape_n(input, out_type=out_type, name=name)
if context.in_graph_mode(): if not context.executing_eagerly():
for i, input_tensor in enumerate(input): for i, input_tensor in enumerate(input):
input_tensor = ops.convert_to_tensor(input_tensor) input_tensor = ops.convert_to_tensor(input_tensor)
input_shape = input_tensor.get_shape() input_shape = input_tensor.get_shape()
@ -385,9 +385,8 @@ def size_internal(input, name=None, optimize=True, out_type=dtypes.int32):
Returns: Returns:
A `Tensor` of type `out_type`. Defaults to `tf.int32`. A `Tensor` of type `out_type`. Defaults to `tf.int32`.
""" """
if context.in_eager_mode() and not isinstance( if context.executing_eagerly() and not isinstance(
input, (sparse_tensor.SparseTensor, input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
sparse_tensor.SparseTensorValue)):
return np.prod(ops.convert_to_tensor(input)._shape_tuple()) # pylint: disable=protected-access return np.prod(ops.convert_to_tensor(input)._shape_tuple()) # pylint: disable=protected-access
with ops.name_scope(name, "Size", [input]) as name: with ops.name_scope(name, "Size", [input]) as name:
if isinstance(input, (sparse_tensor.SparseTensor, if isinstance(input, (sparse_tensor.SparseTensor,
@ -783,7 +782,7 @@ def strided_slice(input_,
new_axis_mask=new_axis_mask, new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask) shrink_axis_mask=shrink_axis_mask)
if context.in_graph_mode(): if not context.executing_eagerly():
# TODO(apassos) In eager mode assignment will be done by overriding # TODO(apassos) In eager mode assignment will be done by overriding
# __setitem__ instead. # __setitem__ instead.
op.assign = assign op.assign = assign
@ -1457,7 +1456,7 @@ def transpose(a, perm=None, name="transpose", conjugate=False):
ret = transpose_fn(a, perm, name=name) ret = transpose_fn(a, perm, name=name)
# NOTE(mrry): Setting the shape explicitly because # NOTE(mrry): Setting the shape explicitly because
# reverse is not handled by the shape function. # reverse is not handled by the shape function.
if context.in_graph_mode(): if not context.executing_eagerly():
input_shape = ret.op.inputs[0].get_shape().dims input_shape = ret.op.inputs[0].get_shape().dims
if input_shape is not None: if input_shape is not None:
ret.set_shape(input_shape[::-1]) ret.set_shape(input_shape[::-1])
@ -1622,7 +1621,7 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
with ops.name_scope(name, "zeros_like", [tensor]) as name: with ops.name_scope(name, "zeros_like", [tensor]) as name:
tensor = ops.convert_to_tensor(tensor, name="tensor") tensor = ops.convert_to_tensor(tensor, name="tensor")
if context.in_eager_mode(): if context.executing_eagerly():
if dtype is not None and dtype != tensor.dtype: if dtype is not None and dtype != tensor.dtype:
return zeros( return zeros(
shape_internal(tensor, optimize=optimize), dtype=dtype, name=name) shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
@ -1678,7 +1677,7 @@ def ones_like(tensor, dtype=None, name=None, optimize=True):
if dtype is None: if dtype is None:
dtype = tensor.dtype dtype = tensor.dtype
ret = ones(ones_shape, dtype=dtype, name=name) ret = ones(ones_shape, dtype=dtype, name=name)
if context.in_graph_mode(): if not context.executing_eagerly():
ret.set_shape(tensor.get_shape()) ret.set_shape(tensor.get_shape())
return ret return ret
@ -1759,7 +1758,7 @@ def placeholder(dtype, shape=None, name=None):
Raises: Raises:
RuntimeError: if eager execution is enabled RuntimeError: if eager execution is enabled
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError("tf.placeholder() is not compatible with " raise RuntimeError("tf.placeholder() is not compatible with "
"eager execution.") "eager execution.")
@ -1822,7 +1821,7 @@ def sparse_placeholder(dtype, shape=None, name=None):
Raises: Raises:
RuntimeError: if eager execution is enabled RuntimeError: if eager execution is enabled
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError("tf.placeholder() is not compatible with " raise RuntimeError("tf.placeholder() is not compatible with "
"eager execution.") "eager execution.")
@ -1921,7 +1920,7 @@ def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pyl
raise ValueError("Unknown padding mode: %s" % mode) raise ValueError("Unknown padding mode: %s" % mode)
# Restore shape information where possible. # Restore shape information where possible.
if context.in_graph_mode(): if not context.executing_eagerly():
paddings_constant = tensor_util.constant_value( paddings_constant = tensor_util.constant_value(
result.op.inputs[1], partial=True) result.op.inputs[1], partial=True)
input_shape = result.op.inputs[0].shape input_shape = result.op.inputs[0].shape

View File

@ -169,7 +169,7 @@ def assert_negative(x, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_negative', [x, data]): with ops.name_scope(name, 'assert_negative', [x, data]):
x = ops.convert_to_tensor(x, name='x') x = ops.convert_to_tensor(x, name='x')
if data is None: if data is None:
if context.in_eager_mode(): if context.executing_eagerly():
name = _shape_and_dtype_str(x) name = _shape_and_dtype_str(x)
else: else:
name = x.name name = x.name
@ -210,7 +210,7 @@ def assert_positive(x, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_positive', [x, data]): with ops.name_scope(name, 'assert_positive', [x, data]):
x = ops.convert_to_tensor(x, name='x') x = ops.convert_to_tensor(x, name='x')
if data is None: if data is None:
if context.in_eager_mode(): if context.executing_eagerly():
name = _shape_and_dtype_str(x) name = _shape_and_dtype_str(x)
else: else:
name = x.name name = x.name
@ -251,7 +251,7 @@ def assert_non_negative(x, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_non_negative', [x, data]): with ops.name_scope(name, 'assert_non_negative', [x, data]):
x = ops.convert_to_tensor(x, name='x') x = ops.convert_to_tensor(x, name='x')
if data is None: if data is None:
if context.in_eager_mode(): if context.executing_eagerly():
name = _shape_and_dtype_str(x) name = _shape_and_dtype_str(x)
else: else:
name = x.name name = x.name
@ -293,7 +293,7 @@ def assert_non_positive(x, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_non_positive', [x, data]): with ops.name_scope(name, 'assert_non_positive', [x, data]):
x = ops.convert_to_tensor(x, name='x') x = ops.convert_to_tensor(x, name='x')
if data is None: if data is None:
if context.in_eager_mode(): if context.executing_eagerly():
name = _shape_and_dtype_str(x) name = _shape_and_dtype_str(x)
else: else:
name = x.name name = x.name
@ -343,7 +343,7 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
x = ops.convert_to_tensor(x, name='x') x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y') y = ops.convert_to_tensor(y, name='y')
if context.in_eager_mode(): if context.executing_eagerly():
eq = math_ops.equal(x, y) eq = math_ops.equal(x, y)
condition = math_ops.reduce_all(eq) condition = math_ops.reduce_all(eq)
if not condition: if not condition:
@ -435,7 +435,7 @@ def assert_none_equal(
with ops.name_scope(name, 'assert_none_equal', [x, y, data]): with ops.name_scope(name, 'assert_none_equal', [x, y, data]):
x = ops.convert_to_tensor(x, name='x') x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y') y = ops.convert_to_tensor(y, name='y')
if context.in_eager_mode(): if context.executing_eagerly():
x_name = _shape_and_dtype_str(x) x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y) y_name = _shape_and_dtype_str(y)
else: else:
@ -512,7 +512,7 @@ def assert_near(
rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=x.dtype) rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=x.dtype)
atol = ops.convert_to_tensor(atol, name='atol', dtype=x.dtype) atol = ops.convert_to_tensor(atol, name='atol', dtype=x.dtype)
if context.in_eager_mode(): if context.executing_eagerly():
x_name = _shape_and_dtype_str(x) x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y) y_name = _shape_and_dtype_str(y)
else: else:
@ -562,7 +562,7 @@ def assert_less(x, y, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_less', [x, y, data]): with ops.name_scope(name, 'assert_less', [x, y, data]):
x = ops.convert_to_tensor(x, name='x') x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y') y = ops.convert_to_tensor(y, name='y')
if context.in_eager_mode(): if context.executing_eagerly():
x_name = _shape_and_dtype_str(x) x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y) y_name = _shape_and_dtype_str(y)
else: else:
@ -610,7 +610,7 @@ def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_less_equal', [x, y, data]): with ops.name_scope(name, 'assert_less_equal', [x, y, data]):
x = ops.convert_to_tensor(x, name='x') x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y') y = ops.convert_to_tensor(y, name='y')
if context.in_eager_mode(): if context.executing_eagerly():
x_name = _shape_and_dtype_str(x) x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y) y_name = _shape_and_dtype_str(y)
else: else:
@ -658,7 +658,7 @@ def assert_greater(x, y, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_greater', [x, y, data]): with ops.name_scope(name, 'assert_greater', [x, y, data]):
x = ops.convert_to_tensor(x, name='x') x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y') y = ops.convert_to_tensor(y, name='y')
if context.in_eager_mode(): if context.executing_eagerly():
x_name = _shape_and_dtype_str(x) x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y) y_name = _shape_and_dtype_str(y)
else: else:
@ -708,7 +708,7 @@ def assert_greater_equal(x, y, data=None, summarize=None, message=None,
with ops.name_scope(name, 'assert_greater_equal', [x, y, data]): with ops.name_scope(name, 'assert_greater_equal', [x, y, data]):
x = ops.convert_to_tensor(x, name='x') x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y') y = ops.convert_to_tensor(y, name='y')
if context.in_eager_mode(): if context.executing_eagerly():
x_name = _shape_and_dtype_str(x) x_name = _shape_and_dtype_str(x)
y_name = _shape_and_dtype_str(y) y_name = _shape_and_dtype_str(y)
else: else:
@ -808,7 +808,7 @@ def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
static_condition = lambda actual_rank, given_rank: actual_rank == given_rank static_condition = lambda actual_rank, given_rank: actual_rank == given_rank
dynamic_condition = math_ops.equal dynamic_condition = math_ops.equal
if context.in_eager_mode(): if context.executing_eagerly():
name = '' name = ''
else: else:
name = x.name name = x.name
@ -873,7 +873,7 @@ def assert_rank_at_least(
static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank
dynamic_condition = math_ops.greater_equal dynamic_condition = math_ops.greater_equal
if context.in_eager_mode(): if context.executing_eagerly():
name = '' name = ''
else: else:
name = x.name name = x.name
@ -1001,7 +1001,7 @@ def assert_rank_in(
ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks]) ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks])
message = message or '' message = message or ''
if context.in_eager_mode(): if context.executing_eagerly():
name = '' name = ''
else: else:
name = x.name name = x.name
@ -1054,7 +1054,7 @@ def assert_integer(x, message=None, name=None):
with ops.name_scope(name, 'assert_integer', [x]): with ops.name_scope(name, 'assert_integer', [x]):
x = ops.convert_to_tensor(x, name='x') x = ops.convert_to_tensor(x, name='x')
if not x.dtype.is_integer: if not x.dtype.is_integer:
if context.in_eager_mode(): if context.executing_eagerly():
name = 'tensor' name = 'tensor'
else: else:
name = x.name name = x.name
@ -1087,12 +1087,11 @@ def assert_type(tensor, tf_type, message=None, name=None):
with ops.name_scope(name, 'assert_type', [tensor]): with ops.name_scope(name, 'assert_type', [tensor]):
tensor = ops.convert_to_tensor(tensor, name='tensor') tensor = ops.convert_to_tensor(tensor, name='tensor')
if tensor.dtype != tf_type: if tensor.dtype != tf_type:
if context.in_graph_mode(): if context.executing_eagerly():
raise TypeError( raise TypeError('%s tensor must be of type %s' % (message, tf_type))
'%s %s must be of type %s' % (message, tensor.name, tf_type))
else: else:
raise TypeError( raise TypeError('%s %s must be of type %s' % (message, tensor.name,
'%s tensor must be of type %s' % (message, tf_type)) tf_type))
return control_flow_ops.no_op('statically_determined_correct_type') return control_flow_ops.no_op('statically_determined_correct_type')
@ -1240,7 +1239,7 @@ def assert_scalar(tensor, name=None):
tensor = ops.convert_to_tensor(tensor, name=name_scope) tensor = ops.convert_to_tensor(tensor, name=name_scope)
shape = tensor.get_shape() shape = tensor.get_shape()
if shape.ndims != 0: if shape.ndims != 0:
if context.in_eager_mode(): if context.executing_eagerly():
raise ValueError('Expected scalar shape, saw shape: %s.' raise ValueError('Expected scalar shape, saw shape: %s.'
% (shape,)) % (shape,))
else: else:

View File

@ -152,7 +152,7 @@ def Assert(condition, data, summarize=None, name=None):
@compatibility{eager} `tf.errors.InvalidArgumentError` if `condition` @compatibility{eager} `tf.errors.InvalidArgumentError` if `condition`
is not true is not true
""" """
if context.in_eager_mode(): if context.executing_eagerly():
if not condition: if not condition:
xs = ops.convert_n_to_tensor(data) xs = ops.convert_n_to_tensor(data)
data_str = [_summarize_eager(x, summarize) for x in xs] data_str = [_summarize_eager(x, summarize) for x in xs]
@ -178,7 +178,7 @@ def Assert(condition, data, summarize=None, name=None):
condition, data, summarize, name="Assert") condition, data, summarize, name="Assert")
guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard") guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard")
if context.in_eager_mode(): if context.executing_eagerly():
return return
return guarded_assert.op return guarded_assert.op
@ -2025,7 +2025,7 @@ def cond(pred,
raise TypeError("false_fn must be callable.") raise TypeError("false_fn must be callable.")
with ops.name_scope(name, "cond", [pred]): with ops.name_scope(name, "cond", [pred]):
if context.in_eager_mode(): if context.executing_eagerly():
if pred: if pred:
return _UnpackIfSingleton(true_fn()) return _UnpackIfSingleton(true_fn())
return _UnpackIfSingleton(false_fn()) return _UnpackIfSingleton(false_fn())
@ -3177,7 +3177,7 @@ def while_loop(cond,
math_ops.logical_and(i < maximum_iterations, orig_cond(*lv))) math_ops.logical_and(i < maximum_iterations, orig_cond(*lv)))
body = lambda i, lv: (i + 1, orig_body(*lv)) body = lambda i, lv: (i + 1, orig_body(*lv))
if context.in_eager_mode(): if context.executing_eagerly():
while cond(*loop_vars): while cond(*loop_vars):
loop_vars = body(*loop_vars) loop_vars = body(*loop_vars)
if maximum_iterations is not None: if maximum_iterations is not None:
@ -3271,7 +3271,7 @@ def with_dependencies(dependencies, output_tensor, name=None):
Raises: Raises:
TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`. TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
return output_tensor return output_tensor
with ops.name_scope(name, "control_dependency", with ops.name_scope(name, "control_dependency",
list(dependencies) + [output_tensor]) as name: list(dependencies) + [output_tensor]) as name:
@ -3316,7 +3316,7 @@ def group(*inputs, **kwargs):
Raises: Raises:
ValueError: If an unknown keyword argument is provided. ValueError: If an unknown keyword argument is provided.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
return None return None
name = kwargs.pop("name", None) name = kwargs.pop("name", None)
if kwargs: if kwargs:
@ -3396,7 +3396,7 @@ def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined
objects. objects.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
return tensors return tensors
with ops.name_scope(name, "tuple", tensors) as name: with ops.name_scope(name, "tuple", tensors) as name:
tensors = [t if (isinstance(t, ops.Operation) tensors = [t if (isinstance(t, ops.Operation)

View File

@ -92,7 +92,7 @@ def custom_gradient(f):
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
"""Decorated function with custom gradient.""" """Decorated function with custom gradient."""
if context.in_graph_mode(): if not context.executing_eagerly():
if kwargs: if kwargs:
raise ValueError( raise ValueError(
"The custom_gradient decorator currently suports keywords " "The custom_gradient decorator currently suports keywords "

View File

@ -159,7 +159,7 @@ class QueueBase(object):
ValueError: If one of the arguments is invalid. ValueError: If one of the arguments is invalid.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError( raise RuntimeError(
"Queues are not supported when eager execution is enabled. " "Queues are not supported when eager execution is enabled. "
"Instead, please use tf.data to get data into your model.") "Instead, please use tf.data to get data into your model.")
@ -177,10 +177,10 @@ class QueueBase(object):
else: else:
self._names = None self._names = None
self._queue_ref = queue_ref self._queue_ref = queue_ref
if context.in_graph_mode(): if context.executing_eagerly():
self._name = self._queue_ref.op.name.split("/")[-1]
else:
self._name = context.context().scope_name self._name = context.context().scope_name
else:
self._name = self._queue_ref.op.name.split("/")[-1]
@staticmethod @staticmethod
def from_list(index, queues): def from_list(index, queues):
@ -231,9 +231,9 @@ class QueueBase(object):
@property @property
def name(self): def name(self):
"""The name of the underlying queue.""" """The name of the underlying queue."""
if context.in_graph_mode(): if context.executing_eagerly():
return self._queue_ref.op.name
return self._name return self._name
return self._queue_ref.op.name
@property @property
def dtypes(self): def dtypes(self):
@ -444,7 +444,7 @@ class QueueBase(object):
# NOTE(mrry): Not using a shape function because we need access to # NOTE(mrry): Not using a shape function because we need access to
# the `QueueBase` object. # the `QueueBase` object.
if context.in_graph_mode(): if not context.executing_eagerly():
op = ret[0].op op = ret[0].op
for output, shape in zip(op.values(), self._shapes): for output, shape in zip(op.values(), self._shapes):
output.set_shape(shape) output.set_shape(shape)
@ -484,7 +484,7 @@ class QueueBase(object):
# NOTE(mrry): Not using a shape function because we need access to # NOTE(mrry): Not using a shape function because we need access to
# the Queue object. # the Queue object.
if context.in_graph_mode(): if not context.executing_eagerly():
op = ret[0].op op = ret[0].op
batch_dim = tensor_shape.Dimension( batch_dim = tensor_shape.Dimension(
tensor_util.constant_value(op.inputs[1])) tensor_util.constant_value(op.inputs[1]))
@ -528,7 +528,7 @@ class QueueBase(object):
# NOTE(mrry): Not using a shape function because we need access to # NOTE(mrry): Not using a shape function because we need access to
# the Queue object. # the Queue object.
if context.in_graph_mode(): if not context.executing_eagerly():
op = ret[0].op op = ret[0].op
for output, shape in zip(op.values(), self._shapes): for output, shape in zip(op.values(), self._shapes):
output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape)) output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape))
@ -990,10 +990,10 @@ class Barrier(object):
shapes=self._shapes, shapes=self._shapes,
shared_name=shared_name, shared_name=shared_name,
name=name) name=name)
if context.in_graph_mode(): if context.executing_eagerly():
self._name = self._barrier_ref.op.name.split("/")[-1]
else:
self._name = context.context().scope_name self._name = context.context().scope_name
else:
self._name = self._barrier_ref.op.name.split("/")[-1]
@property @property
def barrier_ref(self): def barrier_ref(self):
@ -1003,9 +1003,9 @@ class Barrier(object):
@property @property
def name(self): def name(self):
"""The name of the underlying barrier.""" """The name of the underlying barrier."""
if context.in_graph_mode(): if context.executing_eagerly():
return self._barrier_ref.op.name
return self._name return self._name
return self._barrier_ref.op.name
def insert_many(self, component_index, keys, values, name=None): def insert_many(self, component_index, keys, values, name=None):
"""For each key, assigns the respective value to the specified component. """For each key, assigns the respective value to the specified component.
@ -1083,7 +1083,7 @@ class Barrier(object):
# NOTE(mrry): Not using a shape function because we need access to # NOTE(mrry): Not using a shape function because we need access to
# the Barrier object. # the Barrier object.
if context.in_graph_mode(): if not context.executing_eagerly():
op = ret[0].op op = ret[0].op
if allow_small_batch: if allow_small_batch:
batch_dim = None batch_dim = None
@ -1183,10 +1183,10 @@ class ConditionalAccumulatorBase(object):
else: else:
self._shape = tensor_shape.unknown_shape() self._shape = tensor_shape.unknown_shape()
self._accumulator_ref = accumulator_ref self._accumulator_ref = accumulator_ref
if context.in_graph_mode(): if context.executing_eagerly():
self._name = self._accumulator_ref.op.name.split("/")[-1]
else:
self._name = context.context().scope_name self._name = context.context().scope_name
else:
self._name = self._accumulator_ref.op.name.split("/")[-1]
@property @property
def accumulator_ref(self): def accumulator_ref(self):

View File

@ -90,7 +90,7 @@ def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
if not callable(fn): if not callable(fn):
raise TypeError("fn must be callable.") raise TypeError("fn must be callable.")
in_graph_mode = context.in_graph_mode() in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "foldl", [elems]): with ops.name_scope(name, "foldl", [elems]):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are # TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager # supported in Eager
@ -178,7 +178,7 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
if not callable(fn): if not callable(fn):
raise TypeError("fn must be callable.") raise TypeError("fn must be callable.")
in_graph_mode = context.in_graph_mode() in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "foldr", [elems]): with ops.name_scope(name, "foldr", [elems]):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are # TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager # supported in Eager
@ -343,7 +343,7 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
elems_flat = input_flatten(elems) elems_flat = input_flatten(elems)
in_graph_mode = context.in_graph_mode() in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "map", elems_flat): with ops.name_scope(name, "map", elems_flat):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are # TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager # supported in Eager
@ -536,7 +536,7 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
elems_flat = input_flatten(elems) elems_flat = input_flatten(elems)
in_graph_mode = context.in_graph_mode() in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "scan", elems_flat): with ops.name_scope(name, "scan", elems_flat):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are # TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager # supported in Eager

View File

@ -86,7 +86,7 @@ def _IndexedSlicesToTensor(value, dtype=None, name=None, as_ref=False):
% str(value)) % str(value))
# TODO(mrry): Consider adding static shape information to # TODO(mrry): Consider adding static shape information to
# IndexedSlices, to avoid using numpy here. # IndexedSlices, to avoid using numpy here.
if context.in_graph_mode(): if not context.executing_eagerly():
dense_shape_value = tensor_util.constant_value(value.dense_shape) dense_shape_value = tensor_util.constant_value(value.dense_shape)
if dense_shape_value is not None: if dense_shape_value is not None:
num_elements = np.prod(dense_shape_value) num_elements = np.prod(dense_shape_value)
@ -491,9 +491,10 @@ def gradients(ys,
def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
gate_gradients, aggregation_method, stop_gradients): gate_gradients, aggregation_method, stop_gradients):
"""Implementation of gradients().""" """Implementation of gradients()."""
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError("tf.gradients not supported in EAGER mode. Use " raise RuntimeError("tf.gradients not supported when eager execution "
"functions in tf.contrib.eager.backprop instead.") "is enabled. Use tf.contrib.eager.GradientTape "
"instead.")
ys = _AsList(ys) ys = _AsList(ys)
xs = _AsList(xs) xs = _AsList(xs)
stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients) stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)

View File

@ -173,7 +173,7 @@ class ReaderBase(object):
Raises: Raises:
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError( raise RuntimeError(
"Readers are not supported when eager execution is enabled. " "Readers are not supported when eager execution is enabled. "
"Instead, please use tf.data to get data into your model.") "Instead, please use tf.data to get data into your model.")

View File

@ -157,10 +157,10 @@ class InitializableLookupTableBase(LookupInterface):
default_value: The value to use if a key is missing in the table. default_value: The value to use if a key is missing in the table.
initializer: The table initializer to use. initializer: The table initializer to use.
""" """
if context.in_graph_mode(): if context.executing_eagerly():
name = table_ref.op.name.split("/")[-1]
else:
name = context.context().scope_name name = context.context().scope_name
else:
name = table_ref.op.name.split("/")[-1]
super(InitializableLookupTableBase, super(InitializableLookupTableBase,
self).__init__(initializer.key_dtype, initializer.value_dtype, self).__init__(initializer.key_dtype, initializer.value_dtype,
name) name)
@ -521,7 +521,7 @@ class TextFileInitializer(TableInitializerBase):
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
# If the filename tensor is anything other than a string constant (e.g., if # If the filename tensor is anything other than a string constant (e.g., if
# it is a placeholder) then it does not make sense to track it as an asset. # it is a placeholder) then it does not make sense to track it as an asset.
if context.in_graph_mode() and constant_op.is_constant(filename): if not context.executing_eagerly() and constant_op.is_constant(filename):
ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename) ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
return init_op return init_op

View File

@ -136,7 +136,7 @@ def _num_present(losses, weights, per_batch=False):
`[batch_size]`. Otherwise, a single scalar tensor is returned. `[batch_size]`. Otherwise, a single scalar tensor is returned.
""" """
if ((isinstance(weights, float) and weights != 0.0) or if ((isinstance(weights, float) and weights != 0.0) or
(context.in_eager_mode() and weights._rank() == 0 # pylint: disable=protected-access (context.executing_eagerly() and weights._rank() == 0 # pylint: disable=protected-access
and not math_ops.equal(weights, 0.0))): and not math_ops.equal(weights, 0.0))):
return _num_elements(losses) return _num_elements(losses)
with ops.name_scope(None, "num_present", (losses, weights)) as scope: with ops.name_scope(None, "num_present", (losses, weights)) as scope:

View File

@ -52,14 +52,14 @@ def _SumGrad(op, grad):
if axes is not None: if axes is not None:
rank = len(input_0_shape) rank = len(input_0_shape)
if np.array_equal(axes, np.arange(rank)): # Reduce all dims. if np.array_equal(axes, np.arange(rank)): # Reduce all dims.
if context.in_graph_mode(): if context.executing_eagerly():
new_shape = [1] * rank
else:
ctx = context.context() ctx = context.context()
new_shape = ctx.ones_rank_cache().get(rank) new_shape = ctx.ones_rank_cache().get(rank)
if new_shape is None: if new_shape is None:
new_shape = constant_op.constant([1] * rank, dtype=dtypes.int32) new_shape = constant_op.constant([1] * rank, dtype=dtypes.int32)
ctx.ones_rank_cache().put(rank, new_shape) ctx.ones_rank_cache().put(rank, new_shape)
else:
new_shape = [1] * rank
grad = array_ops.reshape(grad, new_shape) grad = array_ops.reshape(grad, new_shape)
# If shape is not fully defined (but rank is), we use Shape. # If shape is not fully defined (but rank is), we use Shape.
if None not in input_0_shape: if None not in input_0_shape:
@ -997,7 +997,7 @@ def _SparseMatMulGrad(op, grad):
op.inputs[0]: op.get_attr("a_is_sparse"), op.inputs[0]: op.get_attr("a_is_sparse"),
op.inputs[1]: op.get_attr("b_is_sparse"), op.inputs[1]: op.get_attr("b_is_sparse"),
# Use heuristic to figure out if grad might be sparse # Use heuristic to figure out if grad might be sparse
grad: context.in_graph_mode() and (grad.op.type == "ReluGrad") grad: not context.executing_eagerly() and (grad.op.type == "ReluGrad")
} }
def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False): def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False):

View File

@ -2007,14 +2007,14 @@ def matmul(a,
if transpose_b and adjoint_b: if transpose_b and adjoint_b:
raise ValueError("Only one of transpose_b and adjoint_b can be True.") raise ValueError("Only one of transpose_b and adjoint_b can be True.")
if context.in_graph_mode(): if context.executing_eagerly():
a = ops.convert_to_tensor(a, name="a")
b = ops.convert_to_tensor(b, name="b")
else:
if not isinstance(a, (ops.EagerTensor, _resource_variable_type)): if not isinstance(a, (ops.EagerTensor, _resource_variable_type)):
a = ops.convert_to_tensor(a, name="a") a = ops.convert_to_tensor(a, name="a")
if not isinstance(b, (ops.EagerTensor, _resource_variable_type)): if not isinstance(b, (ops.EagerTensor, _resource_variable_type)):
b = ops.convert_to_tensor(b, name="b") b = ops.convert_to_tensor(b, name="b")
else:
a = ops.convert_to_tensor(a, name="a")
b = ops.convert_to_tensor(b, name="b")
# TODO(apassos) remove _shape_tuple here when it is not needed. # TODO(apassos) remove _shape_tuple here when it is not needed.
a_shape = a._shape_tuple() # pylint: disable=protected-access a_shape = a._shape_tuple() # pylint: disable=protected-access
@ -2249,7 +2249,7 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
return inputs[0] return inputs[0]
elif len(inputs) == 1 and name is not None: elif len(inputs) == 1 and name is not None:
return array_ops.identity(inputs[0], name=name) return array_ops.identity(inputs[0], name=name)
elif context.in_eager_mode(): elif context.executing_eagerly():
# TemporaryVariable not currently supported in eager mode; fall back # TemporaryVariable not currently supported in eager mode; fall back
# onto AddN for now. # onto AddN for now.
# TODO(frreiss) remove this once the lifetime of eager variables gets # TODO(frreiss) remove this once the lifetime of eager variables gets

View File

@ -60,7 +60,7 @@ class ReduceTest(test_util.TensorFlowTestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testReduceInvalidAxis(self): def testReduceInvalidAxis(self):
if context.in_eager_mode(): if context.executing_eagerly():
# The shape check is in run a graph construction time. In eager mode, # The shape check is in run a graph construction time. In eager mode,
# it misses the check, magically return result given wrong shape. # it misses the check, magically return result given wrong shape.
return return
@ -249,7 +249,7 @@ class ScalarMulTest(test_util.TensorFlowTestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testAcceptsRefs(self): def testAcceptsRefs(self):
if context.in_eager_mode(): if context.executing_eagerly():
var = resource_variable_ops.ResourceVariable(10, name="var") var = resource_variable_ops.ResourceVariable(10, name="var")
else: else:
var = variables.Variable(10) var = variables.Variable(10)

View File

@ -308,7 +308,7 @@ def mean(values,
or tuple. or tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean is not supported when eager execution ' raise RuntimeError('tf.metrics.mean is not supported when eager execution '
'is enabled.') 'is enabled.')
@ -394,7 +394,7 @@ def accuracy(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.accuracy is not supported when eager ' raise RuntimeError('tf.metrics.accuracy is not supported when eager '
'execution is enabled.') 'execution is enabled.')
@ -644,7 +644,7 @@ def auc(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.auc is not supported when eager execution ' raise RuntimeError('tf.metrics.auc is not supported when eager execution '
'is enabled.') 'is enabled.')
@ -758,7 +758,7 @@ def mean_absolute_error(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_absolute_error is not supported ' raise RuntimeError('tf.metrics.mean_absolute_error is not supported '
'when eager execution is enabled.') 'when eager execution is enabled.')
@ -818,7 +818,7 @@ def mean_cosine_distance(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when ' raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when '
'eager execution is enabled.') 'eager execution is enabled.')
@ -891,7 +891,7 @@ def mean_per_class_accuracy(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported ' raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported '
'when eager execution is enabled.') 'when eager execution is enabled.')
@ -996,7 +996,7 @@ def mean_iou(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_iou is not supported when ' raise RuntimeError('tf.metrics.mean_iou is not supported when '
'eager execution is enabled.') 'eager execution is enabled.')
@ -1098,7 +1098,7 @@ def mean_relative_error(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_relative_error is not supported when ' raise RuntimeError('tf.metrics.mean_relative_error is not supported when '
'eager execution is enabled.') 'eager execution is enabled.')
@ -1165,7 +1165,7 @@ def mean_squared_error(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_squared_error is not supported when ' raise RuntimeError('tf.metrics.mean_squared_error is not supported when '
'eager execution is enabled.') 'eager execution is enabled.')
@ -1223,7 +1223,7 @@ def mean_tensor(values,
or tuple. or tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean_tensor is not supported when ' raise RuntimeError('tf.metrics.mean_tensor is not supported when '
'eager execution is enabled.') 'eager execution is enabled.')
@ -1304,7 +1304,7 @@ def percentage_below(values,
or tuple. or tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.percentage_below is not supported when ' raise RuntimeError('tf.metrics.percentage_below is not supported when '
'eager execution is enabled.') 'eager execution is enabled.')
@ -1397,7 +1397,7 @@ def false_negatives(labels,
or tuple. or tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.false_negatives is not supported when ' raise RuntimeError('tf.metrics.false_negatives is not supported when '
'eager execution is enabled.') 'eager execution is enabled.')
@ -1453,7 +1453,7 @@ def false_negatives_at_thresholds(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not ' raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not '
'supported when eager execution is enabled.') 'supported when eager execution is enabled.')
@ -1507,7 +1507,7 @@ def false_positives(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.false_positives is not supported when ' raise RuntimeError('tf.metrics.false_positives is not supported when '
'eager execution is enabled.') 'eager execution is enabled.')
@ -1563,7 +1563,7 @@ def false_positives_at_thresholds(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.false_positives_at_thresholds is not ' raise RuntimeError('tf.metrics.false_positives_at_thresholds is not '
'supported when eager execution is enabled.') 'supported when eager execution is enabled.')
@ -1617,7 +1617,7 @@ def true_negatives(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.true_negatives is not ' raise RuntimeError('tf.metrics.true_negatives is not '
'supported when eager execution is enabled.') 'supported when eager execution is enabled.')
@ -1673,7 +1673,7 @@ def true_negatives_at_thresholds(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not ' raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not '
'supported when eager execution is enabled.') 'supported when eager execution is enabled.')
@ -1727,7 +1727,7 @@ def true_positives(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.true_positives is not ' raise RuntimeError('tf.metrics.true_positives is not '
'supported when eager execution is enabled.') 'supported when eager execution is enabled.')
@ -1783,7 +1783,7 @@ def true_positives_at_thresholds(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.true_positives_at_thresholds is not ' raise RuntimeError('tf.metrics.true_positives_at_thresholds is not '
'supported when eager execution is enabled.') 'supported when eager execution is enabled.')
@ -1851,7 +1851,7 @@ def precision(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.precision is not ' raise RuntimeError('tf.metrics.precision is not '
'supported when eager execution is enabled.') 'supported when eager execution is enabled.')
@ -1947,7 +1947,7 @@ def precision_at_thresholds(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.precision_at_thresholds is not ' raise RuntimeError('tf.metrics.precision_at_thresholds is not '
'supported when eager execution is enabled.') 'supported when eager execution is enabled.')
@ -2023,7 +2023,7 @@ def recall(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.recall is not supported is not ' raise RuntimeError('tf.metrics.recall is not supported is not '
'supported when eager execution is enabled.') 'supported when eager execution is enabled.')
@ -2400,7 +2400,7 @@ def recall_at_k(labels,
are not a list or tuple. are not a list or tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.recall_at_k is not ' raise RuntimeError('tf.metrics.recall_at_k is not '
'supported when eager execution is enabled.') 'supported when eager execution is enabled.')
@ -2549,7 +2549,7 @@ def recall_at_thresholds(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.recall_at_thresholds is not ' raise RuntimeError('tf.metrics.recall_at_thresholds is not '
'supported when eager execution is enabled.') 'supported when eager execution is enabled.')
@ -2626,7 +2626,7 @@ def root_mean_squared_error(labels,
tuple. tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.root_mean_squared_error is not ' raise RuntimeError('tf.metrics.root_mean_squared_error is not '
'supported when eager execution is enabled.') 'supported when eager execution is enabled.')
@ -2707,7 +2707,7 @@ def sensitivity_at_specificity(labels,
or `updates_collections` are not a list or tuple. or `updates_collections` are not a list or tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.sensitivity_at_specificity is not ' raise RuntimeError('tf.metrics.sensitivity_at_specificity is not '
'supported when eager execution is enabled.') 'supported when eager execution is enabled.')
@ -3098,7 +3098,7 @@ def average_precision_at_k(labels,
ValueError: if k is invalid. ValueError: if k is invalid.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not ' raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not '
'supported when eager execution is enabled.') 'supported when eager execution is enabled.')
@ -3267,7 +3267,7 @@ def precision_at_top_k(labels,
are not a list or tuple. are not a list or tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.precision_at_top_k is not ' raise RuntimeError('tf.metrics.precision_at_top_k is not '
'supported when eager execution is enabled.') 'supported when eager execution is enabled.')
@ -3396,7 +3396,7 @@ def precision_at_k(labels,
are not a list or tuple. are not a list or tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.sparse_precision_at_k is not ' raise RuntimeError('tf.metrics.sparse_precision_at_k is not '
'supported when eager execution is enabled.') 'supported when eager execution is enabled.')
@ -3473,7 +3473,7 @@ def specificity_at_sensitivity(labels,
or `updates_collections` are not a list or tuple. or `updates_collections` are not a list or tuple.
RuntimeError: If eager execution is enabled. RuntimeError: If eager execution is enabled.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError('tf.metrics.specificity_at_sensitivity is not ' raise RuntimeError('tf.metrics.specificity_at_sensitivity is not '
'supported when eager execution is enabled.') 'supported when eager execution is enabled.')

View File

@ -456,7 +456,7 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
def IsZero(g): def IsZero(g):
# Some introspection to check if the gradient is feeding zeros # Some introspection to check if the gradient is feeding zeros
if context.in_eager_mode(): if context.executing_eagerly():
# TODO(apassos) add an efficient way to detect eager zeros here. # TODO(apassos) add an efficient way to detect eager zeros here.
return False return False
if g.op.type in ("ZerosLike", "Zeros"): if g.op.type in ("ZerosLike", "Zeros"):

View File

@ -1504,7 +1504,7 @@ def bias_add(value, bias, data_format=None, name=None):
A `Tensor` with the same type as `value`. A `Tensor` with the same type as `value`.
""" """
with ops.name_scope(name, "BiasAdd", [value, bias]) as name: with ops.name_scope(name, "BiasAdd", [value, bias]) as name:
if context.in_graph_mode(): if not context.executing_eagerly():
value = ops.convert_to_tensor(value, name="input") value = ops.convert_to_tensor(value, name="input")
bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias") bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
return gen_nn_ops.bias_add(value, bias, data_format=data_format, name=name) return gen_nn_ops.bias_add(value, bias, data_format=data_format, name=name)
@ -1616,7 +1616,7 @@ def _flatten_outer_dims(logits):
output = array_ops.reshape(logits, array_ops.concat([[-1], last_dim_size], 0)) output = array_ops.reshape(logits, array_ops.concat([[-1], last_dim_size], 0))
# Set output shape if known. # Set output shape if known.
if context.in_graph_mode(): if not context.executing_eagerly():
shape = logits.get_shape() shape = logits.get_shape()
if shape is not None and shape.dims is not None: if shape is not None and shape.dims is not None:
shape = shape.as_list() shape = shape.as_list()
@ -1881,7 +1881,8 @@ def softmax_cross_entropy_with_logits_v2(
# Make shape inference work since reshape and transpose may erase its static # Make shape inference work since reshape and transpose may erase its static
# shape. # shape.
if context.in_graph_mode() and shape is not None and shape.dims is not None: if not context.executing_eagerly(
) and shape is not None and shape.dims is not None:
shape = shape.as_list() shape = shape.as_list()
del shape[dim] del shape[dim]
cost.set_shape(shape) cost.set_shape(shape)
@ -2318,7 +2319,7 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: di
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
binary_tensor = math_ops.floor(random_tensor) binary_tensor = math_ops.floor(random_tensor)
ret = math_ops.div(x, keep_prob) * binary_tensor ret = math_ops.div(x, keep_prob) * binary_tensor
if context.in_graph_mode(): if not context.executing_eagerly():
ret.set_shape(x.get_shape()) ret.set_shape(x.get_shape())
return ret return ret

View File

@ -74,7 +74,7 @@ def add_check_numerics_ops():
the checked operations. the checked operations.
@enc_compatibility @enc_compatibility
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError( raise RuntimeError(
"add_check_numerics_ops() is not compatible with eager execution. " "add_check_numerics_ops() is not compatible with eager execution. "
"To check for Inf's and NaN's under eager execution, call " "To check for Inf's and NaN's under eager execution, call "

View File

@ -267,9 +267,9 @@ class ResourceVariable(variables.Variable):
if initial_value is not None: if initial_value is not None:
raise ValueError("variable_def and initial_value are mutually " raise ValueError("variable_def and initial_value are mutually "
"exclusive.") "exclusive.")
if not context.in_graph_mode(): if context.executing_eagerly():
raise ValueError("Creating ResourceVariable from variable_def" raise ValueError("Creating ResourceVariable from variable_def is "
" only supported in GRAPH mode.") "not supported when eager execution is enabled.")
self._init_from_proto(variable_def, import_scope=import_scope) self._init_from_proto(variable_def, import_scope=import_scope)
else: else:
self._init_from_args( self._init_from_args(
@ -363,7 +363,7 @@ class ResourceVariable(variables.Variable):
# this graph. # this graph.
self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
with ops.init_scope(): with ops.init_scope():
self._in_graph_mode = context.in_graph_mode() self._in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "Variable", [] with ops.name_scope(name, "Variable", []
if init_from_fn else [initial_value]) as name: if init_from_fn else [initial_value]) as name:
# pylint: disable=protected-access # pylint: disable=protected-access
@ -470,7 +470,7 @@ class ResourceVariable(variables.Variable):
self._cached_value = self._read_variable_op() self._cached_value = self._read_variable_op()
else: else:
self._cached_value = None self._cached_value = None
if context.in_graph_mode(): if not context.executing_eagerly():
ops.add_to_collections(collections, self) ops.add_to_collections(collections, self)
elif ops.GraphKeys.GLOBAL_STEP in collections: elif ops.GraphKeys.GLOBAL_STEP in collections:
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self) ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)
@ -489,7 +489,7 @@ class ResourceVariable(variables.Variable):
def _init_from_proto(self, variable_def, import_scope=None): def _init_from_proto(self, variable_def, import_scope=None):
"""Initializes from `VariableDef` proto.""" """Initializes from `VariableDef` proto."""
# Note that init_from_proto is currently not supported in Eager mode. # Note that init_from_proto is currently not supported in Eager mode.
assert context.in_graph_mode() assert not context.executing_eagerly()
self._in_graph_mode = True self._in_graph_mode = True
assert isinstance(variable_def, variable_pb2.VariableDef) assert isinstance(variable_def, variable_pb2.VariableDef)
if not variable_def.is_resource: if not variable_def.is_resource:
@ -582,7 +582,8 @@ class ResourceVariable(variables.Variable):
def create(self): def create(self):
"""The op responsible for initializing this variable.""" """The op responsible for initializing this variable."""
if not self._in_graph_mode: if not self._in_graph_mode:
raise RuntimeError("Calling create in EAGER mode not supported.") raise RuntimeError("Calling create is not supported when eager execution"
" is enabled.")
return self._initializer_op return self._initializer_op
@property @property
@ -610,7 +611,7 @@ class ResourceVariable(variables.Variable):
@property @property
def initial_value(self): def initial_value(self):
"""Returns the Tensor used as the initial value for the variable.""" """Returns the Tensor used as the initial value for the variable."""
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError("initial_value not supported in EAGER mode.") raise RuntimeError("initial_value not supported in EAGER mode.")
return self._initial_value return self._initial_value
@ -631,15 +632,15 @@ class ResourceVariable(variables.Variable):
def eval(self, session=None): def eval(self, session=None):
"""Evaluates and returns the value of this variable.""" """Evaluates and returns the value of this variable."""
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError("Trying to eval in EAGER mode") raise RuntimeError("Trying to eval in EAGER mode")
return self._graph_element.eval(session=session) return self._graph_element.eval(session=session)
def numpy(self): def numpy(self):
if context.in_graph_mode(): if context.executing_eagerly():
return self.read_value().numpy()
raise NotImplementedError( raise NotImplementedError(
"numpy() is only available when eager execution is enabled.") "numpy() is only available when eager execution is enabled.")
return self.read_value().numpy()
def count_up_to(self, limit): def count_up_to(self, limit):
"""Increments this variable until it reaches `limit`. """Increments this variable until it reaches `limit`.
@ -720,7 +721,7 @@ class ResourceVariable(variables.Variable):
A `VariableDef` protocol buffer, or `None` if the `Variable` is not A `VariableDef` protocol buffer, or `None` if the `Variable` is not
in the specified name scope. in the specified name scope.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError("to_proto not supported in EAGER mode.") raise RuntimeError("to_proto not supported in EAGER mode.")
if export_scope is None or self.handle.name.startswith(export_scope): if export_scope is None or self.handle.name.startswith(export_scope):
var_def = variable_pb2.VariableDef() var_def = variable_pb2.VariableDef()
@ -747,7 +748,7 @@ class ResourceVariable(variables.Variable):
@staticmethod @staticmethod
def from_proto(variable_def, import_scope=None): def from_proto(variable_def, import_scope=None):
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError("from_proto not supported in EAGER mode.") raise RuntimeError("from_proto not supported in EAGER mode.")
return ResourceVariable( return ResourceVariable(
variable_def=variable_def, import_scope=import_scope) variable_def=variable_def, import_scope=import_scope)
@ -984,10 +985,10 @@ class _UnreadVariable(ResourceVariable):
self._is_initialized_op = None self._is_initialized_op = None
self._initializer_op = None self._initializer_op = None
self._parent_op = parent_op self._parent_op = parent_op
if context.in_graph_mode(): if context.executing_eagerly():
self._graph_element = self.read_value()
else:
self._graph_element = None self._graph_element = None
else:
self._graph_element = self.read_value()
self._handle_deleter = deleter self._handle_deleter = deleter
def value(self): def value(self):

View File

@ -575,7 +575,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
# Create a new scope in which the caching device is either # Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached # determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN. # Variable using the same placement as for the rest of the RNN.
if context.in_graph_mode(): if not context.executing_eagerly():
if varscope.caching_device is None: if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device) varscope.set_caching_device(lambda op: op.device)
@ -616,7 +616,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
["Expected shape for Tensor %s is " % x.name, ["Expected shape for Tensor %s is " % x.name,
packed_shape, " but saw shape: ", x_shape]) packed_shape, " but saw shape: ", x_shape])
if context.in_graph_mode() and sequence_length is not None: if not context.executing_eagerly() and sequence_length is not None:
# Perform some shape validation # Perform some shape validation
with ops.control_dependencies( with ops.control_dependencies(
[_assert_has_shape(sequence_length, [batch_size])]): [_assert_has_shape(sequence_length, [batch_size])]):
@ -742,7 +742,7 @@ def _dynamic_rnn_loop(cell,
element_shape=element_shape, element_shape=element_shape,
tensor_array_name=base_name + name) tensor_array_name=base_name + name)
in_graph_mode = context.in_graph_mode() in_graph_mode = not context.executing_eagerly()
if in_graph_mode: if in_graph_mode:
output_ta = tuple( output_ta = tuple(
_create_ta( _create_ta(
@ -1027,7 +1027,7 @@ def raw_rnn(cell, loop_fn,
# determined by the parent scope, or is set to place the cached # determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN. # Variable using the same placement as for the rest of the RNN.
with vs.variable_scope(scope or "rnn") as varscope: with vs.variable_scope(scope or "rnn") as varscope:
if context.in_graph_mode(): if not context.executing_eagerly():
if varscope.caching_device is None: if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device) varscope.set_caching_device(lambda op: op.device)
@ -1242,7 +1242,7 @@ def static_rnn(cell,
# determined by the parent scope, or is set to place the cached # determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN. # Variable using the same placement as for the rest of the RNN.
with vs.variable_scope(scope or "rnn") as varscope: with vs.variable_scope(scope or "rnn") as varscope:
if context.in_graph_mode(): if not context.executing_eagerly():
if varscope.caching_device is None: if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device) varscope.set_caching_device(lambda op: op.device)

View File

@ -128,7 +128,7 @@ def _zero_state_tensors(state_size, batch_size, dtype):
"""Combine s with batch_size to get a proper tensor shape.""" """Combine s with batch_size to get a proper tensor shape."""
c = _concat(batch_size, s) c = _concat(batch_size, s)
size = array_ops.zeros(c, dtype=dtype) size = array_ops.zeros(c, dtype=dtype)
if context.in_graph_mode(): if not context.executing_eagerly():
c_static = _concat(batch_size, s, static=True) c_static = _concat(batch_size, s, static=True)
size.set_shape(c_static) size.set_shape(c_static)
return size return size
@ -192,12 +192,13 @@ class RNNCell(base_layer.Layer):
def _rnn_get_variable(self, getter, *args, **kwargs): def _rnn_get_variable(self, getter, *args, **kwargs):
variable = getter(*args, **kwargs) variable = getter(*args, **kwargs)
if context.in_graph_mode(): if context.executing_eagerly():
trainable = (variable in tf_variables.trainable_variables() or trainable = variable._trainable # pylint: disable=protected-access
else:
trainable = (
variable in tf_variables.trainable_variables() or
(isinstance(variable, tf_variables.PartitionedVariable) and (isinstance(variable, tf_variables.PartitionedVariable) and
list(variable)[0] in tf_variables.trainable_variables())) list(variable)[0] in tf_variables.trainable_variables()))
else:
trainable = variable._trainable # pylint: disable=protected-access
if trainable and variable not in self._trainable_weights: if trainable and variable not in self._trainable_weights:
self._trainable_weights.append(variable) self._trainable_weights.append(variable)
elif not trainable and variable not in self._non_trainable_weights: elif not trainable and variable not in self._non_trainable_weights:
@ -241,7 +242,7 @@ class RNNCell(base_layer.Layer):
# Try to use the last cached zero_state. This is done to avoid recreating # Try to use the last cached zero_state. This is done to avoid recreating
# zeros, especially when eager execution is enabled. # zeros, especially when eager execution is enabled.
state_size = self.state_size state_size = self.state_size
is_eager = context.in_eager_mode() is_eager = context.executing_eagerly()
if is_eager and hasattr(self, "_last_zero_state"): if is_eager and hasattr(self, "_last_zero_state"):
(last_state_size, last_batch_size, last_dtype, (last_state_size, last_batch_size, last_dtype,
last_output) = getattr(self, "_last_zero_state") last_output) = getattr(self, "_last_zero_state")

View File

@ -317,7 +317,7 @@ def py_func(func, inp, Tout, stateful=True, name=None):
Returns: Returns:
A list of `Tensor` or a single `Tensor` which `func` computes. A list of `Tensor` or a single `Tensor` which `func` computes.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
result = func(*[x.numpy() for x in inp]) result = func(*[x.numpy() for x in inp])
result = nest.flatten(result) result = nest.flatten(result)

View File

@ -186,7 +186,7 @@ def is_variable_initialized(ref, name=None):
if ref.dtype._is_ref_dtype: if ref.dtype._is_ref_dtype:
return gen_state_ops.is_variable_initialized(ref=ref, name=name) return gen_state_ops.is_variable_initialized(ref=ref, name=name)
# Handle resource variables. # Handle resource variables.
if context.in_eager_mode() or ref.op.type == "VarHandleOp": if context.executing_eagerly() or ref.op.type == "VarHandleOp":
return gen_resource_variable_ops.var_is_initialized_op(ref.handle, return gen_resource_variable_ops.var_is_initialized_op(ref.handle,
name=name) name=name)

View File

@ -204,7 +204,7 @@ def make_template_internal(name_,
if kwargs: if kwargs:
func_ = tf_decorator.make_decorator(func_, functools.partial( func_ = tf_decorator.make_decorator(func_, functools.partial(
func_, **kwargs)) func_, **kwargs))
if context.in_eager_mode(): if context.executing_eagerly():
if unique_name_ is not None: if unique_name_ is not None:
raise ValueError( raise ValueError(
"unique_name_ cannot be used when eager exeuction is enabled.") "unique_name_ cannot be used when eager exeuction is enabled.")
@ -647,7 +647,7 @@ class EagerTemplate(Template):
Raises: Raises:
RuntimeError: if eager execution is not enabled. RuntimeError: if eager execution is not enabled.
""" """
if not context.in_eager_mode(): if not context.executing_eagerly():
raise RuntimeError( raise RuntimeError(
"{} objects can only be used when eager execution is enabled, use " "{} objects can only be used when eager execution is enabled, use "
"tf.Template for graph construction". "tf.Template for graph construction".

View File

@ -338,7 +338,7 @@ class _GraphTensorArray(object):
with ops.name_scope(name, "TensorArrayScatter", with ops.name_scope(name, "TensorArrayScatter",
[self._handle, value, indices]): [self._handle, value, indices]):
value = ops.convert_to_tensor(value, name="value") value = ops.convert_to_tensor(value, name="value")
if self._infer_shape and context.in_graph_mode(): if self._infer_shape and not context.executing_eagerly():
self._merge_element_shape(value.shape[1:]) self._merge_element_shape(value.shape[1:])
with self._maybe_colocate_with(value): with self._maybe_colocate_with(value):
flow_out = gen_data_flow_ops.tensor_array_scatter_v3( flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
@ -363,7 +363,7 @@ class _GraphTensorArray(object):
value = ops.convert_to_tensor(value, name="value") value = ops.convert_to_tensor(value, name="value")
with self._maybe_colocate_with(value): with self._maybe_colocate_with(value):
lengths_64 = math_ops.to_int64(lengths) lengths_64 = math_ops.to_int64(lengths)
if self._infer_shape and context.in_graph_mode(): if self._infer_shape and not context.executing_eagerly():
clengths = tensor_util.constant_value(lengths_64) clengths = tensor_util.constant_value(lengths_64)
if value.shape.dims is not None: if value.shape.dims is not None:
if clengths is not None and clengths.max() == clengths.min(): if clengths is not None and clengths.max() == clengths.min():
@ -774,10 +774,10 @@ class TensorArray(object):
ValueError: if both handle and tensor_array_name are provided. ValueError: if both handle and tensor_array_name are provided.
TypeError: if handle is provided but is not a Tensor. TypeError: if handle is provided but is not a Tensor.
""" """
if context.in_graph_mode(): if context.executing_eagerly():
implementation = _GraphTensorArray
else:
implementation = _EagerTensorArray implementation = _EagerTensorArray
else:
implementation = _GraphTensorArray
self._implementation = implementation( self._implementation = implementation(
dtype, dtype,

View File

@ -321,7 +321,7 @@ class _VariableStore(object):
raise ValueError( raise ValueError(
"Passed a custom_getter which is not callable: %s" % custom_getter) "Passed a custom_getter which is not callable: %s" % custom_getter)
if context.in_eager_mode(): if context.executing_eagerly():
if not self._store_eager_variables and reuse: if not self._store_eager_variables and reuse:
raise RuntimeError( raise RuntimeError(
"When eager execution is enabled variable reuse is only supported" "When eager execution is enabled variable reuse is only supported"
@ -518,7 +518,7 @@ class _VariableStore(object):
when violating reuse during variable creation, or if an existing when violating reuse during variable creation, or if an existing
sharded variable exists for the given name but with different sharding. sharded variable exists for the given name but with different sharding.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise NotImplementedError("Partitioned variables are not yet supported " raise NotImplementedError("Partitioned variables are not yet supported "
"when eager execution is enabled.") "when eager execution is enabled.")
@ -798,7 +798,7 @@ class _VariableStore(object):
validate_shape=validate_shape, validate_shape=validate_shape,
constraint=constraint, constraint=constraint,
use_resource=use_resource) use_resource=use_resource)
if context.in_graph_mode() or self._store_eager_variables: if not context.executing_eagerly() or self._store_eager_variables:
# In eager mode we do not want to keep default references to Variable # In eager mode we do not want to keep default references to Variable
# objects as this will prevent their memory from being released. # objects as this will prevent their memory from being released.
self._vars[name] = v self._vars[name] = v
@ -811,12 +811,12 @@ class _VariableStore(object):
with ops.name_scope(name + "/Regularizer/"): with ops.name_scope(name + "/Regularizer/"):
loss = regularizer(v) loss = regularizer(v)
if loss is not None: if loss is not None:
if context.in_graph_mode(): if context.executing_eagerly():
v_name = v.name
loss_name = loss.name
else:
v_name = "v_%s" % type(v) v_name = "v_%s" % type(v)
loss_name = "loss_%s" % type(loss) loss_name = "loss_%s" % type(loss)
else:
v_name = v.name
loss_name = loss.name
logging.vlog(1, "Applied regularizer to %s and added the result %s " logging.vlog(1, "Applied regularizer to %s and added the result %s "
"to REGULARIZATION_LOSSES.", v_name, loss_name) "to REGULARIZATION_LOSSES.", v_name, loss_name)
ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, loss) ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, loss)
@ -920,7 +920,7 @@ class VariableScope(object):
self._dtype = dtype self._dtype = dtype
self._use_resource = use_resource self._use_resource = use_resource
self._constraint = constraint self._constraint = constraint
if context.in_eager_mode(): if context.executing_eagerly():
if self._caching_device is not None: if self._caching_device is not None:
raise NotImplementedError("Caching devices is not yet supported " raise NotImplementedError("Caching devices is not yet supported "
"when eager execution is enabled.") "when eager execution is enabled.")
@ -988,7 +988,7 @@ class VariableScope(object):
def set_use_resource(self, use_resource): def set_use_resource(self, use_resource):
"""Sets whether to use ResourceVariables for this scope.""" """Sets whether to use ResourceVariables for this scope."""
if context.in_eager_mode() and not use_resource: if context.executing_eagerly() and not use_resource:
raise ValueError("When eager execution is enabled, " raise ValueError("When eager execution is enabled, "
"use_resource cannot be set to false.") "use_resource cannot be set to false.")
self._use_resource = use_resource self._use_resource = use_resource
@ -999,14 +999,14 @@ class VariableScope(object):
def set_caching_device(self, caching_device): def set_caching_device(self, caching_device):
"""Set caching_device for this scope.""" """Set caching_device for this scope."""
if context.in_eager_mode(): if context.executing_eagerly():
raise NotImplementedError("Caching devices are not yet supported " raise NotImplementedError("Caching devices are not yet supported "
"when eager execution is enabled.") "when eager execution is enabled.")
self._caching_device = caching_device self._caching_device = caching_device
def set_partitioner(self, partitioner): def set_partitioner(self, partitioner):
"""Set partitioner for this scope.""" """Set partitioner for this scope."""
if partitioner and context.in_eager_mode(): if partitioner and context.executing_eagerly():
raise NotImplementedError("Partitioned variables are not yet supported " raise NotImplementedError("Partitioned variables are not yet supported "
"when eager execution is enabled.") "when eager execution is enabled.")
self._partitioner = partitioner self._partitioner = partitioner
@ -1057,14 +1057,14 @@ class VariableScope(object):
partitioner = self._partitioner partitioner = self._partitioner
if custom_getter is None: if custom_getter is None:
custom_getter = self._custom_getter custom_getter = self._custom_getter
if context.in_graph_mode(): if context.executing_eagerly():
reuse = False
use_resource = True
else:
if reuse is None: if reuse is None:
reuse = self._reuse reuse = self._reuse
if use_resource is None: if use_resource is None:
use_resource = self._use_resource use_resource = self._use_resource
else:
reuse = False
use_resource = True
full_name = self.name + "/" + name if self.name else name full_name = self.name + "/" + name if self.name else name
# Variable names only depend on variable_scope (full_name here), # Variable names only depend on variable_scope (full_name here),
@ -1107,7 +1107,7 @@ class VariableScope(object):
use_resource=None, use_resource=None,
constraint=None): constraint=None):
"""Gets an existing variable with this name or create a new one.""" """Gets an existing variable with this name or create a new one."""
if context.in_eager_mode(): if context.executing_eagerly():
raise NotImplementedError("Partitioned variables are not yet supported " raise NotImplementedError("Partitioned variables are not yet supported "
"when eager execution is enabled.") "when eager execution is enabled.")
if initializer is None: if initializer is None:
@ -1871,7 +1871,7 @@ class variable_scope(object):
raise ValueError("The reuse parameter must be True or False or None.") raise ValueError("The reuse parameter must be True or False or None.")
if self._values is None: if self._values is None:
self._values = [] self._values = []
self._in_graph_mode = not context.in_eager_mode() self._in_graph_mode = not context.executing_eagerly()
if self._in_graph_mode: if self._in_graph_mode:
self._graph = ops._get_graph_from_inputs(self._values) # pylint: disable=protected-access self._graph = ops._get_graph_from_inputs(self._values) # pylint: disable=protected-access
self._cached_pure_variable_scope = None self._cached_pure_variable_scope = None
@ -2111,13 +2111,13 @@ def default_variable_creator(next_creator=None, **kwargs):
use_resource = kwargs.get("use_resource", None) use_resource = kwargs.get("use_resource", None)
if use_resource is None: if use_resource is None:
use_resource = get_variable_scope().use_resource use_resource = get_variable_scope().use_resource
if use_resource or (use_resource is None and context.in_eager_mode()): if use_resource or (use_resource is None and context.executing_eagerly()):
return resource_variable_ops.ResourceVariable( return resource_variable_ops.ResourceVariable(
initial_value=initial_value, trainable=trainable, initial_value=initial_value, trainable=trainable,
collections=collections, validate_shape=validate_shape, collections=collections, validate_shape=validate_shape,
caching_device=caching_device, name=name, dtype=dtype, caching_device=caching_device, name=name, dtype=dtype,
constraint=constraint) constraint=constraint)
elif not use_resource and context.in_eager_mode(): elif not use_resource and context.executing_eagerly():
raise RuntimeError( raise RuntimeError(
"VariableScope should use resource variable when eager execution is" "VariableScope should use resource variable when eager execution is"
" enabled, but use_resource is False." " enabled, but use_resource is False."

View File

@ -210,10 +210,11 @@ class Variable(checkpointable.CheckpointableBase):
for details on how variables work in eager execution. for details on how variables work in eager execution.
@end_compatibility @end_compatibility
""" """
if not context.in_graph_mode(): if context.executing_eagerly():
raise RuntimeError("tf.Variable not supported in Eager mode. " raise RuntimeError(
"Please use tfe.Variable instead") "tf.Variable not supported when eager execution is enabled. "
self._in_graph_mode = context.in_graph_mode() "Please use tf.contrib.eager.Variable instead")
self._in_graph_mode = True
if variable_def: if variable_def:
# If variable_def is provided, recreates the variable from its fields. # If variable_def is provided, recreates the variable from its fields.
if initial_value: if initial_value:
@ -234,7 +235,7 @@ class Variable(checkpointable.CheckpointableBase):
constraint=constraint) constraint=constraint)
def __repr__(self): def __repr__(self):
if context.in_eager_mode(): if context.executing_eagerly():
return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % ( return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % (
self.name, self.get_shape(), self.dtype.name, self.name, self.get_shape(), self.dtype.name,
ops.numpy_text(self.read_value(), is_repr=True)) ops.numpy_text(self.read_value(), is_repr=True))
@ -740,15 +741,15 @@ class Variable(checkpointable.CheckpointableBase):
Raises: Raises:
ValueError: Session is not passed and no default session ValueError: Session is not passed and no default session
""" """
if context.in_graph_mode(): if context.executing_eagerly():
self.assign(value)
else:
session = session or ops.get_default_session() session = session or ops.get_default_session()
if session is None: if session is None:
raise ValueError( raise ValueError(
"Either session argument should be provided or default session " "Either session argument should be provided or default session "
"should be established") "should be established")
session.run(self._initializer_op, {self._initializer_op.inputs[1]: value}) session.run(self._initializer_op, {self._initializer_op.inputs[1]: value})
else:
self.assign(value)
# Conversion to tensor. # Conversion to tensor.
@staticmethod @staticmethod
@ -1248,9 +1249,9 @@ class PartitionedVariable(object):
information does not match `shape`, or `partitions` has invalid values. information does not match `shape`, or `partitions` has invalid values.
RuntimeError: If eager execution is enabled RuntimeError: If eager execution is enabled
""" """
if not context.in_graph_mode(): if context.executing_eagerly():
raise RuntimeError("tf.PartitionedVariable not supported in " raise RuntimeError(
"eager mode. Please use tfe.Variable instead") "tf.PartitionedVariable not supported with eager execution enabled.")
if not isinstance(variable_list, (list, tuple)): if not isinstance(variable_list, (list, tuple)):
raise TypeError( raise TypeError(
"variable_list is not a list or tuple: %s" % variable_list) "variable_list is not a list or tuple: %s" % variable_list)
@ -1541,7 +1542,7 @@ def variables_initializer(var_list, name="init"):
Returns: Returns:
An Op that run the initializers of all the specified variables. An Op that run the initializers of all the specified variables.
""" """
if var_list and context.in_graph_mode(): if var_list and not context.executing_eagerly():
return control_flow_ops.group(*[v.initializer for v in var_list], name=name) return control_flow_ops.group(*[v.initializer for v in var_list], name=name)
return control_flow_ops.no_op(name=name) return control_flow_ops.no_op(name=name)
@ -1563,7 +1564,7 @@ def global_variables_initializer():
Returns: Returns:
An Op that initializes global variables in the graph. An Op that initializes global variables in the graph.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
return control_flow_ops.no_op(name="global_variables_initializer") return control_flow_ops.no_op(name="global_variables_initializer")
return variables_initializer(global_variables()) return variables_initializer(global_variables())
@ -1585,7 +1586,7 @@ def local_variables_initializer():
Returns: Returns:
An Op that initializes all local variables in the graph. An Op that initializes all local variables in the graph.
""" """
if context.in_eager_mode(): if context.executing_eagerly():
return control_flow_ops.no_op(name="local_variables_initializer") return control_flow_ops.no_op(name="local_variables_initializer")
return variables_initializer(local_variables()) return variables_initializer(local_variables())

View File

@ -172,7 +172,7 @@ class Profiler(object):
op_log: optional. tensorflow::tfprof::OpLogProto proto. Used to define op_log: optional. tensorflow::tfprof::OpLogProto proto. Used to define
extra op types. extra op types.
""" """
if not graph and context.in_graph_mode(): if not graph and not context.executing_eagerly():
graph = ops.get_default_graph() graph = ops.get_default_graph()
self._coverage = 0.0 self._coverage = 0.0
self._graph = graph self._graph = graph
@ -336,7 +336,7 @@ def profile(graph=None,
If cmd is 'op' or 'code', returns MultiGraphNodeProto proto. If cmd is 'op' or 'code', returns MultiGraphNodeProto proto.
Side effect: stdout/file/timeline.json depending on options['output'] Side effect: stdout/file/timeline.json depending on options['output']
""" """
if not graph and context.in_graph_mode(): if not graph and not context.executing_eagerly():
graph = ops.get_default_graph() graph = ops.get_default_graph()
if options == _DEFAULT_PROFILE_OPTIONS: if options == _DEFAULT_PROFILE_OPTIONS:

View File

@ -156,7 +156,7 @@ def merge_default_with_oplog(graph, op_log=None, run_meta=None,
Returns: Returns:
tmp_op_log: Merged OpLogProto proto. tmp_op_log: Merged OpLogProto proto.
""" """
if not graph and context.in_graph_mode(): if not graph and not context.executing_eagerly():
graph = ops.get_default_graph() graph = ops.get_default_graph()
tmp_op_log = tfprof_log_pb2.OpLogProto() tmp_op_log = tfprof_log_pb2.OpLogProto()
@ -210,7 +210,7 @@ def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True):
add_trace: Whether to add python code trace information. add_trace: Whether to add python code trace information.
Used to support "code" view. Used to support "code" view.
""" """
if not graph and context.in_graph_mode(): if not graph and not context.executing_eagerly():
graph = ops.get_default_graph() graph = ops.get_default_graph()
op_log = merge_default_with_oplog(graph, op_log, run_meta, add_trace) op_log = merge_default_with_oplog(graph, op_log, run_meta, add_trace)

View File

@ -278,7 +278,7 @@ def merge(inputs, collections=None, name=None):
@end_compatbility @end_compatbility
""" """
# pylint: enable=line-too-long # pylint: enable=line-too-long
if _context.in_eager_mode(): if _context.executing_eagerly():
raise RuntimeError( raise RuntimeError(
'Merging tf.summary.* ops is not compatible with eager execution. ' 'Merging tf.summary.* ops is not compatible with eager execution. '
'Use tf.contrib.summary instead.') 'Use tf.contrib.summary instead.')
@ -311,7 +311,7 @@ def merge_all(key=_ops.GraphKeys.SUMMARIES, scope=None):
summaries under eager execution, use `tf.contrib.summary` instead. summaries under eager execution, use `tf.contrib.summary` instead.
@end_compatbility @end_compatbility
""" """
if _context.in_eager_mode(): if _context.executing_eagerly():
raise RuntimeError( raise RuntimeError(
'Merging tf.summary.* ops is not compatible with eager execution. ' 'Merging tf.summary.* ops is not compatible with eager execution. '
'Use tf.contrib.summary instead.') 'Use tf.contrib.summary instead.')

View File

@ -343,7 +343,7 @@ class FileWriter(SummaryToEventTransformer):
summaries under eager execution, use `tf.contrib.summary` instead. summaries under eager execution, use `tf.contrib.summary` instead.
@end_compatbility @end_compatbility
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError( raise RuntimeError(
"tf.summary.FileWriter is not compatible with eager execution. " "tf.summary.FileWriter is not compatible with eager execution. "
"Use tf.contrib.summary instead.") "Use tf.contrib.summary instead.")

View File

@ -106,10 +106,10 @@ class AdamOptimizer(optimizer.Optimizer):
self._updated_lr = None self._updated_lr = None
def _get_beta_accumulators(self): def _get_beta_accumulators(self):
if context.in_graph_mode(): if context.executing_eagerly():
graph = ops.get_default_graph()
else:
graph = None graph = None
else:
graph = ops.get_default_graph()
return (self._get_non_slot_variable("beta1_power", graph=graph), return (self._get_non_slot_variable("beta1_power", graph=graph),
self._get_non_slot_variable("beta2_power", graph=graph)) self._get_non_slot_variable("beta2_power", graph=graph))

View File

@ -184,7 +184,7 @@ class AdamOptimizerTest(test.TestCase):
# Shouldn't return non-slot variables from other graphs. # Shouldn't return non-slot variables from other graphs.
self.assertEqual(0, len(opt.variables())) self.assertEqual(0, len(opt.variables()))
if context.in_graph_mode(): if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer()) self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0)) self.assertAllClose([1.0, 2.0], self.evaluate(var0))
@ -194,7 +194,7 @@ class AdamOptimizerTest(test.TestCase):
# Run 3 steps of Adam # Run 3 steps of Adam
for t in range(1, 4): for t in range(1, 4):
if context.in_graph_mode(): if not context.executing_eagerly():
self.evaluate(update) self.evaluate(update)
elif t > 1: elif t > 1:
opt.apply_gradients(zip([grads0, grads1], [var0, var1])) opt.apply_gradients(zip([grads0, grads1], [var0, var1]))

View File

@ -208,7 +208,7 @@ class _CheckpointPosition(object):
# Name saveables based on the name this object had when it was checkpointed. # Name saveables based on the name this object had when it was checkpointed.
named_saveables = {} named_saveables = {}
restore_ops = [] restore_ops = []
in_graph_mode = context.in_graph_mode() building_graph = not context.executing_eagerly()
for serialized_tensor in self.object_proto.attributes: for serialized_tensor in self.object_proto.attributes:
saveable_object = saveables.get(serialized_tensor.name, None) saveable_object = saveables.get(serialized_tensor.name, None)
if saveable_object is None: if saveable_object is None:
@ -219,7 +219,7 @@ class _CheckpointPosition(object):
self._checkpoint.unused_attributes.setdefault( self._checkpoint.unused_attributes.setdefault(
self.checkpointable, []).append(serialized_tensor.name) self.checkpointable, []).append(serialized_tensor.name)
continue continue
if in_graph_mode: if building_graph:
existing_ops = self._checkpoint.restore_ops_by_name.get( existing_ops = self._checkpoint.restore_ops_by_name.get(
serialized_tensor.name, None) serialized_tensor.name, None)
else: else:
@ -245,7 +245,7 @@ class _CheckpointPosition(object):
saveable_index:saveable_index + num_specs] saveable_index:saveable_index + num_specs]
saveable_index += num_specs saveable_index += num_specs
restore_op = saveable.restore(saveable_tensors, restored_shapes=None) restore_op = saveable.restore(saveable_tensors, restored_shapes=None)
if in_graph_mode: if building_graph:
assert saveable.name not in self._checkpoint.restore_ops_by_name assert saveable.name not in self._checkpoint.restore_ops_by_name
self._checkpoint.restore_ops_by_name[saveable.name] = restore_op self._checkpoint.restore_ops_by_name[saveable.name] = restore_op
restore_ops.append(restore_op) restore_ops.append(restore_op)
@ -388,7 +388,7 @@ class CheckpointableBase(object):
"Checkpointable._add_variable called to create another with " "Checkpointable._add_variable called to create another with "
"that name. Variable names must be unique within a Checkpointable " "that name. Variable names must be unique within a Checkpointable "
"object.") % (name,)) "object.") % (name,))
if context.in_eager_mode(): if context.executing_eagerly():
# If this is a variable with a single Tensor stored in the checkpoint, we # If this is a variable with a single Tensor stored in the checkpoint, we
# can set that value as an initializer rather than initializing and then # can set that value as an initializer rather than initializing and then
# assigning (when executing eagerly). This call returns None if there is # assigning (when executing eagerly). This call returns None if there is

View File

@ -71,6 +71,6 @@ class GradientDescentOptimizer(optimizer.Optimizer):
return var.scatter_sub(delta, use_locking=self._use_locking) return var.scatter_sub(delta, use_locking=self._use_locking)
def _prepare(self): def _prepare(self):
if context.in_graph_mode() or self._learning_rate_tensor is None: if not context.executing_eagerly() or self._learning_rate_tensor is None:
self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate, self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
name="learning_rate") name="learning_rate")

View File

@ -159,7 +159,7 @@ def input_producer(input_tensor,
enabled. Please use the `tf.data` API to ingest data under eager execution. enabled. Please use the `tf.data` API to ingest data under eager execution.
@end_compatibility @end_compatibility
""" """
if context.in_eager_mode(): if context.executing_eagerly():
raise RuntimeError( raise RuntimeError(
"Input pipelines based on Queues are not supported when eager execution" "Input pipelines based on Queues are not supported when eager execution"
" is enabled. Please use tf.data to ingest data into your model" " is enabled. Please use tf.data to ingest data into your model"
@ -737,7 +737,7 @@ def _batch(tensors, batch_size, keep_input, num_threads=1, capacity=32,
allow_smaller_final_batch=False, shared_name=None, allow_smaller_final_batch=False, shared_name=None,
name=None): name=None):
"""Helper function for `batch` and `maybe_batch`.""" """Helper function for `batch` and `maybe_batch`."""
if context.in_eager_mode(): if context.executing_eagerly():
raise ValueError( raise ValueError(
"Input pipelines based on Queues are not supported when eager execution" "Input pipelines based on Queues are not supported when eager execution"
" is enabled. Please use tf.data to ingest data into your model" " is enabled. Please use tf.data to ingest data into your model"
@ -775,7 +775,7 @@ def _batch_join(tensors_list, batch_size, keep_input, capacity=32,
enqueue_many=False, shapes=None, dynamic_pad=False, enqueue_many=False, shapes=None, dynamic_pad=False,
allow_smaller_final_batch=False, shared_name=None, name=None): allow_smaller_final_batch=False, shared_name=None, name=None):
"""Helper function for `batch_join` and `maybe_batch_join`.""" """Helper function for `batch_join` and `maybe_batch_join`."""
if context.in_eager_mode(): if context.executing_eagerly():
raise ValueError( raise ValueError(
"Input pipelines based on Queues are not supported when eager execution" "Input pipelines based on Queues are not supported when eager execution"
" is enabled. Please use tf.data to ingest data into your model" " is enabled. Please use tf.data to ingest data into your model"
@ -810,7 +810,7 @@ def _shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
shapes=None, allow_smaller_final_batch=False, shapes=None, allow_smaller_final_batch=False,
shared_name=None, name=None): shared_name=None, name=None):
"""Helper function for `shuffle_batch` and `maybe_shuffle_batch`.""" """Helper function for `shuffle_batch` and `maybe_shuffle_batch`."""
if context.in_eager_mode(): if context.executing_eagerly():
raise ValueError( raise ValueError(
"Input pipelines based on Queues are not supported when eager execution" "Input pipelines based on Queues are not supported when eager execution"
" is enabled. Please use tf.data to ingest data into your model" " is enabled. Please use tf.data to ingest data into your model"
@ -855,7 +855,7 @@ def _shuffle_batch_join(tensors_list, batch_size, capacity,
allow_smaller_final_batch=False, shared_name=None, allow_smaller_final_batch=False, shared_name=None,
name=None): name=None):
"""Helper function for `shuffle_batch_join` and `maybe_shuffle_batch_join`.""" """Helper function for `shuffle_batch_join` and `maybe_shuffle_batch_join`."""
if context.in_eager_mode(): if context.executing_eagerly():
raise ValueError( raise ValueError(
"Input pipelines based on Queues are not supported when eager execution" "Input pipelines based on Queues are not supported when eager execution"
" is enabled. Please use tf.data to ingest data into your model" " is enabled. Please use tf.data to ingest data into your model"

View File

@ -113,7 +113,7 @@ class LRDecayTest(test_util.TensorFlowTestCase):
learning_rate_decay.piecewise_constant(x, boundaries, values) learning_rate_decay.piecewise_constant(x, boundaries, values)
# Test that ref types are valid. # Test that ref types are valid.
if context.in_graph_mode(): if not context.executing_eagerly():
x = variables.Variable(0.0) x = variables.Variable(0.0)
x_ref = x.op.outputs[0] # float32_ref tensor should be accepted x_ref = x.op.outputs[0] # float32_ref tensor should be accepted
boundaries, values = [1.0, 2.0], [1, 2, 3] boundaries, values = [1.0, 2.0], [1, 2, 3]

Some files were not shown because too many files have changed in this diff Show More