eager: Rename in_eager_mode to executing_eagerly and get rid of in_graph_mode.
This is in preparation to introduce one public, stable symbol: tf.executing_eagerly() (i.e., part of moving APIs related to eager execution from "contrib" to a namespace where we provide API stability guarantees) PiperOrigin-RevId: 188212646
This commit is contained in:
parent
808b569e85
commit
37cef895bf
tensorflow
contrib
data/python/ops
eager/python
checkpointable_utils.pycheckpointable_utils_test.pydatasets.pyevaluator.pymetrics_impl.pynetwork.pysaver.pytfe.pytfe_test.py
framework/python/ops
layers/python/layers
metrics/python/ops
nccl/python/ops
opt/python/training
rnn/python/kernel_tests
summary
python
data
eager
benchmarks_test.pycontext.pycore_test.pyfunction.pygraph_callable.pypython_eager_op_gen.ccpywrap_tfe_test.py
estimator
framework
constant_op.pyfunction.pymeta_graph.pyops.pyops_test.pyrandom_seed.pyrandom_seed_test.pytensor_util.pytest_util.py
keras/_impl/keras
kernel_tests
atrous_convolution_test.pycheck_ops_test.pypy_func_test.pyresource_variable_ops_test.pyrnn_test.pyslice_op_test.pytemplate_test.pytensor_array_ops_test.pyvariable_scope_test.py
layers
ops
array_grad.pyarray_ops.pycheck_ops.pycontrol_flow_ops.pycustom_gradient.pydata_flow_ops.pyfunctional_ops.pygradients_impl.pyio_ops.pylookup_ops.py
losses
math_grad.pymath_ops.pymath_ops_test.pymetrics_impl.pynn_grad.pynn_ops.pynumerics.pyresource_variable_ops.pyrnn.pyrnn_cell_impl.pyscript_ops.pystate_ops.pytemplate.pytensor_array_ops.pyvariable_scope.pyvariables.pyprofiler
summary
training
@ -44,7 +44,7 @@ class PrivateThreadPool(object):
|
||||
|
||||
def __init__(self, num_threads, display_name=None):
|
||||
"""Creates a `PrivateThreadPool` with the given number of threads."""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
shared_name = _generate_shared_name("privatethreadpool")
|
||||
self._resource = gen_dataset_ops.thread_pool_handle(
|
||||
num_threads=num_threads,
|
||||
|
@ -395,7 +395,7 @@ class CheckpointLoadStatus(_LoadStatus):
|
||||
|
||||
def run_restore_ops(self, session=None):
|
||||
"""Run operations to restore objects in the dependency graph."""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return # Run eagerly
|
||||
if session is None:
|
||||
session = ops.get_default_session()
|
||||
@ -459,7 +459,7 @@ class InitializationOnlyStatus(_LoadStatus):
|
||||
session: The session to run initialization ops in. If `None`, uses the
|
||||
default session.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return # run eagerly
|
||||
if session is None:
|
||||
session = ops.get_default_session()
|
||||
@ -491,7 +491,7 @@ class NameBasedSaverStatus(_LoadStatus):
|
||||
date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS)
|
||||
def run_restore_ops(self, session=None):
|
||||
"""Load the name-based training checkpoint using a new `tf.train.Saver`."""
|
||||
if session is None and context.in_graph_mode():
|
||||
if session is None and not context.executing_eagerly():
|
||||
session = ops.get_default_session()
|
||||
saver_lib.Saver(self._object_saver._global_variable_names()).restore( # pylint: disable=protected-access
|
||||
sess=session, save_path=self._save_path)
|
||||
@ -548,7 +548,7 @@ class CheckpointableSaver(object):
|
||||
# Allow passing in a weak reference to avoid reference cycles when
|
||||
# `Checkpointable` objects save themselves.
|
||||
self._root_checkpointable_ref = root_checkpointable
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
with ops.device("/cpu:0"):
|
||||
self._file_prefix_placeholder = constant_op.constant("model")
|
||||
else:
|
||||
@ -597,7 +597,7 @@ class CheckpointableSaver(object):
|
||||
"""
|
||||
named_variables, graph_proto = _serialize_object_graph(
|
||||
self._root_checkpointable)
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
in_graph_mode = not context.executing_eagerly()
|
||||
if in_graph_mode:
|
||||
if session is None:
|
||||
session = ops.get_default_session()
|
||||
@ -714,7 +714,7 @@ class CheckpointableSaver(object):
|
||||
"""
|
||||
if save_path is None:
|
||||
return InitializationOnlyStatus(self._root_checkpointable)
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
in_graph_mode = not context.executing_eagerly()
|
||||
if in_graph_mode:
|
||||
if session is None:
|
||||
session = ops.get_default_session()
|
||||
@ -850,7 +850,7 @@ class Checkpoint(core_checkpointable.Checkpointable):
|
||||
|
||||
def save(self, file_prefix, session=None):
|
||||
"""Save a checkpoint. Wraps `tfe.CheckpointableSaver.save`."""
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
in_graph_mode = not context.executing_eagerly()
|
||||
if in_graph_mode:
|
||||
if session is None:
|
||||
session = ops.get_default_session()
|
||||
|
@ -108,14 +108,14 @@ class InterfaceTests(test.TestCase):
|
||||
[0., 0.]], self.evaluate(bare_initializer))
|
||||
self.assertEqual("a_variable:0", obj.a_variable.name)
|
||||
self.assertEqual("duplicate:0", other_duplicate.name)
|
||||
if context.in_graph_mode():
|
||||
# The .name attribute may be globally influenced, but the checkpoint name
|
||||
# won't be (tested below).
|
||||
self.assertEqual("duplicate_1:0", duplicate.name)
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
# When executing eagerly, there's no uniquification of variable names. The
|
||||
# checkpoint name will be the same.
|
||||
self.assertEqual("duplicate:0", duplicate.name)
|
||||
else:
|
||||
# The .name attribute may be globally influenced, but the checkpoint name
|
||||
# won't be (tested below).
|
||||
self.assertEqual("duplicate_1:0", duplicate.name)
|
||||
named_variables, _ = checkpointable_utils._serialize_object_graph(obj)
|
||||
expected_checkpoint_names = (
|
||||
"a_variable/.ATTRIBUTES/VARIABLE_VALUE",
|
||||
@ -165,7 +165,7 @@ class CheckpointingTests(test.TestCase):
|
||||
optimizer_step = training_util.get_or_create_global_step()
|
||||
root_checkpointable = checkpointable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model, optimizer_step=optimizer_step)
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
optimizer.minimize(
|
||||
lambda: model(input_value),
|
||||
global_step=optimizer_step)
|
||||
@ -268,7 +268,7 @@ class CheckpointingTests(test.TestCase):
|
||||
root_checkpointable = checkpointable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model)
|
||||
input_value = constant_op.constant([[3.]])
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
optimizer.minimize(
|
||||
lambda: model(input_value))
|
||||
else:
|
||||
@ -293,7 +293,7 @@ class CheckpointingTests(test.TestCase):
|
||||
self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1]))
|
||||
self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter))
|
||||
self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
return # Restore-on-create is only supported when executing eagerly
|
||||
on_create_model = MyModel()
|
||||
on_create_optimizer = adam.AdamOptimizer(0.001)
|
||||
@ -400,7 +400,7 @@ class CheckpointingTests(test.TestCase):
|
||||
optimizer.minimize,
|
||||
functools.partial(model, input_value),
|
||||
global_step=root.global_step)
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
train_fn = functools.partial(self.evaluate, train_fn())
|
||||
status.initialize_or_restore()
|
||||
for _ in range(num_training_steps):
|
||||
@ -524,7 +524,9 @@ class CheckpointingTests(test.TestCase):
|
||||
root.var = checkpointable_utils.add_variable(
|
||||
root, name="var", initializer=0.)
|
||||
optimizer = adam.AdamOptimizer(0.1)
|
||||
if context.in_graph_mode():
|
||||
if context.executing_eagerly():
|
||||
optimizer.minimize(root.var.read_value)
|
||||
else:
|
||||
train_op = optimizer.minimize(root.var)
|
||||
# Note that `optimizer` has not been added as a dependency of
|
||||
# `root`. Create a one-off grouping so that slot variables for `root.var`
|
||||
@ -532,8 +534,6 @@ class CheckpointingTests(test.TestCase):
|
||||
self.evaluate(checkpointable_utils.gather_initializers(
|
||||
checkpointable_utils.Checkpoint(root=root, optimizer=optimizer)))
|
||||
self.evaluate(train_op)
|
||||
else:
|
||||
optimizer.minimize(root.var.read_value)
|
||||
self.evaluate(state_ops.assign(root.var, 12.))
|
||||
no_slots_path = checkpointable_utils.CheckpointableSaver(root).save(
|
||||
os.path.join(checkpoint_directory, "no_slots"))
|
||||
@ -561,7 +561,7 @@ class CheckpointingTests(test.TestCase):
|
||||
with self.assertRaisesRegexp(AssertionError, "beta1_power"):
|
||||
slot_status.assert_consumed()
|
||||
self.assertEqual(12., self.evaluate(new_root.var))
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
# Slot variables are only created with restoring initializers when
|
||||
# executing eagerly.
|
||||
self.assertEqual(14., self.evaluate(
|
||||
@ -569,7 +569,9 @@ class CheckpointingTests(test.TestCase):
|
||||
else:
|
||||
self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
|
||||
None)
|
||||
if context.in_graph_mode():
|
||||
if context.executing_eagerly():
|
||||
new_root.optimizer.minimize(new_root.var.read_value)
|
||||
else:
|
||||
train_op = new_root.optimizer.minimize(new_root.var)
|
||||
# The slot variable now exists; restore() didn't create it, but we should
|
||||
# now have a restore op for it.
|
||||
@ -577,8 +579,6 @@ class CheckpointingTests(test.TestCase):
|
||||
self.assertEqual(14., self.evaluate(
|
||||
new_root.optimizer.get_slot(name="m", var=new_root.var)))
|
||||
self.evaluate(train_op)
|
||||
else:
|
||||
new_root.optimizer.minimize(new_root.var.read_value)
|
||||
slot_status.assert_consumed()
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
|
@ -68,7 +68,7 @@ class Iterator(object):
|
||||
RuntimeError: When invoked without eager execution enabled.
|
||||
"""
|
||||
|
||||
if not context.in_eager_mode():
|
||||
if not context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
"{} objects can only be used when eager execution is enabled, use "
|
||||
"tf.data.Dataset.make_initializable_iterator or "
|
||||
|
@ -57,7 +57,7 @@ class Evaluator(object):
|
||||
self._model = model
|
||||
self._metrics = {}
|
||||
self._evaluators = {}
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.call = function.defun(self.call)
|
||||
|
||||
# ---- API for users ----
|
||||
@ -90,7 +90,7 @@ class Evaluator(object):
|
||||
Only for graph execution.
|
||||
@end_compatibility
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError("Evaluator.init_variables() not needed when "
|
||||
"eager execution is enabled.")
|
||||
return control_flow_ops.group([m.init_variables() for _, m in self.metrics])
|
||||
@ -113,7 +113,8 @@ class Evaluator(object):
|
||||
with summary_ops.create_file_writer(
|
||||
summary_logdir).as_default(), summary_ops.always_record_summaries():
|
||||
return self._all_metric_results()
|
||||
if context.in_eager_mode():
|
||||
|
||||
if context.executing_eagerly():
|
||||
return f()
|
||||
else:
|
||||
return function.defun(f)()
|
||||
@ -158,16 +159,16 @@ class Evaluator(object):
|
||||
@end_compatibility
|
||||
"""
|
||||
summary_logdir = kwargs.pop("summary_logdir", None)
|
||||
if context.in_graph_mode():
|
||||
call_op = self.__call__(dataset.make_one_shot_iterator().get_next(),
|
||||
*args, **kwargs)
|
||||
init_op = self.init_variables()
|
||||
results_op = self.all_metric_results(summary_logdir)
|
||||
return (init_op, call_op, results_op)
|
||||
# Eager case
|
||||
for example in datasets.Iterator(dataset):
|
||||
self.__call__(example, *args, **kwargs)
|
||||
return self.all_metric_results(summary_logdir)
|
||||
if context.executing_eagerly():
|
||||
for example in datasets.Iterator(dataset):
|
||||
self.__call__(example, *args, **kwargs)
|
||||
return self.all_metric_results(summary_logdir)
|
||||
# Graph construction
|
||||
call_op = self.__call__(dataset.make_one_shot_iterator().get_next(), *args,
|
||||
**kwargs)
|
||||
init_op = self.init_variables()
|
||||
results_op = self.all_metric_results(summary_logdir)
|
||||
return (init_op, call_op, results_op)
|
||||
|
||||
@staticmethod
|
||||
def run_evaluation(init_op, call_op, results_op, sess=None):
|
||||
@ -192,7 +193,7 @@ class Evaluator(object):
|
||||
Only for graph execution.
|
||||
@end_compatibility
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError("Evaluator.run_evaluation() not supported when "
|
||||
"eager execution is enabled.")
|
||||
sess = sess or ops.get_default_session()
|
||||
|
@ -109,13 +109,13 @@ class Metric(checkpointable.CheckpointableBase):
|
||||
pos = scope.name.rfind(scope_name)
|
||||
self._name = name + scope.name[pos + len(scope_name):]
|
||||
self._scope = scope
|
||||
if context.in_graph_mode():
|
||||
if context.executing_eagerly():
|
||||
self._construction_scope = context.eager_mode
|
||||
else:
|
||||
# We make self.call() into a graph callable here, so that we can
|
||||
# return a single op that performs all of the variable updates.
|
||||
self._construction_scope = ops.get_default_graph().as_default
|
||||
self.call = function.defun(self.call)
|
||||
else:
|
||||
self._construction_scope = context.eager_mode
|
||||
|
||||
# ---- API for users ----
|
||||
def __call__(self, *args, **kwargs):
|
||||
@ -156,10 +156,11 @@ class Metric(checkpointable.CheckpointableBase):
|
||||
initialization. Under eager execution, the variables are reset to their
|
||||
initial values as a side effect and this function returns None.
|
||||
"""
|
||||
if context.in_graph_mode():
|
||||
if context.executing_eagerly():
|
||||
for v in self._vars:
|
||||
v.assign(self._initial_values[v])
|
||||
else:
|
||||
return control_flow_ops.group([v.initializer for v in self._vars])
|
||||
for v in self._vars:
|
||||
v.assign(self._initial_values[v])
|
||||
|
||||
# ---- To be implemented by descendants ---
|
||||
def build(self, *args, **kwargs):
|
||||
@ -201,10 +202,10 @@ class Metric(checkpointable.CheckpointableBase):
|
||||
|
||||
def value(self):
|
||||
"""In graph mode returns the result Tensor while in eager the callable."""
|
||||
if context.in_graph_mode():
|
||||
return self.result()
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
return self.result
|
||||
else:
|
||||
return self.result()
|
||||
|
||||
# We can support two different strategies of for doing data-parallel
|
||||
# distributed metric computations:
|
||||
@ -246,7 +247,7 @@ class Metric(checkpointable.CheckpointableBase):
|
||||
"""***Only for use by descendants of Metric***."""
|
||||
if self._built:
|
||||
raise RuntimeError("Can't call add_variable() except in build().")
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
collections = None
|
||||
else:
|
||||
if self._use_global_variables:
|
||||
@ -270,7 +271,7 @@ class Metric(checkpointable.CheckpointableBase):
|
||||
# Checkpointable.
|
||||
overwrite=True)
|
||||
self._vars.append(v)
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
self._initial_values[v] = v.value()
|
||||
return v
|
||||
|
||||
|
@ -639,7 +639,7 @@ def _make_custom_getter_for_deferred_restorations():
|
||||
# Mark as already restored from this checkpoint.
|
||||
delayed_restoration.checkpointed_variables_to_restore[
|
||||
checkpoint_name] = None
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
delayed_restoration.session.run(variable.initializer)
|
||||
if found_value:
|
||||
# Error checking should run even if we've already restored a value.
|
||||
@ -772,7 +772,7 @@ def save_network_checkpoint(
|
||||
variable_map[mapped_name]._shared_name,
|
||||
variable._shared_name,
|
||||
network.scope_name))
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
sess = None
|
||||
else:
|
||||
sess = ops.get_default_session()
|
||||
@ -853,7 +853,7 @@ def _restore_existing_variables(network, save_path, map_func, user_map_func):
|
||||
network_name=network.name,
|
||||
network_scope_name=network.scope_name))
|
||||
if existing_variables_by_checkpoint_name:
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
sess = None
|
||||
else:
|
||||
sess = ops.get_default_session()
|
||||
@ -880,7 +880,7 @@ def _set_restore_on_create(network, save_path, map_func, user_map_func,
|
||||
# _DeferredRestoration objects once a Network has been built (so that
|
||||
# restoring in a loop does not take increasing amounts of memory).
|
||||
if checkpointed_variables_to_restore:
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
sess = None
|
||||
else:
|
||||
sess = ops.get_default_session()
|
||||
|
@ -73,7 +73,7 @@ def restore_variables_on_create(save_path, map_func=None):
|
||||
NotFoundError: If the variable is not found in checkpoint.
|
||||
ValueError: If not used in eager mode or map_func is not callable.
|
||||
"""
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
raise ValueError(
|
||||
"Currently, restore_variables_on_create can only be used with "
|
||||
"eager execution enabled.")
|
||||
@ -131,7 +131,7 @@ class Saver(object):
|
||||
Raises:
|
||||
RuntimeError: if invoked when eager execution has not been enabled.
|
||||
"""
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
raise RuntimeError("tfe.Saver can only be used when eager "
|
||||
"execution is enabled. Use tf.train.Saver when "
|
||||
"building graphs.")
|
||||
|
@ -60,8 +60,8 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
|
||||
@@Checkpointable
|
||||
@@CheckpointableSaver
|
||||
|
||||
@@executing_eagerly
|
||||
@@in_eager_mode
|
||||
@@in_graph_mode
|
||||
|
||||
@@run_test_in_graph_and_eager_modes
|
||||
|
||||
@ -93,8 +93,7 @@ from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager.context import DEVICE_PLACEMENT_EXPLICIT
|
||||
from tensorflow.python.eager.context import DEVICE_PLACEMENT_WARN
|
||||
from tensorflow.python.eager.context import DEVICE_PLACEMENT_SILENT
|
||||
from tensorflow.python.eager.context import in_eager_mode
|
||||
from tensorflow.python.eager.context import in_graph_mode
|
||||
from tensorflow.python.eager.context import executing_eagerly
|
||||
from tensorflow.python.eager.context import list_devices
|
||||
from tensorflow.python.eager.context import num_gpus
|
||||
from tensorflow.python.eager.execution_callbacks import add_execution_callback
|
||||
@ -122,5 +121,6 @@ implicit_value_and_gradients = backprop.implicit_val_and_grad
|
||||
gradients_function = backprop.gradients_function
|
||||
value_and_gradients_function = backprop.val_and_grad_function
|
||||
GradientTape = backprop.GradientTape # pylint: disable=invalid-name
|
||||
in_eager_mode = executing_eagerly
|
||||
|
||||
remove_undocumented(__name__)
|
||||
|
@ -47,7 +47,8 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testVariableError(self):
|
||||
with self.assertRaisesRegexp(
|
||||
RuntimeError, r'Variable not supported in Eager mode'):
|
||||
RuntimeError,
|
||||
r'Variable not supported when eager execution is enabled'):
|
||||
variables.Variable(initial_value=1.0)
|
||||
|
||||
def testGradients(self):
|
||||
|
@ -154,7 +154,7 @@ class CriticalSection(object):
|
||||
self._handle = gen_resource_variable_ops.mutex_v2(
|
||||
shared_name=shared_name, container=container, name=name)
|
||||
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
ops.add_to_collections(CRITICAL_SECTIONS, self)
|
||||
|
||||
@property
|
||||
@ -221,7 +221,7 @@ class CriticalSection(object):
|
||||
"This is illegal and would cause deadlocks. "
|
||||
"CriticalSection: %s." % self._handle)
|
||||
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
# Collections and op introspection does not work in eager
|
||||
# mode. This is generally ok; since eager mode (as of
|
||||
# writing) executes sequentially anyway.
|
||||
@ -250,7 +250,7 @@ class CriticalSection(object):
|
||||
return x.identity()
|
||||
elif isinstance(x, ops.Operation):
|
||||
return control_flow_ops.group(x)
|
||||
elif context.in_eager_mode() and x is None:
|
||||
elif context.executing_eagerly() and x is None:
|
||||
return None
|
||||
else:
|
||||
return array_ops.identity(x)
|
||||
@ -274,7 +274,7 @@ class CriticalSection(object):
|
||||
with ops.control_dependencies([ensure_lock_exists]):
|
||||
outputs = nest.map_structure(identity, r)
|
||||
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
signature = _ExecutionSignature(
|
||||
op=lock.op,
|
||||
handle=self._handle,
|
||||
|
@ -2746,7 +2746,7 @@ def softmax(logits, scope=None):
|
||||
logits_2d = array_ops.reshape(logits, [-1, num_logits])
|
||||
predictions = nn.softmax(logits_2d)
|
||||
predictions = array_ops.reshape(predictions, array_ops.shape(logits))
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
predictions.set_shape(logits.get_shape())
|
||||
return predictions
|
||||
|
||||
|
@ -1263,7 +1263,7 @@ def _compute_placement_auc(labels, predictions, weights, alpha,
|
||||
weights_for_true = ordered_weights * float_labels_for_true
|
||||
weights_for_false = ordered_weights * float_labels_for_false
|
||||
|
||||
# For each set of weights with the same segmented indices, we add up the
|
||||
# For each set of weights with the same segmented indices, we add up the
|
||||
# weight values. Note that for each label, we deliberately rely on weights
|
||||
# for the opposite label.
|
||||
weight_totals_for_true = math_ops.segment_sum(weights_for_false,
|
||||
@ -3646,7 +3646,7 @@ def cohen_kappa(labels,
|
||||
`updates_collections` are not a list or tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.contrib.metrics.cohen_kappa is not supported'
|
||||
'when eager execution is enabled.')
|
||||
if num_classes < 2:
|
||||
|
@ -267,5 +267,5 @@ def _check_device(tensor, expected=None):
|
||||
|
||||
|
||||
def _check_graph_mode():
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise ValueError('Nccl ops are not supported in eager mode')
|
||||
|
@ -97,7 +97,7 @@ class AddSignTest(test.TestCase):
|
||||
global_step=global_step)
|
||||
neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||
@ -108,13 +108,13 @@ class AddSignTest(test.TestCase):
|
||||
# last 3 steps with negative gradient (sign(gm) should be -1)
|
||||
for t in range(1, 8):
|
||||
if t < 5:
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(update)
|
||||
elif t > 1:
|
||||
opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
else:
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(neg_update)
|
||||
elif t > 1:
|
||||
opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),
|
||||
|
@ -99,7 +99,7 @@ class PowerSignTest(test.TestCase):
|
||||
neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||
@ -110,13 +110,13 @@ class PowerSignTest(test.TestCase):
|
||||
# last 3 steps with negative gradient (sign(gm) should be -1)
|
||||
for t in range(1, 8):
|
||||
if t < 5:
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(update)
|
||||
elif t > 1:
|
||||
opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
else:
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(neg_update)
|
||||
elif t > 1:
|
||||
opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),
|
||||
|
@ -869,7 +869,7 @@ class LSTMTest(test.TestCase):
|
||||
num_proj = 4
|
||||
max_length = 8
|
||||
sequence_length = [4, 6]
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
in_graph_mode = not context.executing_eagerly()
|
||||
with self.test_session(graph=ops_lib.Graph()) as sess:
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=self._seed)
|
||||
@ -934,8 +934,7 @@ class LSTMTest(test.TestCase):
|
||||
if in_graph_mode:
|
||||
self.assertAllEqual(outputs_static, outputs_dynamic)
|
||||
else:
|
||||
self.assertAllEqual(
|
||||
array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy())
|
||||
self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic)
|
||||
self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@ -946,7 +945,7 @@ class LSTMTest(test.TestCase):
|
||||
num_proj = 4
|
||||
max_length = 8
|
||||
sequence_length = [4, 6]
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
in_graph_mode = not context.executing_eagerly()
|
||||
with self.test_session(graph=ops_lib.Graph()) as sess:
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=self._seed)
|
||||
@ -1022,10 +1021,9 @@ class LSTMTest(test.TestCase):
|
||||
if in_graph_mode:
|
||||
self.assertAllEqual(outputs_static, outputs_dynamic)
|
||||
else:
|
||||
self.assertAllEqual(
|
||||
array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy())
|
||||
state_static = [s.numpy() for s in nest.flatten(state_static)]
|
||||
state_dynamic = [s.numpy() for s in nest.flatten(state_dynamic)]
|
||||
self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic)
|
||||
state_static = nest.flatten(state_static)
|
||||
state_dynamic = nest.flatten(state_dynamic)
|
||||
self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic))
|
||||
|
||||
def _testDynamicEquivalentToStaticRNN(self, use_sequence_length):
|
||||
@ -1043,7 +1041,7 @@ class LSTMTest(test.TestCase):
|
||||
else:
|
||||
sequence_length = None
|
||||
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
in_graph_mode = not context.executing_eagerly()
|
||||
|
||||
# TODO(b/68017812): Eager ignores operation seeds, so we need to create a
|
||||
# single cell and reuse it across the static and dynamic RNNs. Remove this
|
||||
|
@ -110,7 +110,7 @@ class SummaryWriter(object):
|
||||
|
||||
def __init__(self, resource):
|
||||
self._resource = resource
|
||||
if context.in_eager_mode() and self._resource is not None:
|
||||
if context.executing_eagerly() and self._resource is not None:
|
||||
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
|
||||
handle=self._resource, handle_device="cpu:0")
|
||||
|
||||
@ -158,7 +158,7 @@ def initialize(
|
||||
@{tf.contrib.summary.SummaryWriter}.
|
||||
ValueError: If session wasn't passed and no default session.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
if context.context().summary_writer_resource is None:
|
||||
raise RuntimeError("No default tf.contrib.summary.SummaryWriter found")
|
||||
@ -269,7 +269,7 @@ def _make_summary_writer(name, factory, **kwargs):
|
||||
resource = gen_summary_ops.summary_writer(shared_name=name)
|
||||
# TODO(apassos): Consider doing this instead.
|
||||
# node = factory(resource, **kwargs)
|
||||
# if not context.in_eager_mode():
|
||||
# if not context.executing_eagerly():
|
||||
# ops.get_default_session().run(node)
|
||||
ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME,
|
||||
factory(resource, **kwargs))
|
||||
@ -295,7 +295,7 @@ def all_summary_ops():
|
||||
Returns:
|
||||
The summary ops.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return None
|
||||
return ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
|
||||
|
||||
@ -309,7 +309,7 @@ def summary_writer_initializer_op():
|
||||
Raises:
|
||||
RuntimeError: If in Eager mode.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
"tf.contrib.summary.summary_writer_initializer_op is only "
|
||||
"supported in graph mode.")
|
||||
@ -477,7 +477,7 @@ def graph(param, step=None, name=None):
|
||||
Raises:
|
||||
TypeError: If `param` isn't already a @{tf.Tensor} in graph mode.
|
||||
"""
|
||||
if not context.in_eager_mode() and not isinstance(param, ops.Tensor):
|
||||
if not context.executing_eagerly() and not isinstance(param, ops.Tensor):
|
||||
raise TypeError("graph() needs a tf.Tensor (e.g. tf.placeholder) in graph "
|
||||
"mode, but was: %s" % type(param))
|
||||
writer = context.context().summary_writer_resource
|
||||
|
@ -91,7 +91,7 @@ class Dataset(object):
|
||||
Raises:
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
"dataset.make_initializable_iterator is not supported when eager "
|
||||
"execution is enabled.")
|
||||
@ -123,7 +123,7 @@ class Dataset(object):
|
||||
Raises:
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
"dataset.make_one_shot_iterator is not supported when eager "
|
||||
"execution is enabled.")
|
||||
|
@ -65,7 +65,7 @@ class RandomSeedTest(test.TestCase):
|
||||
self.assertEqual((g_seed, op_seed), toutput, msg=msg)
|
||||
random_seed.set_random_seed(None)
|
||||
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
random_seed.set_random_seed(1)
|
||||
tinput = (1, None)
|
||||
toutput = (1, ops.get_default_graph()._last_id) # pylint: disable=protected-access
|
||||
|
@ -55,7 +55,7 @@ def c_tfe_py_fastpath_execute(a,
|
||||
transpose_b=False,
|
||||
name=None):
|
||||
ctx = context.context()
|
||||
assert not ctx.in_graph_mode(
|
||||
assert ctx.in_eager_mode(
|
||||
), "The prototype doesn't contain C code for graph construction"
|
||||
try:
|
||||
return pywrap_tensorflow.TFE_Py_FastPathExecute(
|
||||
|
@ -260,12 +260,8 @@ class Context(object):
|
||||
if mode == EAGER_MODE:
|
||||
context_stack.pop()
|
||||
|
||||
def in_graph_mode(self):
|
||||
"""Returns True if current thread is in GRAPH mode."""
|
||||
return self._eager_context.mode == GRAPH_MODE
|
||||
|
||||
def in_eager_mode(self):
|
||||
"""Returns True if current thread is in EAGER mode."""
|
||||
def executing_eagerly(self):
|
||||
"""Returns True if current thread has eager executing enabled."""
|
||||
return self._eager_context.mode == EAGER_MODE
|
||||
|
||||
def scalar_cache(self):
|
||||
@ -522,23 +518,23 @@ def internal_operation_seed():
|
||||
return context()._internal_operation_seed() # pylint: disable=protected-access
|
||||
|
||||
|
||||
def in_graph_mode():
|
||||
"""Returns True if current thread is in GRAPH mode for default context."""
|
||||
return context().in_graph_mode()
|
||||
def executing_eagerly():
|
||||
"""Returns True if the current thread has eager execution enabled."""
|
||||
return context().executing_eagerly()
|
||||
|
||||
|
||||
def in_eager_mode():
|
||||
"""Returns True if current thread is in EAGER mode for default context."""
|
||||
return context().in_eager_mode()
|
||||
"""Use executing_eagerly() instead. This function will be removed."""
|
||||
return executing_eagerly()
|
||||
|
||||
|
||||
def graph_mode():
|
||||
"""Context-manager to enable GRAPH mode for current thread."""
|
||||
"""Context-manager to disable eager execution for the current thread."""
|
||||
return context()._mode(GRAPH_MODE) # pylint: disable=protected-access
|
||||
|
||||
|
||||
def eager_mode():
|
||||
"""Context-manager to enable EAGER mode for current thread."""
|
||||
"""Context-manager to enable eager execution for the current thread."""
|
||||
return context()._mode(EAGER_MODE) # pylint: disable=protected-access
|
||||
|
||||
|
||||
@ -631,4 +627,8 @@ def export_run_metadata():
|
||||
# (for example, enable_eager_execution in python/framework/ops.py),
|
||||
# but they do all import this file. Note that IS_IN_GRAPH_MODE and
|
||||
# in_graph_mode are both parameterless functions.
|
||||
is_in_graph_mode.IS_IN_GRAPH_MODE = in_graph_mode
|
||||
def _tmp_in_graph_mode():
|
||||
return not executing_eagerly()
|
||||
|
||||
|
||||
is_in_graph_mode.IS_IN_GRAPH_MODE = _tmp_in_graph_mode
|
||||
|
@ -57,8 +57,7 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testContext(self):
|
||||
ctx = context.Context()
|
||||
self.assertFalse(ctx.in_graph_mode())
|
||||
self.assertTrue(ctx.in_eager_mode())
|
||||
self.assertTrue(ctx.executing_eagerly())
|
||||
|
||||
self.assertEqual('', ctx.scope_name)
|
||||
ctx.scope_name = 'foo'
|
||||
@ -150,9 +149,9 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
|
||||
def get_context_values(ctx):
|
||||
return [
|
||||
ctx.in_graph_mode(),
|
||||
ctx.in_eager_mode(), ctx.scope_name, ctx.summary_writer_resource,
|
||||
ctx.device_name, ctx.num_gpus()
|
||||
ctx.executing_eagerly(), ctx.scope_name, ctx.summary_writer_resource,
|
||||
ctx.device_name,
|
||||
ctx.num_gpus()
|
||||
]
|
||||
|
||||
def get_values(ctx, values):
|
||||
|
@ -112,7 +112,7 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
|
||||
"""
|
||||
del as_ref # Unused.
|
||||
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return value
|
||||
|
||||
default_graph = ops.get_default_graph()
|
||||
@ -295,7 +295,7 @@ class _EagerDefinedFunction(object):
|
||||
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
|
||||
function_def = function_pb2.FunctionDef()
|
||||
function_def.ParseFromString(compat.as_bytes(proto_data))
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
_register(fn)
|
||||
self.definition = function_def
|
||||
self.name = function_def.signature.name
|
||||
@ -438,7 +438,14 @@ class GraphModeFunction(object):
|
||||
all_args = args + self._extra_inputs
|
||||
signature = self._forward_fdef.signature
|
||||
ctx = context.context()
|
||||
if ctx.in_graph_mode():
|
||||
if ctx.executing_eagerly():
|
||||
outputs = execute.execute(
|
||||
str(signature.name),
|
||||
num_outputs=len(signature.output_arg),
|
||||
inputs=all_args,
|
||||
attrs=None,
|
||||
ctx=ctx)
|
||||
else:
|
||||
g = ops.get_default_graph()
|
||||
g._add_function(self._forward_fdef) # pylint: disable=protected-access
|
||||
op = g.create_op(
|
||||
@ -453,13 +460,6 @@ class GraphModeFunction(object):
|
||||
outputs, (ops.Tensor, type(None))) else list(outputs)
|
||||
for i, s in enumerate(self._output_shapes):
|
||||
outputs[i].set_shape(s)
|
||||
else:
|
||||
outputs = execute.execute(
|
||||
str(signature.name),
|
||||
num_outputs=len(signature.output_arg),
|
||||
inputs=all_args,
|
||||
attrs=None,
|
||||
ctx=ctx)
|
||||
real_outputs = outputs[:len(self._returns)]
|
||||
side_outputs = outputs[len(self._returns):]
|
||||
|
||||
@ -530,7 +530,14 @@ class GraphModeFunction(object):
|
||||
return self._backprop_call(tensor_inputs)
|
||||
|
||||
ctx = context.context()
|
||||
if ctx.in_graph_mode():
|
||||
if ctx.executing_eagerly():
|
||||
result = execute.execute(
|
||||
str(self._func_name),
|
||||
num_outputs=self._num_outputs,
|
||||
inputs=tensor_inputs + self._extra_inputs,
|
||||
attrs=None,
|
||||
ctx=ctx)
|
||||
else:
|
||||
g = ops.get_default_graph()
|
||||
self.add_to_graph(g)
|
||||
signature = self._function_def.definition.signature
|
||||
@ -547,13 +554,6 @@ class GraphModeFunction(object):
|
||||
return op
|
||||
for i, s in enumerate(self._output_shapes):
|
||||
result[i].set_shape(s)
|
||||
else:
|
||||
result = execute.execute(
|
||||
str(self._func_name),
|
||||
num_outputs=self._num_outputs,
|
||||
inputs=tensor_inputs + self._extra_inputs,
|
||||
attrs=None,
|
||||
ctx=ctx)
|
||||
|
||||
return self._build_call_outputs(result)
|
||||
|
||||
@ -666,7 +666,7 @@ def _defun_internal(name, func, args, kwds):
|
||||
if x not in all_ignored_ops)
|
||||
# Register any other functions defined in the graph
|
||||
# TODO(ashankar): Oh lord, forgive me for this lint travesty.
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
|
||||
# TODO(ashankar): What about the gradient registry?
|
||||
_register(f._c_func) # pylint: disable=protected-access
|
||||
@ -906,7 +906,7 @@ class AutomaticControlDependencies(object):
|
||||
return tensor
|
||||
|
||||
def __enter__(self):
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return self
|
||||
# This code assumes no other thread is adding ops to the graph while
|
||||
# we're adding ops to the graph.
|
||||
@ -977,7 +977,7 @@ class AutomaticControlDependencies(object):
|
||||
merge_for_resource[o] = new_merge[0].op
|
||||
|
||||
def __exit__(self, unused_type, unused_value, unused_traceback):
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
|
||||
if self._graph is not ops.get_default_graph():
|
||||
|
@ -406,7 +406,7 @@ def graph_callable(shape_and_dtypes):
|
||||
A callable graph object.
|
||||
"""
|
||||
# TODO(alive,apassos): support initialized_value and friends from tf.Variable.
|
||||
assert context.in_eager_mode(), (
|
||||
assert context.executing_eagerly(), (
|
||||
"graph_callable can only be used when Eager execution is enabled.")
|
||||
def decorator(func):
|
||||
return tf_decorator.make_decorator(func,
|
||||
|
@ -367,7 +367,7 @@ void GenEagerPythonOp::HandleGraphMode(const string& function_setup) {
|
||||
// Handle graph-mode case
|
||||
strings::StrAppend(&result_,
|
||||
" _ctx = _context.context()\n"
|
||||
" if _ctx.in_graph_mode():\n",
|
||||
" if not _ctx.executing_eagerly():\n",
|
||||
function_setup,
|
||||
" _, _, _op = _op_def_lib._apply_op_helper(\n");
|
||||
AddBodyNoReturn(" ");
|
||||
|
@ -169,7 +169,7 @@ class Tests(test.TestCase):
|
||||
def testFastpathExecute_InvalidInputs(self):
|
||||
a_2_by_2 = random_ops.random_uniform((2, 2))
|
||||
ctx = context.context()
|
||||
assert not ctx.in_graph_mode(
|
||||
assert ctx.executing_eagerly(
|
||||
), "The prototype doesn't contain C code for graph construction"
|
||||
ctx_handle = ctx._handle # pylint: disable=protected-access
|
||||
|
||||
|
@ -166,7 +166,7 @@ class Estimator(object):
|
||||
ValueError: if this is called via a subclass and if that class overrides
|
||||
a member of `Estimator`.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
'Estimators are not supported when eager execution is enabled.')
|
||||
|
||||
|
@ -181,7 +181,7 @@ def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
|
||||
TypeError: if shape is incorrectly specified or unsupported.
|
||||
"""
|
||||
ctx = context.context()
|
||||
if not ctx.in_graph_mode():
|
||||
if ctx.executing_eagerly():
|
||||
t = convert_to_eager_tensor(value, ctx, dtype)
|
||||
if shape is None:
|
||||
return t
|
||||
|
@ -489,10 +489,10 @@ class _DefinedFunction(object):
|
||||
|
||||
# Adds this function into 'g'.
|
||||
# pylint: disable=protected-access
|
||||
if context.in_graph_mode():
|
||||
g._add_function(self)
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
context.context().add_function_def(self.definition)
|
||||
else:
|
||||
g._add_function(self)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Ensures related sub-routines are defined in 'g', too.
|
||||
|
@ -695,7 +695,7 @@ def import_scoped_meta_graph(meta_graph_or_file,
|
||||
Raises:
|
||||
ValueError: If the graph_def contains unbound inputs.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise ValueError("Exporting/importing meta graphs is not supported when "
|
||||
"eager execution is enabled.")
|
||||
if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
|
||||
@ -856,7 +856,7 @@ def export_scoped_meta_graph(filename=None,
|
||||
Raises:
|
||||
ValueError: When the `GraphDef` is larger than 2GB.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise ValueError("Exporting/importing meta graphs is not supported when "
|
||||
"Eager Execution is enabled.")
|
||||
graph = graph or ops.get_default_graph()
|
||||
|
@ -395,10 +395,10 @@ class Tensor(_TensorLike):
|
||||
"Tensor._shape cannot be assigned, use Tensor.set_shape instead.")
|
||||
|
||||
def __iter__(self):
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
raise TypeError(
|
||||
"`Tensor` objects are not iterable when eager execution is not "
|
||||
"enabled. To iterate over this tensor use `tf.map_fn`.")
|
||||
"Tensor objects are not iterable when eager execution is not "
|
||||
"enabled. To iterate over this tensor use tf.map_fn.")
|
||||
shape = self._shape_tuple()
|
||||
if shape is None:
|
||||
raise TypeError("Cannot iterate over a tensor with unknown shape.")
|
||||
@ -772,7 +772,7 @@ class _EagerTensorBase(Tensor):
|
||||
six.raise_from(core._status_to_exception(e.code, e.message), None)
|
||||
|
||||
# Record the copy on tape and define backprop copy as well.
|
||||
if not context.in_graph_mode():
|
||||
if context.executing_eagerly():
|
||||
self_device = self.device
|
||||
def grad_fun(dresult):
|
||||
return [dresult._copy(device_name=self_device)]
|
||||
@ -993,7 +993,7 @@ def internal_convert_to_tensor(value,
|
||||
|
||||
"""
|
||||
if ctx is None: ctx = context.context()
|
||||
if ctx.in_eager_mode():
|
||||
if ctx.executing_eagerly():
|
||||
# Fast path for EagerTensors that don't need any conversion.
|
||||
if isinstance(value, EagerTensor):
|
||||
# Note that we don't check that value's dtype matches the dtype
|
||||
@ -4797,15 +4797,15 @@ def device(device_name_or_function):
|
||||
Raises:
|
||||
RuntimeError: If eager execution is enabled and a function is passed in.
|
||||
"""
|
||||
if context.in_graph_mode():
|
||||
return get_default_graph().device(device_name_or_function)
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
# TODO(agarwal): support device functions in EAGER mode.
|
||||
if callable(device_name_or_function):
|
||||
raise RuntimeError(
|
||||
"tf.device does not support functions when eager execution "
|
||||
"is enabled.")
|
||||
return context.device(device_name_or_function)
|
||||
else:
|
||||
return get_default_graph().device(device_name_or_function)
|
||||
|
||||
|
||||
@tf_export("container")
|
||||
@ -4824,7 +4824,12 @@ def container(container_name):
|
||||
|
||||
@tf_export("colocate_with")
|
||||
def colocate_with(op, ignore_existing=False):
|
||||
if context.in_graph_mode():
|
||||
if context.executing_eagerly():
|
||||
if op is not None:
|
||||
return device(op.device)
|
||||
else:
|
||||
return _NullContextmanager()
|
||||
else:
|
||||
default_graph = get_default_graph()
|
||||
if isinstance(op, EagerTensor):
|
||||
if default_graph.building_function:
|
||||
@ -4833,11 +4838,6 @@ def colocate_with(op, ignore_existing=False):
|
||||
raise ValueError("Encountered an Eager-defined Tensor during graph "
|
||||
"construction, but a function was not being built.")
|
||||
return default_graph.colocate_with(op, ignore_existing)
|
||||
else:
|
||||
if op is not None:
|
||||
return device(op.device)
|
||||
else:
|
||||
return _NullContextmanager()
|
||||
|
||||
|
||||
@tf_export("control_dependencies")
|
||||
@ -4857,10 +4857,10 @@ def control_dependencies(control_inputs):
|
||||
A context manager that specifies control dependencies for all
|
||||
operations constructed within the context.
|
||||
"""
|
||||
if context.in_graph_mode():
|
||||
return get_default_graph().control_dependencies(control_inputs)
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
return _NullContextmanager()
|
||||
else:
|
||||
return get_default_graph().control_dependencies(control_inputs)
|
||||
|
||||
|
||||
class _DefaultStack(threading.local):
|
||||
@ -5123,7 +5123,7 @@ def init_scope():
|
||||
"""
|
||||
# pylint: enable=g-doc-return-or-yield,line-too-long
|
||||
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
# Fastpath.
|
||||
with tape.stop_recording():
|
||||
yield
|
||||
@ -5705,7 +5705,7 @@ class name_scope(object): # pylint: disable=invalid-name
|
||||
self._default_name = default_name
|
||||
self._values = values
|
||||
self._ctx = context.context()
|
||||
self._in_eager_mode = self._ctx.in_eager_mode()
|
||||
self._in_eager_mode = self._ctx.executing_eagerly()
|
||||
|
||||
def __enter__(self):
|
||||
"""Start the scope block.
|
||||
@ -5884,7 +5884,7 @@ def get_from_proto_function(collection_name):
|
||||
|
||||
|
||||
def _assert_collection_is_ok(collection_name):
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
if collection_name in GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access
|
||||
raise ValueError("When Eager Execution is enabled, variable "
|
||||
"collections are not supported.")
|
||||
|
@ -1763,7 +1763,13 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
|
||||
return constant_op.constant(2.0)
|
||||
future.calls = 0
|
||||
|
||||
if context.in_graph_mode():
|
||||
if context.executing_eagerly():
|
||||
a = constant_op.constant(1.0)
|
||||
b = future()
|
||||
with ops.control_dependencies([a, b]):
|
||||
c = constant_op.constant(3.0)
|
||||
self.assertEqual(future.calls, 1)
|
||||
else:
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
a = constant_op.constant(1.0)
|
||||
@ -1772,12 +1778,6 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
|
||||
c = constant_op.constant(3.0)
|
||||
self.assertEqual(c.op.control_inputs, [a.op, b.op])
|
||||
self.assertEqual(future.calls, 1)
|
||||
else:
|
||||
a = constant_op.constant(1.0)
|
||||
b = future()
|
||||
with ops.control_dependencies([a, b]):
|
||||
c = constant_op.constant(3.0)
|
||||
self.assertEqual(future.calls, 1)
|
||||
|
||||
def testBasicWithConversion(self):
|
||||
g = ops.Graph()
|
||||
@ -2150,11 +2150,11 @@ class InitScopeTest(test_util.TensorFlowTestCase):
|
||||
with ops.init_scope():
|
||||
# Because g is building a function, init_scope should
|
||||
# escape out to the eager context.
|
||||
self.assertTrue(context.in_eager_mode())
|
||||
self.assertTrue(context.executing_eagerly())
|
||||
# g should be reinstated as the default graph, and the
|
||||
# graph context should be re-entered.
|
||||
self.assertIs(g, ops.get_default_graph())
|
||||
self.assertTrue(context.in_graph_mode())
|
||||
self.assertFalse(context.executing_eagerly())
|
||||
|
||||
def testStaysInEagerWhenOnlyEagerContextActive(self):
|
||||
with context.eager_mode():
|
||||
@ -2277,12 +2277,13 @@ class InitScopeTest(test_util.TensorFlowTestCase):
|
||||
with context.eager_mode():
|
||||
def foo():
|
||||
with ops.name_scope("inner"), ops.init_scope():
|
||||
if context.in_graph_mode():
|
||||
self.assertEqual(ops.get_name_scope(), "inner")
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
# A trailing slash is always appended when eager execution is
|
||||
# enabled.
|
||||
self.assertEqual(context.context().scope_name, "inner/")
|
||||
else:
|
||||
self.assertEqual(ops.get_name_scope(), "inner")
|
||||
|
||||
foo()
|
||||
self.assertEqual(ops.get_name_scope(), "")
|
||||
foo_compiled = eager_function.defun(foo)
|
||||
|
@ -52,20 +52,20 @@ def get_seed(op_seed):
|
||||
A tuple of two integers that should be used for the local seed of this
|
||||
operation.
|
||||
"""
|
||||
is_graph_mode = context.in_graph_mode()
|
||||
eager = context.executing_eagerly()
|
||||
|
||||
if is_graph_mode:
|
||||
global_seed = ops.get_default_graph().seed
|
||||
else:
|
||||
if eager:
|
||||
global_seed = context.global_seed()
|
||||
else:
|
||||
global_seed = ops.get_default_graph().seed
|
||||
|
||||
if global_seed is not None:
|
||||
if op_seed is None:
|
||||
# pylint: disable=protected-access
|
||||
if is_graph_mode:
|
||||
op_seed = ops.get_default_graph()._last_id
|
||||
else:
|
||||
if eager:
|
||||
op_seed = context.internal_operation_seed()
|
||||
else:
|
||||
op_seed = ops.get_default_graph()._last_id
|
||||
|
||||
seeds = _truncate_seed(global_seed), _truncate_seed(op_seed)
|
||||
else:
|
||||
@ -176,7 +176,7 @@ def set_random_seed(seed):
|
||||
Args:
|
||||
seed: integer.
|
||||
"""
|
||||
if context.in_graph_mode():
|
||||
ops.get_default_graph().seed = seed
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
context.set_global_seed(seed)
|
||||
else:
|
||||
ops.get_default_graph().seed = seed
|
||||
|
@ -40,13 +40,13 @@ class RandomSeedTest(test.TestCase):
|
||||
((2**31 - 1, 0), (0, 2**31 - 1)), # Don't wrap to (0, 0) either
|
||||
((0, 2**31 - 1), (0, 2**31 - 1)), # Wrapping for the other argument
|
||||
]
|
||||
if context.in_graph_mode():
|
||||
# 0 will be the default_graph._lastid.
|
||||
test_cases.append(((1, None), (1, 0)))
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
# operation seed is random number generated based on global seed.
|
||||
# it's not tested due to possibility of platform or version difference.
|
||||
pass
|
||||
else:
|
||||
# 0 will be the default_graph._lastid.
|
||||
test_cases.append(((1, None), (1, 0)))
|
||||
for tc in test_cases:
|
||||
tinput, toutput = tc[0], tc[1]
|
||||
random_seed.set_random_seed(tinput[0])
|
||||
|
@ -828,7 +828,7 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name
|
||||
Returns:
|
||||
A `TensorShape` based on the constant value of the given `tensor`.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return tensor_shape.as_shape(
|
||||
[dim if dim != -1 else None for dim in tensor.numpy()])
|
||||
|
||||
|
@ -816,7 +816,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
Returns:
|
||||
tensors numpy values.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return self._eval_helper(tensors)
|
||||
else:
|
||||
sess = ops.get_default_session()
|
||||
|
@ -343,7 +343,7 @@ def learning_phase():
|
||||
Returns:
|
||||
Learning phase (scalar integer tensor or Python integer).
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
if 'eager' not in _GRAPH_LEARNING_PHASES:
|
||||
# Fallback to inference mode as default.
|
||||
return 0
|
||||
@ -370,7 +370,7 @@ def set_learning_phase(value):
|
||||
global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned
|
||||
if value not in {0, 1}:
|
||||
raise ValueError('Expected learning phase to be 0 or 1.')
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
_GRAPH_LEARNING_PHASES['eager'] = value
|
||||
else:
|
||||
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
|
||||
@ -399,7 +399,7 @@ def learning_phase_scope(value):
|
||||
yield value
|
||||
finally:
|
||||
# Restore learning phase to initial value.
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
_GRAPH_LEARNING_PHASES['eager'] = previous_value
|
||||
else:
|
||||
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value
|
||||
@ -2625,7 +2625,7 @@ def get_value(x):
|
||||
Returns:
|
||||
A Numpy array.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return x.numpy()
|
||||
return x.eval(session=get_session())
|
||||
|
||||
@ -2640,7 +2640,7 @@ def batch_get_value(tensors):
|
||||
Returns:
|
||||
A list of Numpy arrays.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return [x.numpy() for x in tensors]
|
||||
if tensors:
|
||||
return get_session().run(tensors)
|
||||
@ -2658,7 +2658,7 @@ def set_value(x, value):
|
||||
(of the same shape).
|
||||
"""
|
||||
value = np.asarray(value, dtype=dtype(x))
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
x.assign(value)
|
||||
else:
|
||||
tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0])
|
||||
@ -2681,7 +2681,7 @@ def batch_set_value(tuples):
|
||||
tuples: a list of tuples `(tensor, value)`.
|
||||
`value` should be a Numpy array.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
for x, value in tuples:
|
||||
x.assign(np.asarray(value, dtype=dtype(x)))
|
||||
else:
|
||||
@ -3123,7 +3123,7 @@ def rnn(step_function,
|
||||
outputs_shape[1] = inputs_shape[1]
|
||||
outputs.set_shape(outputs_shape)
|
||||
|
||||
if not context.in_eager_mode():
|
||||
if not context.executing_eagerly():
|
||||
last_output._uses_learning_phase = uses_learning_phase
|
||||
return last_output, outputs, new_states
|
||||
|
||||
|
@ -237,7 +237,7 @@ class Layer(tf_base_layers.Layer):
|
||||
"""
|
||||
# Actually call the layer (optionally building it).
|
||||
output = super(Layer, self).__call__(inputs, **kwargs)
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return output
|
||||
|
||||
if hasattr(self, '_symbolic_set_inputs') and not self.inputs:
|
||||
|
@ -92,7 +92,7 @@ class InputLayer(base_layer.Layer):
|
||||
else:
|
||||
batch_input_shape = None
|
||||
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
# In eager mode, create a temporary placeholder to call the layer on.
|
||||
input_tensor = tf_base_layers._DeferredTensor( # pylint: disable=protected-access
|
||||
shape=batch_input_shape,
|
||||
|
@ -99,11 +99,11 @@ class Network(base_layer.Layer):
|
||||
self._losses = [] # Used in symbolic mode only.
|
||||
self._scope = None # Never used.
|
||||
self._reuse = None # Never used.
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
self._graph = None
|
||||
else:
|
||||
self._graph = ops.get_default_graph() # Used in symbolic mode only.
|
||||
# A Network does not create weights of its own, thus has no dtype.
|
||||
# A Network does not create weights of its own, thus has no dtype.
|
||||
self._dtype = None
|
||||
|
||||
# All layers in order of horizontal graph traversal.
|
||||
@ -126,7 +126,7 @@ class Network(base_layer.Layer):
|
||||
self.outputs = [outputs]
|
||||
|
||||
# User-prodived argument validation.
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
# Check that all inputs/outputs are DeferredTensors.
|
||||
for tensor in self.inputs:
|
||||
if not isinstance(tensor, tf_base_layers._DeferredTensor): # pylint: disable=protected-access
|
||||
@ -275,7 +275,7 @@ class Network(base_layer.Layer):
|
||||
self._feed_input_names.append(layer.name)
|
||||
self._feed_input_shapes.append(K.int_shape(self.inputs[i]))
|
||||
# layer.input gives an error in eager mode
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self._feed_inputs.append(layer.input)
|
||||
for layer in self._output_layers:
|
||||
self.output_names.append(layer.name)
|
||||
@ -317,7 +317,7 @@ class Network(base_layer.Layer):
|
||||
raise NotImplementedError('`add_variable` is not supported on Networks.')
|
||||
|
||||
def add_loss(self, *args, **kwargs):
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise NotImplementedError('`add_loss` is not supported on Networks '
|
||||
'when eager execution is enabled.')
|
||||
super(Network, self).add_loss(*args, **kwargs)
|
||||
@ -483,7 +483,7 @@ class Network(base_layer.Layer):
|
||||
Returns:
|
||||
A list of update ops.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return []
|
||||
|
||||
if not self.trainable and not self.stateful:
|
||||
@ -530,7 +530,7 @@ class Network(base_layer.Layer):
|
||||
losses = []
|
||||
for layer in self.layers:
|
||||
losses += layer.losses
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return losses
|
||||
|
||||
if self.inputs:
|
||||
@ -623,7 +623,7 @@ class Network(base_layer.Layer):
|
||||
else:
|
||||
masks = nest.flatten(mask)
|
||||
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
# Try to retrieve cached outputs if the layer has already been called
|
||||
# on these exact inputs.
|
||||
cache_key = (tf_layers_util.object_list_uid(inputs)
|
||||
@ -829,7 +829,7 @@ class Network(base_layer.Layer):
|
||||
else:
|
||||
output_masks = [None for _ in range(len(output_tensors))]
|
||||
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
if layer.activity_regularizer is not None:
|
||||
regularization_losses = [
|
||||
layer.activity_regularizer(x) for x in output_tensors
|
||||
@ -859,7 +859,7 @@ class Network(base_layer.Layer):
|
||||
if output_masks is not None:
|
||||
output_masks = output_masks[0]
|
||||
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
# Update cache;
|
||||
# keys are based on ids on input tensors and inputs masks.
|
||||
cache_key = (tf_layers_util.object_list_uid(inputs)
|
||||
|
@ -755,7 +755,17 @@ class TopologyConstructionTest(test.TestCase):
|
||||
def compute_mask(self, inputs, mask=None):
|
||||
return array_ops.ones_like(inputs)
|
||||
|
||||
if context.in_graph_mode():
|
||||
if context.executing_eagerly():
|
||||
a = constant_op.constant([2] * 32)
|
||||
mask = constant_op.constant([0, 1] * 16)
|
||||
a._keras_mask = mask
|
||||
b = MaskedLayer().apply(a)
|
||||
self.assertTrue(hasattr(b, '_keras_mask'))
|
||||
self.assertAllEqual(
|
||||
self.evaluate(array_ops.ones_like(mask)),
|
||||
self.evaluate(getattr(b, '_keras_mask')))
|
||||
self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b))
|
||||
else:
|
||||
x = keras.Input(shape=(32,))
|
||||
y = MaskedLayer()(x) # pylint: disable=not-callable
|
||||
network = keras.engine.Network(x, y)
|
||||
@ -769,15 +779,6 @@ class TopologyConstructionTest(test.TestCase):
|
||||
x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32))
|
||||
y_2 = network(x_2)
|
||||
self.assertEqual(y_2.get_shape().as_list(), [None, 32])
|
||||
else:
|
||||
a = constant_op.constant([2] * 32)
|
||||
mask = constant_op.constant([0, 1] * 16)
|
||||
a._keras_mask = mask
|
||||
b = MaskedLayer().apply(a)
|
||||
self.assertTrue(hasattr(b, '_keras_mask'))
|
||||
self.assertAllEqual(self.evaluate(array_ops.ones_like(mask)),
|
||||
self.evaluate(getattr(b, '_keras_mask')))
|
||||
self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b))
|
||||
|
||||
def test_activity_regularization_with_model_composition(self):
|
||||
|
||||
@ -885,13 +886,13 @@ class DeferredModeTest(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testSimpleNetworkBuilding(self):
|
||||
inputs = keras.engine.Input(shape=(32,))
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
self.assertIsInstance(inputs, tf_base_layers._DeferredTensor)
|
||||
self.assertEqual(inputs.dtype.name, 'float32')
|
||||
self.assertEqual(inputs.shape.as_list(), [None, 32])
|
||||
|
||||
x = keras.layers.Dense(2)(inputs)
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
self.assertIsInstance(x, tf_base_layers._DeferredTensor)
|
||||
self.assertEqual(x.dtype.name, 'float32')
|
||||
self.assertEqual(x.shape.as_list(), [None, 2])
|
||||
@ -900,7 +901,7 @@ class DeferredModeTest(test.TestCase):
|
||||
network = keras.engine.Network(inputs, outputs)
|
||||
self.assertIsInstance(network, keras.engine.Network)
|
||||
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
# It should be possible to call such a network on EagerTensors.
|
||||
inputs = constant_op.constant(
|
||||
np.random.random((10, 32)).astype('float32'))
|
||||
@ -925,7 +926,7 @@ class DeferredModeTest(test.TestCase):
|
||||
c = keras.layers.Dense(2)(c)
|
||||
|
||||
network = keras.engine.Network([input_a, input_b], [a, c])
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
a_val = constant_op.constant(
|
||||
np.random.random((10, 32)).astype('float32'))
|
||||
b_val = constant_op.constant(
|
||||
|
@ -162,7 +162,7 @@ class Model(Network):
|
||||
`optimizer`, `loss`, `metrics` or `sample_weight_mode`.
|
||||
"""
|
||||
loss = loss or {}
|
||||
if context.in_eager_mode() and not isinstance(
|
||||
if context.executing_eagerly() and not isinstance(
|
||||
optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
|
||||
raise ValueError('Only TF native optimizers are supported in Eager mode.')
|
||||
|
||||
@ -170,13 +170,13 @@ class Model(Network):
|
||||
self.loss = loss
|
||||
self.metrics = metrics or []
|
||||
self.loss_weights = loss_weights
|
||||
if context.in_eager_mode() and sample_weight_mode is not None:
|
||||
if context.executing_eagerly() and sample_weight_mode is not None:
|
||||
raise ValueError('sample_weight_mode is not supported in Eager mode.')
|
||||
self.sample_weight_mode = sample_weight_mode
|
||||
if context.in_eager_mode() and weighted_metrics is not None:
|
||||
if context.executing_eagerly() and weighted_metrics is not None:
|
||||
raise ValueError('weighted_metrics is not supported in Eager mode.')
|
||||
self.weighted_metrics = weighted_metrics
|
||||
if context.in_eager_mode() and target_tensors is not None:
|
||||
if context.executing_eagerly() and target_tensors is not None:
|
||||
raise ValueError('target_tensors is not supported in Eager mode.')
|
||||
self.target_tensors = target_tensors
|
||||
|
||||
@ -230,7 +230,7 @@ class Model(Network):
|
||||
skip_target_weighing_indices.append(i)
|
||||
|
||||
# Prepare output masks.
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
masks = self.compute_mask(self.inputs, mask=None)
|
||||
if masks is None:
|
||||
masks = [None for _ in self.outputs]
|
||||
@ -264,7 +264,7 @@ class Model(Network):
|
||||
self.loss_weights_list = loss_weights_list
|
||||
|
||||
# initialization for Eager mode execution
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
if target_tensors is not None:
|
||||
raise ValueError('target_tensors are not currently supported in Eager'
|
||||
'mode.')
|
||||
@ -738,13 +738,13 @@ class Model(Network):
|
||||
'TensorFlow tensors. '
|
||||
'You passed: x=' + str(x) + '; y=' + str(y))
|
||||
|
||||
if context.in_graph_mode():
|
||||
if context.executing_eagerly():
|
||||
target_tensors = None
|
||||
else:
|
||||
# Handle target tensors if any passed.
|
||||
if not isinstance(y, (list, tuple)):
|
||||
y = [y]
|
||||
target_tensors = [v for v in y if tensor_util.is_tensor(v)]
|
||||
else:
|
||||
target_tensors = None
|
||||
self.compile(optimizer=self.optimizer,
|
||||
loss=self.loss,
|
||||
metrics=self.metrics,
|
||||
@ -761,7 +761,7 @@ class Model(Network):
|
||||
# What follows is input validation and standardization to list format,
|
||||
# in the case where all inputs are value arrays.
|
||||
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
# In eager mode, do not do shape validation.
|
||||
feed_input_names = self.input_names
|
||||
feed_input_shapes = None
|
||||
@ -784,7 +784,7 @@ class Model(Network):
|
||||
exception_prefix='input')
|
||||
|
||||
if y is not None:
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
feed_output_names = self.output_names
|
||||
feed_output_shapes = None
|
||||
# Sample weighting not supported in this case.
|
||||
@ -835,7 +835,7 @@ class Model(Network):
|
||||
]
|
||||
# Check that all arrays have the same length.
|
||||
training_utils.check_array_lengths(x, y, sample_weights)
|
||||
if self._is_graph_network and not context.in_eager_mode():
|
||||
if self._is_graph_network and not context.executing_eagerly():
|
||||
# Additional checks to avoid users mistakenly using improper loss fns.
|
||||
training_utils.check_loss_and_target_compatibility(
|
||||
y, self._feed_loss_fns, feed_output_shapes)
|
||||
@ -874,7 +874,7 @@ class Model(Network):
|
||||
whether to build the model's graph in inference mode (False), training
|
||||
mode (True), or using the Keras learning phase (None).
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
self._eager_set_inputs(inputs)
|
||||
else:
|
||||
self._symbolic_set_inputs(inputs, training=training)
|
||||
@ -903,7 +903,7 @@ class Model(Network):
|
||||
Raises:
|
||||
ValueError: If the model's inputs are already set.
|
||||
"""
|
||||
assert context.in_eager_mode()
|
||||
assert context.executing_eagerly()
|
||||
if self.inputs:
|
||||
raise ValueError('Model inputs are already set.')
|
||||
# On-the-fly setting of model inputs/outputs as DeferredTensors,
|
||||
@ -950,7 +950,7 @@ class Model(Network):
|
||||
Raises:
|
||||
ValueError: If the model's inputs are already set.
|
||||
"""
|
||||
assert context.in_graph_mode()
|
||||
assert not context.executing_eagerly()
|
||||
if self.inputs:
|
||||
raise ValueError('Model inputs are already set.')
|
||||
|
||||
@ -1186,7 +1186,7 @@ class Model(Network):
|
||||
val_y = None
|
||||
val_sample_weights = None
|
||||
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return training_eager.fit_loop(
|
||||
self,
|
||||
inputs=x,
|
||||
@ -1289,7 +1289,7 @@ class Model(Network):
|
||||
sample_weight=sample_weight,
|
||||
batch_size=batch_size)
|
||||
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return training_eager.test_loop(
|
||||
self, inputs=x, targets=y, sample_weights=sample_weights,
|
||||
batch_size=batch_size, verbose=verbose, steps=steps)
|
||||
@ -1330,7 +1330,7 @@ class Model(Network):
|
||||
'argument.')
|
||||
x, _, _ = self._standardize_user_data(x)
|
||||
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return training_eager.predict_loop(
|
||||
self, x, batch_size=batch_size, verbose=verbose, steps=steps)
|
||||
else:
|
||||
@ -1381,7 +1381,7 @@ class Model(Network):
|
||||
sample_weight=sample_weight,
|
||||
class_weight=class_weight)
|
||||
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
outputs = training_eager.train_on_batch(
|
||||
self, x, y, sample_weights=sample_weights)
|
||||
else:
|
||||
@ -1431,7 +1431,7 @@ class Model(Network):
|
||||
x, y, sample_weights = self._standardize_user_data(
|
||||
x, y, sample_weight=sample_weight)
|
||||
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
outputs = training_eager.test_on_batch(
|
||||
self, x, y, sample_weights=sample_weights)
|
||||
else:
|
||||
@ -1458,11 +1458,11 @@ class Model(Network):
|
||||
"""
|
||||
x, _, _ = self._standardize_user_data(x)
|
||||
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
inputs = [ops.convert_to_tensor(val, dtype=K.floatx()) for val in x]
|
||||
return self(inputs) # pylint: disable=not-callable
|
||||
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
|
||||
ins = x + [0]
|
||||
else:
|
||||
|
@ -553,7 +553,7 @@ class ZeroPaddingTest(test.TestCase):
|
||||
layer = keras.layers.ZeroPadding1D(padding=2)
|
||||
layer.build(shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
@ -564,7 +564,7 @@ class ZeroPaddingTest(test.TestCase):
|
||||
layer = keras.layers.ZeroPadding1D(padding=(1, 2))
|
||||
layer.build(shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
@ -610,7 +610,7 @@ class ZeroPaddingTest(test.TestCase):
|
||||
padding=(2, 2), data_format=data_format)
|
||||
layer.build(inputs.shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
@ -629,7 +629,7 @@ class ZeroPaddingTest(test.TestCase):
|
||||
padding=((1, 2), (3, 4)), data_format=data_format)
|
||||
layer.build(inputs.shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
@ -683,7 +683,7 @@ class ZeroPaddingTest(test.TestCase):
|
||||
layer = keras.layers.ZeroPadding3D(padding=(2, 2, 2))
|
||||
layer.build(inputs.shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
@ -737,7 +737,7 @@ class UpSamplingTest(test.TestCase):
|
||||
size=(length_row, length_col), data_format=data_format)
|
||||
layer.build(inputs.shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
@ -790,7 +790,7 @@ class UpSamplingTest(test.TestCase):
|
||||
data_format=data_format)
|
||||
layer.build(inputs.shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
@ -865,7 +865,7 @@ class CroppingTest(test.TestCase):
|
||||
cropping=cropping, data_format=data_format)
|
||||
layer.build(inputs.shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
@ -892,7 +892,7 @@ class CroppingTest(test.TestCase):
|
||||
cropping=cropping, data_format=data_format)
|
||||
layer.build(inputs.shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
@ -937,7 +937,7 @@ class CroppingTest(test.TestCase):
|
||||
cropping=cropping, data_format=data_format)
|
||||
layer.build(inputs.shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
@ -954,7 +954,7 @@ class CroppingTest(test.TestCase):
|
||||
cropping[2][0]:-cropping[2][1], :]
|
||||
np.testing.assert_allclose(np_output, expected_out)
|
||||
|
||||
# test incorrect use
|
||||
# test incorrect use
|
||||
with self.assertRaises(ValueError):
|
||||
keras.layers.Cropping3D(cropping=(1, 1))
|
||||
with self.assertRaises(ValueError):
|
||||
|
@ -124,7 +124,7 @@ class Dropout(tf_core_layers.Dropout, Layer):
|
||||
training = K.learning_phase()
|
||||
output = super(Dropout, self).call(inputs, training=training)
|
||||
# EagerTensor object has no attribute _uses_learning_phase
|
||||
if not context.in_eager_mode() and training is K.learning_phase():
|
||||
if not context.executing_eagerly() and training is K.learning_phase():
|
||||
output._uses_learning_phase = True # pylint: disable=protected-access
|
||||
return output
|
||||
|
||||
|
@ -111,7 +111,7 @@ class BatchNormalization(tf_normalization_layers.BatchNormalization, Layer):
|
||||
if training is None:
|
||||
training = K.learning_phase()
|
||||
output = super(BatchNormalization, self).call(inputs, training=training)
|
||||
if context.in_graph_mode() and training is K.learning_phase():
|
||||
if not context.executing_eagerly() and training is K.learning_phase():
|
||||
output._uses_learning_phase = True # pylint: disable=protected-access
|
||||
return output
|
||||
|
||||
|
@ -105,7 +105,7 @@ class Pooling2DTest(test.TestCase):
|
||||
|
||||
# This part of the test can only run on GPU but doesn't appear
|
||||
# to be properly assigned to a GPU when running in eager mode.
|
||||
if not context.in_eager_mode():
|
||||
if not context.executing_eagerly():
|
||||
# Only runs on GPU with CUDA, channels_first is not supported on CPU.
|
||||
# TODO(b/62340061): Support channels_first on CPU.
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
|
@ -936,7 +936,7 @@ class SimpleRNNCell(Layer):
|
||||
|
||||
# Properly set learning phase on output tensor.
|
||||
if 0 < self.dropout + self.recurrent_dropout:
|
||||
if training is None and not context.in_eager_mode():
|
||||
if training is None and not context.executing_eagerly():
|
||||
# This would be harmless to set in eager mode, but eager tensors
|
||||
# disallow setting arbitrary attributes.
|
||||
output._uses_learning_phase = True
|
||||
@ -1384,7 +1384,7 @@ class GRUCell(Layer):
|
||||
hh = self.activation(x_h + recurrent_h)
|
||||
h = z * h_tm1 + (1 - z) * hh
|
||||
if 0 < self.dropout + self.recurrent_dropout:
|
||||
if training is None and not context.in_eager_mode():
|
||||
if training is None and not context.executing_eagerly():
|
||||
# This would be harmless to set in eager mode, but eager tensors
|
||||
# disallow setting arbitrary attributes.
|
||||
h._uses_learning_phase = True
|
||||
@ -1877,7 +1877,7 @@ class LSTMCell(Layer):
|
||||
|
||||
h = o * self.activation(c)
|
||||
if 0 < self.dropout + self.recurrent_dropout:
|
||||
if training is None and not context.in_eager_mode():
|
||||
if training is None and not context.executing_eagerly():
|
||||
# This would be harmless to set in eager mode, but eager tensors
|
||||
# disallow setting arbitrary attributes.
|
||||
h._uses_learning_phase = True
|
||||
|
@ -83,14 +83,14 @@ class AtrousConvolutionTest(test.TestCase):
|
||||
checks = []
|
||||
|
||||
def add_check(check, *args, **kwargs):
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
args_val, kwargs_val = self.evaluate([args, kwargs])
|
||||
check(*args_val, **kwargs_val)
|
||||
else:
|
||||
checks.append((check, args, kwargs))
|
||||
|
||||
yield add_check
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
all_values = self.evaluate([[args, kwargs] for _, args, kwargs in checks])
|
||||
for (check, _, _), (args, kwargs) in zip(checks, all_values):
|
||||
check(*args, **kwargs)
|
||||
|
@ -102,17 +102,15 @@ class AssertEqualTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"):
|
||||
check_ops.assert_equal(static_big, static_small, message="fail")
|
||||
|
||||
# Dynamic check
|
||||
if context.in_graph_mode():
|
||||
with self.test_session():
|
||||
small = array_ops.placeholder(dtypes.int32, name="small")
|
||||
big = array_ops.placeholder(dtypes.int32, name="big")
|
||||
with ops.control_dependencies(
|
||||
[check_ops.assert_equal(
|
||||
big, small, message="fail")]):
|
||||
out = array_ops.identity(small)
|
||||
with self.assertRaisesOpError("fail.*big.*small"):
|
||||
out.eval(feed_dict={small: [1, 2], big: [3, 4]})
|
||||
def test_raises_when_greater_dynamic(self):
|
||||
with self.test_session():
|
||||
small = array_ops.placeholder(dtypes.int32, name="small")
|
||||
big = array_ops.placeholder(dtypes.int32, name="big")
|
||||
with ops.control_dependencies(
|
||||
[check_ops.assert_equal(big, small, message="fail")]):
|
||||
out = array_ops.identity(small)
|
||||
with self.assertRaisesOpError("fail.*big.*small"):
|
||||
out.eval(feed_dict={small: [1, 2], big: [3, 4]})
|
||||
|
||||
def test_error_message_eager(self):
|
||||
expected_error_msg_full = r"""big does not equal small
|
||||
@ -182,15 +180,14 @@ First 2 elements of y:
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"):
|
||||
check_ops.assert_equal(static_big, static_small, message="fail")
|
||||
|
||||
# Dynamic check
|
||||
if context.in_graph_mode():
|
||||
with self.test_session():
|
||||
small = array_ops.placeholder(dtypes.int32, name="small")
|
||||
big = array_ops.placeholder(dtypes.int32, name="big")
|
||||
with ops.control_dependencies([check_ops.assert_equal(small, big)]):
|
||||
out = array_ops.identity(small)
|
||||
with self.assertRaisesOpError("small.*big"):
|
||||
out.eval(feed_dict={small: [3, 1], big: [4, 2]})
|
||||
def test_raises_when_less_dynamic(self):
|
||||
with self.test_session():
|
||||
small = array_ops.placeholder(dtypes.int32, name="small")
|
||||
big = array_ops.placeholder(dtypes.int32, name="big")
|
||||
with ops.control_dependencies([check_ops.assert_equal(small, big)]):
|
||||
out = array_ops.identity(small)
|
||||
with self.assertRaisesOpError("small.*big"):
|
||||
out.eval(feed_dict={small: [3, 1], big: [4, 2]})
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def test_doesnt_raise_when_equal_and_broadcastable_shapes(self):
|
||||
|
@ -360,7 +360,7 @@ class PyFuncTest(test.TestCase):
|
||||
raise py_exp("blah") # pylint: disable=not-callable
|
||||
|
||||
if eager:
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
with self.assertRaisesRegexp(tf_exp, "blah"):
|
||||
f = script_ops.eager_py_func(raise_exception, [], [])
|
||||
return
|
||||
@ -432,7 +432,7 @@ class PyFuncTest(test.TestCase):
|
||||
|
||||
output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[])
|
||||
ret = self.evaluate(output)
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
self.assertEquals(len(ret), 0)
|
||||
else:
|
||||
self.assertIsNone(ret)
|
||||
|
@ -279,15 +279,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# Tests for the 'read_value' argument:
|
||||
assign_with_read = v.assign(3.0, read_value=True)
|
||||
if context.in_graph_mode():
|
||||
self.assertEqual(3.0, assign_with_read.eval())
|
||||
else:
|
||||
self.assertEqual(3.0, self.evaluate(assign_with_read))
|
||||
self.assertEqual(3.0, self.evaluate(assign_with_read))
|
||||
assign_without_read = v.assign(4.0, read_value=False)
|
||||
if context.in_graph_mode():
|
||||
self.assertIsInstance(assign_without_read, ops.Operation)
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
self.assertIsNone(assign_without_read)
|
||||
else:
|
||||
self.assertIsInstance(assign_without_read, ops.Operation)
|
||||
self.evaluate(assign_without_read)
|
||||
self.assertEqual(4.0, self.evaluate(v.value()))
|
||||
|
||||
@ -355,15 +352,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# Tests for the 'read_value' argument:
|
||||
assign_with_read = v.assign_add(1.0, read_value=True)
|
||||
if context.in_graph_mode():
|
||||
self.assertEqual(3.0, assign_with_read.eval())
|
||||
else:
|
||||
self.assertEqual(3.0, self.evaluate(assign_with_read))
|
||||
self.assertEqual(3.0, self.evaluate(assign_with_read))
|
||||
assign_without_read = v.assign_add(1.0, read_value=False)
|
||||
if context.in_graph_mode():
|
||||
self.assertIsInstance(assign_without_read, ops.Operation)
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
self.assertIsNone(assign_without_read)
|
||||
else:
|
||||
self.assertIsInstance(assign_without_read, ops.Operation)
|
||||
self.evaluate(assign_without_read)
|
||||
self.assertEqual(4.0, self.evaluate(v.value()))
|
||||
|
||||
@ -376,15 +370,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# Tests for the 'read_value' argument:
|
||||
assign_with_read = v.assign_sub(1.0, read_value=True)
|
||||
if context.in_graph_mode():
|
||||
self.assertEqual(1.0, assign_with_read.eval())
|
||||
else:
|
||||
self.assertEqual(1.0, self.evaluate(assign_with_read))
|
||||
self.assertEqual(1.0, self.evaluate(assign_with_read))
|
||||
assign_without_read = v.assign_sub(1.0, read_value=False)
|
||||
if context.in_graph_mode():
|
||||
self.assertIsInstance(assign_without_read, ops.Operation)
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
self.assertIsNone(assign_without_read)
|
||||
else:
|
||||
self.assertIsInstance(assign_without_read, ops.Operation)
|
||||
self.evaluate(assign_without_read)
|
||||
self.assertEqual(0.0, self.evaluate(v.value()))
|
||||
|
||||
@ -485,7 +476,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual("(10, 20, 35)", str(v.get_shape()))
|
||||
self.assertEqual("(10, 20, 35)", str(v.value().shape))
|
||||
self.assertEqual("(3, 20, 35)", str(v.sparse_read([0, 1, 2]).shape))
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.assertEqual(
|
||||
"<unknown>",
|
||||
str(v.sparse_read(array_ops.placeholder(dtypes.int32)).shape))
|
||||
|
@ -111,10 +111,10 @@ class RNNTest(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testInvalidSequenceLengthShape(self):
|
||||
cell = Plus1RNNCell()
|
||||
if context.in_graph_mode():
|
||||
inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))]
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
inputs = [constant_op.constant(np.ones((3, 4)))]
|
||||
else:
|
||||
inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))]
|
||||
with self.assertRaisesRegexp(ValueError, "must be a vector"):
|
||||
rnn.dynamic_rnn(
|
||||
cell,
|
||||
@ -125,38 +125,30 @@ class RNNTest(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testBatchSizeFromInput(self):
|
||||
cell = Plus1RNNCell()
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
in_eager_mode = context.executing_eagerly()
|
||||
# With static batch size
|
||||
if in_graph_mode:
|
||||
inputs = array_ops.placeholder(dtypes.float32, shape=(3, 4, 5))
|
||||
initial_state = array_ops.placeholder(dtypes.float32, shape=(3, 5))
|
||||
else:
|
||||
if in_eager_mode:
|
||||
inputs = np.zeros((3, 4, 5), dtype=np.float32)
|
||||
initial_state = np.zeros((3, 5), dtype=np.float32)
|
||||
else:
|
||||
inputs = array_ops.placeholder(dtypes.float32, shape=(3, 4, 5))
|
||||
initial_state = array_ops.placeholder(dtypes.float32, shape=(3, 5))
|
||||
|
||||
# - Without initial_state
|
||||
outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
if in_graph_mode:
|
||||
self.assertEqual(3, outputs.shape[0].value)
|
||||
self.assertEqual(3, state.shape[0].value)
|
||||
else:
|
||||
self.assertEqual(3, outputs.shape[0])
|
||||
self.assertEqual(3, state.shape[0])
|
||||
self.assertEqual(3, outputs.shape[0])
|
||||
self.assertEqual(3, state.shape[0])
|
||||
|
||||
# - With initial_state
|
||||
outputs, state = rnn.dynamic_rnn(
|
||||
cell, inputs, initial_state=initial_state)
|
||||
if in_graph_mode:
|
||||
self.assertEqual(3, outputs.shape[0].value)
|
||||
self.assertEqual(3, state.shape[0].value)
|
||||
else:
|
||||
self.assertEqual(3, outputs.shape[0])
|
||||
self.assertEqual(3, state.shape[0])
|
||||
self.assertEqual(3, outputs.shape[0])
|
||||
self.assertEqual(3, state.shape[0])
|
||||
|
||||
# Without static batch size
|
||||
# Tensor shapes are fully determined in Eager mode, so only run this
|
||||
# test in graph mode.
|
||||
if in_graph_mode:
|
||||
# Tensor shapes are fully determined with eager execution enabled,
|
||||
# so only run this test for graph construction.
|
||||
if not in_eager_mode:
|
||||
inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 5))
|
||||
# - Without initial_state
|
||||
outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
@ -173,56 +165,46 @@ class RNNTest(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testScalarStateIsAccepted(self):
|
||||
cell = ScalarStateRNNCell()
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
in_eager_mode = context.executing_eagerly()
|
||||
|
||||
if in_graph_mode:
|
||||
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
|
||||
else:
|
||||
if in_eager_mode:
|
||||
inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32)
|
||||
else:
|
||||
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
|
||||
|
||||
with self.test_session() as sess:
|
||||
outputs, state = rnn.dynamic_rnn(
|
||||
cell, inputs, dtype=dtypes.float32, sequence_length=[4])
|
||||
if in_graph_mode:
|
||||
if not in_eager_mode:
|
||||
outputs, state = sess.run(
|
||||
[outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]})
|
||||
|
||||
if in_graph_mode:
|
||||
self.assertAllEqual(outputs, np.array([[[1], [2], [3], [4]]]))
|
||||
self.assertEqual(state, 4)
|
||||
else:
|
||||
self.assertAllEqual(outputs.numpy(), np.array([[[1], [2], [3], [4]]]))
|
||||
self.assertEqual(state.numpy(), 4)
|
||||
self.assertAllEqual([[[1], [2], [3], [4]]], outputs)
|
||||
self.assertAllEqual(4, state)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testTensorArrayStateIsAccepted(self):
|
||||
cell = TensorArrayStateRNNCell()
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
in_eager_mode = context.executing_eagerly()
|
||||
|
||||
if in_graph_mode:
|
||||
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
|
||||
else:
|
||||
if in_eager_mode:
|
||||
inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32)
|
||||
else:
|
||||
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
|
||||
|
||||
with self.test_session() as sess:
|
||||
outputs, state = rnn.dynamic_rnn(
|
||||
cell, inputs, dtype=dtypes.float32, sequence_length=[4])
|
||||
state = (state[0], state[1].stack())
|
||||
if in_graph_mode:
|
||||
if not in_eager_mode:
|
||||
outputs, state = sess.run(
|
||||
[outputs, state], feed_dict={
|
||||
inputs: [[[1], [2], [3], [4]]]
|
||||
})
|
||||
|
||||
if in_graph_mode:
|
||||
self.assertAllEqual(outputs, np.array([[[1], [2], [3], [4]]]))
|
||||
self.assertEqual(state[0], 4)
|
||||
self.assertAllEqual(state[1], np.array([[[1]], [[2]], [[3]], [[4]]]))
|
||||
else:
|
||||
self.assertAllEqual(outputs.numpy(), np.array([[[1], [2], [3], [4]]]))
|
||||
self.assertEqual(state[0].numpy(), 4)
|
||||
self.assertAllEqual(state[1].numpy(),
|
||||
np.array([[[1]], [[2]], [[3]], [[4]]]))
|
||||
self.assertAllEqual([[[1], [2], [3], [4]]], outputs)
|
||||
self.assertAllEqual(4, state[0])
|
||||
self.assertAllEqual([[[1]], [[2]], [[3]], [[4]]], state[1])
|
||||
|
||||
|
||||
######### Benchmarking RNN code
|
||||
|
@ -283,7 +283,7 @@ class SliceTest(test.TestCase):
|
||||
# unintended behavior is prevented.
|
||||
c = constant_op.constant(5.0)
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
TypeError, lambda e: "`Tensor` objects are not iterable" in str(e)):
|
||||
TypeError, lambda e: "Tensor objects are not iterable" in str(e)):
|
||||
for _ in c:
|
||||
pass
|
||||
|
||||
|
@ -562,7 +562,7 @@ class TemplateTest(test.TestCase):
|
||||
outputs_b, _ = linear1(inputs)
|
||||
self.assertEquals("foo", linear1.variable_scope.name)
|
||||
self.assertEquals("foo/w:0", w1.name)
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.assertEquals("foo/add:0", outputs_a.name,
|
||||
"First application of template should get "
|
||||
"same name scope as variables.")
|
||||
@ -577,7 +577,7 @@ class TemplateTest(test.TestCase):
|
||||
"New template gets a freshly uniquified variable scope "
|
||||
"because 'foo' is already taken.")
|
||||
self.assertEquals("foo_1/w:0", w2.name)
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.assertEquals("foo_1_1/add:0", outputs_c.name,
|
||||
"First application of template would get "
|
||||
"same name scope as variables, but 'foo_1' is already "
|
||||
@ -592,7 +592,7 @@ class TemplateTest(test.TestCase):
|
||||
with variable_scope.variable_scope("foo"):
|
||||
# Create two templates with the same name, ensure scopes are made unique.
|
||||
ta = template.make_template("bar", variable_scoped_function, True)
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
tb = template.make_template("s", function_with_side_create,
|
||||
trainable=False)
|
||||
else:
|
||||
|
@ -399,28 +399,14 @@ class TensorArrayTest(test.TestCase):
|
||||
def testTensorArrayWriteWrongIndexOrDataTypeFails(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
ta = _make_ta(3, "foo", dtype=dtypes.float32)
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
# Test writing the wrong datatype
|
||||
if in_graph_mode:
|
||||
with self.assertRaisesOpError(
|
||||
"TensorArray dtype is float but Op is trying to write "
|
||||
"dtype string"):
|
||||
self.evaluate(ta.write(0, "wrong_type_scalar").flow)
|
||||
else:
|
||||
with self.assertRaisesOpError(
|
||||
"TensorArray dtype is float32 but Op is trying to write "
|
||||
"dtype string"):
|
||||
self.evaluate(ta.write(0, "wrong_type_scalar").flow)
|
||||
with self.assertRaisesOpError(
|
||||
"TensorArray dtype is (float|float32) but Op is trying to write "
|
||||
"dtype string"):
|
||||
self.evaluate(ta.write(0, "wrong_type_scalar").flow)
|
||||
|
||||
if context.in_graph_mode():
|
||||
with self.assertRaisesOpError(
|
||||
"Tried to write to index -1 but array is not "
|
||||
"resizeable and size is: 3"):
|
||||
self.evaluate(ta.write(-1, 3.0).flow)
|
||||
else:
|
||||
with self.assertRaisesOpError(
|
||||
r"Writing to negative indices \(index -1\) is not allowed."):
|
||||
self.evaluate(ta.write(-1, 3.0).flow)
|
||||
with self.assertRaisesOpError("index -1"):
|
||||
self.evaluate(ta.write(-1, 3.0).flow)
|
||||
|
||||
# Test reading from too large an index
|
||||
with self.assertRaisesOpError(
|
||||
@ -435,8 +421,8 @@ class TensorArrayTest(test.TestCase):
|
||||
|
||||
w0 = ta.write(0, [[4.0, 5.0]])
|
||||
|
||||
# Test reading wrong datatype, which is only possible in graph mode
|
||||
if context.in_graph_mode():
|
||||
# Test reading wrong datatype (only possible when constructing graphs).
|
||||
if not context.executing_eagerly():
|
||||
r0_bad = gen_data_flow_ops.tensor_array_read_v3(
|
||||
handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow)
|
||||
with self.assertRaisesOpError(
|
||||
@ -444,14 +430,8 @@ class TensorArrayTest(test.TestCase):
|
||||
r0_bad.eval()
|
||||
|
||||
# Test reading from a negative index, which is not allowed
|
||||
if context.in_graph_mode():
|
||||
with self.assertRaisesOpError(
|
||||
r"Tried to read from index -1 but array size is: 3"):
|
||||
self.evaluate(ta.read(-1))
|
||||
else:
|
||||
with self.assertRaisesOpError(
|
||||
r"Reading from negative indices \(index -1\) is not allowed."):
|
||||
self.evaluate(ta.read(-1))
|
||||
with self.assertRaisesOpError("index -1"):
|
||||
self.evaluate(ta.read(-1))
|
||||
|
||||
# Test reading from too large an index
|
||||
with self.assertRaisesOpError(
|
||||
@ -467,10 +447,7 @@ class TensorArrayTest(test.TestCase):
|
||||
with self.assertRaisesOpError(
|
||||
"Could not write to TensorArray index 2 because "
|
||||
"it has already been written to."):
|
||||
if context.in_graph_mode():
|
||||
self.evaluate(ta.write(2, 3.0).write(2, 3.0).flow)
|
||||
else:
|
||||
self.evaluate(ta.write(2, 3.0).write(2, 3.0))
|
||||
self.evaluate(ta.write(2, 3.0).write(2, 3.0).flow)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testTensorArrayConcatIncompatibleShapesFails(self):
|
||||
@ -499,58 +476,40 @@ class TensorArrayTest(test.TestCase):
|
||||
w2 = w1.write(1, [4.0])
|
||||
w3 = w2.write(2, [[3.0]])
|
||||
|
||||
# The eager-mode implementation just passes up array_op.concat's error
|
||||
# message.
|
||||
if context.in_graph_mode():
|
||||
with self.assertRaisesOpError(
|
||||
r"TensorArray has inconsistent shapes. Index 0 has "
|
||||
r"\(excepting dimension 0\) shape: \[\] but index 2 has "
|
||||
r"\(excepting dimension 0\) shape: \[1\]"):
|
||||
self.evaluate(w3.concat())
|
||||
else:
|
||||
with self.assertRaisesOpError(
|
||||
r".*Ranks of all input tensors should match: shape\[0\] "
|
||||
r"= \[1\] vs\. shape\[2\] = \[1,1\].*"):
|
||||
self.evaluate(w3.concat())
|
||||
# The exact error messages differ between eager execution and graph
|
||||
# construction as the former bubbles up the error from array_op.concat.
|
||||
with self.assertRaisesOpError("shape"):
|
||||
self.evaluate(w3.concat())
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testTensorArraySplitIncompatibleShapesFails(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
in_eager_mode = context.executing_eagerly()
|
||||
ta = _make_ta(3, "foo")
|
||||
with self.assertRaisesOpError(
|
||||
r"Expected lengths to be a vector, received shape: \[\]"):
|
||||
if in_graph_mode:
|
||||
if in_eager_mode:
|
||||
self.evaluate(ta.split([1.0, 2.0, 3.0], 1))
|
||||
else:
|
||||
lengths = array_ops.placeholder(dtypes.int64)
|
||||
ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1})
|
||||
else:
|
||||
self.evaluate(ta.split([1.0, 2.0, 3.0], 1))
|
||||
|
||||
with self.assertRaisesOpError(
|
||||
r"Expected sum of lengths to be equal to values.shape\[0\], "
|
||||
r"but sum of lengths is 1 and value's shape is: \[3\]"):
|
||||
if in_graph_mode:
|
||||
self.evaluate(ta.split([1.0, 2.0, 3.0], [1]).flow)
|
||||
else:
|
||||
self.evaluate(ta.split([1.0, 2.0, 3.0], [1]))
|
||||
self.evaluate(ta.split([1.0, 2.0, 3.0], [1]).flow)
|
||||
|
||||
ta = _make_ta(1, "baz")
|
||||
with self.assertRaisesOpError(
|
||||
r"Expected value to be at least a vector, but received shape: \[\]"):
|
||||
if in_graph_mode:
|
||||
self.evaluate(ta.split(1.0, [1]).flow)
|
||||
else:
|
||||
self.evaluate(ta.split(1.0, [1]))
|
||||
self.evaluate(ta.split(1.0, [1]).flow)
|
||||
|
||||
ta = _make_ta(2, "buz")
|
||||
with self.assertRaisesOpError(
|
||||
r"TensorArray's size is not equal to the size of lengths "
|
||||
r"\(2 vs. 1\), and the TensorArray is not marked as "
|
||||
r"dynamically resizeable"):
|
||||
if in_graph_mode:
|
||||
self.evaluate(ta.split([1.0], [1]).flow)
|
||||
else:
|
||||
self.evaluate(ta.split([1.0], [1]))
|
||||
self.evaluate(ta.split([1.0], [1]).flow)
|
||||
|
||||
def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype):
|
||||
with self.test_session(use_gpu=True):
|
||||
@ -868,14 +827,14 @@ class TensorArrayTest(test.TestCase):
|
||||
|
||||
vout = func(v0, state0, var)
|
||||
grad_val = -np.arange(3 * 5, dtype=np_dtype).reshape(3, 5)
|
||||
if context.in_graph_mode():
|
||||
if context.executing_eagerly():
|
||||
grad_fn = backprop.gradients_function(func)
|
||||
v0_grad, state0_grad, var_grad = grad_fn(v0, state0, var, dy=grad_val)
|
||||
else:
|
||||
v0_grad = gradients_impl.gradients([vout], [v0], [grad_val])[0]
|
||||
state0_grad = gradients_impl.gradients([vout], [state0], [grad_val])[0]
|
||||
var_grad = gradients_impl.gradients([vout], [var], [grad_val])[0]
|
||||
variables.global_variables_initializer().run()
|
||||
else:
|
||||
grad_fn = backprop.gradients_function(func)
|
||||
v0_grad, state0_grad, var_grad = grad_fn(v0, state0, var, dy=grad_val)
|
||||
|
||||
state0_t, var_t, v0_t, vout_t, v0_grad_t, var_grad_t, state0_grad_t = (
|
||||
self.evaluate(
|
||||
@ -959,10 +918,10 @@ class TensorArrayTest(test.TestCase):
|
||||
return r
|
||||
|
||||
x = constant_op.constant(2.0, name="x")
|
||||
if context.in_graph_mode():
|
||||
grad = gradients_impl.gradients(loop(x), [x])[0]
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
grad = backprop.gradients_function(loop)(x)[0]
|
||||
else:
|
||||
grad = gradients_impl.gradients(loop(x), [x])[0]
|
||||
self.assertAllClose(31.0, self.evaluate(grad))
|
||||
|
||||
def testSumOfTwoReadVariablesWithoutRepeatGrad(self):
|
||||
@ -1158,14 +1117,14 @@ class TensorArrayTest(test.TestCase):
|
||||
infer_shape=True)
|
||||
w0 = ta1.split(value, [1, 2])
|
||||
r0 = w0.read(0)
|
||||
if context.in_graph_mode():
|
||||
if context.executing_eagerly():
|
||||
self.assertEqual((1, 2), r0.get_shape())
|
||||
self.assertEqual((2, 2), w0.read(1).get_shape())
|
||||
else:
|
||||
self.assertEqual(r0.get_shape().ndims, None)
|
||||
self.assertEqual(
|
||||
tensor_shape.TensorShape(
|
||||
ta1.handle.op.get_attr("element_shape")).ndims, None)
|
||||
else:
|
||||
self.assertEqual((1, 2), r0.get_shape())
|
||||
self.assertEqual((2, 2), w0.read(1).get_shape())
|
||||
|
||||
def testWriteUnknownShape(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
@ -1297,13 +1256,13 @@ class TensorArrayTest(test.TestCase):
|
||||
g = func(values)
|
||||
grad_ys = [[[2.0, 3.0], [4.0, 5.0]]]
|
||||
# Test combined gradients + aggregation of read(0)
|
||||
if context.in_graph_mode():
|
||||
grad = gradients_impl.gradients(ys=[g], xs=[values], grad_ys=grad_ys)
|
||||
g_vals, grad_vals = session.run([[g], grad])
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
g_vals = [g]
|
||||
grad_vals = backprop.gradients_function(func)(
|
||||
values, dy=constant_op.constant(grad_ys[0], dtype=dtypes.float32))
|
||||
else:
|
||||
grad = gradients_impl.gradients(ys=[g], xs=[values], grad_ys=grad_ys)
|
||||
g_vals, grad_vals = session.run([[g], grad])
|
||||
|
||||
# Gradients for 8 of the 10 unread components are zero.
|
||||
expected_grad = np.zeros((10, 2))
|
||||
@ -1453,13 +1412,13 @@ class TensorArrayTest(test.TestCase):
|
||||
# Tests correct properties on new TensorArrays.
|
||||
self.assertEqual(dtypes.float32, ta0.dtype)
|
||||
self.assertEqual(dtypes.int32, ta1.dtype)
|
||||
if context.in_graph_mode():
|
||||
self.assertEqual(tensor_shape.unknown_shape(), read0.get_shape())
|
||||
if context.executing_eagerly():
|
||||
self.assertEqual(tensor_shape.scalar(), read0.get_shape())
|
||||
else:
|
||||
self.assertEqual(tensor_shape.scalar(), read1.get_shape())
|
||||
self.assertEqual(tensor_shape.unknown_shape(), read0.get_shape())
|
||||
self.assertEqual(tensor_shape.scalar(), read1.get_shape())
|
||||
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
read0_v, read1_v, size0_v, size1_v = self.evaluate((read0, read1, size0,
|
||||
|
@ -166,12 +166,10 @@ class VariableScopeTest(test.TestCase):
|
||||
self.evaluate(variables_lib.variables_initializer([w]))
|
||||
self.assertAllClose(self.evaluate(w.value()), [1, 2, 3])
|
||||
|
||||
if context.in_graph_mode():
|
||||
with self.assertRaises(TypeError):
|
||||
variable_scope.get_variable("x4", initializer={})
|
||||
else:
|
||||
with self.assertRaises(ValueError):
|
||||
variable_scope.get_variable("x4", initializer={})
|
||||
# A quirk to be revisited?
|
||||
error = ValueError if context.executing_eagerly() else TypeError
|
||||
with self.assertRaises(error):
|
||||
variable_scope.get_variable("x4", initializer={})
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testInitFromNonInitializer(self):
|
||||
@ -267,7 +265,7 @@ class VariableScopeTest(test.TestCase):
|
||||
self.assertAllClose(self.evaluate(losses[2]), 0.5)
|
||||
with variable_scope.variable_scope("foo", reuse=True):
|
||||
# reuse=True is for now only supported when eager execution is disabled.
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
v = variable_scope.get_variable("v",
|
||||
[]) # "v" is alredy there, reused
|
||||
losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
|
||||
@ -374,7 +372,7 @@ class VariableScopeTest(test.TestCase):
|
||||
v = variable_scope.get_variable("v", [])
|
||||
self.evaluate(variables_lib.variables_initializer([v]))
|
||||
self.assertAllClose(self.evaluate(v.value()), 0.3)
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
# Check that we can set reuse.
|
||||
variable_scope.get_variable_scope().reuse_variables()
|
||||
with self.assertRaises(ValueError): # Fail, w does not exist yet.
|
||||
@ -408,7 +406,7 @@ class VariableScopeTest(test.TestCase):
|
||||
with variable_scope.variable_scope("tower") as tower:
|
||||
with ops.name_scope("scope2") as sc2:
|
||||
self.assertEqual(sc2, "testVarScopeNameScope1/tower/scope2/")
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
with variable_scope.variable_scope(
|
||||
tower): # Re-entering acts like another "tower".
|
||||
with ops.name_scope("scope2") as sc2:
|
||||
@ -422,7 +420,7 @@ class VariableScopeTest(test.TestCase):
|
||||
with variable_scope.variable_scope("tower"):
|
||||
with ops.name_scope("scope2") as sc2:
|
||||
self.assertEqual(sc2, "testVarScopeNameScope2/tower/scope2/")
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
with variable_scope.variable_scope(tower):
|
||||
with ops.name_scope("scope2") as sc2:
|
||||
self.assertEqual(sc2, "testVarScopeNameScope2/tower_1/scope2/")
|
||||
@ -903,17 +901,15 @@ class VariableScopeTest(test.TestCase):
|
||||
"w", [], collections=["foo"])
|
||||
self.assertEqual(local_var.name, "outer/w:0")
|
||||
|
||||
# Since variable is local, it should be in the local variable collection
|
||||
# but not the trainable collection.
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
# Since variable is local, it should be in the local variable collection
|
||||
# but not the trainable collection.
|
||||
self.assertIn(local_var,
|
||||
ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
|
||||
self.assertIn(local_var, ops.get_collection("foo"))
|
||||
self.assertNotIn(local_var,
|
||||
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
|
||||
|
||||
# Check that local variable respects `reuse`.
|
||||
if context.in_graph_mode():
|
||||
# Check that local variable respects `reuse`.
|
||||
with variable_scope.variable_scope(outer, "default", reuse=True):
|
||||
self.assertEqual(
|
||||
variable_scope.get_local_variable("w", []).name, "outer/w:0")
|
||||
|
@ -115,7 +115,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
# Provides information about which inputs are compatible with the layer.
|
||||
self.input_spec = None
|
||||
|
||||
if activity_regularizer and context.in_eager_mode():
|
||||
if activity_regularizer and context.executing_eagerly():
|
||||
raise ValueError(
|
||||
('Activity regularization is not supported when executing eagerly. '
|
||||
'Got activity_regularizer=%s') % (activity_regularizer,))
|
||||
@ -228,7 +228,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
|
||||
@property
|
||||
def updates(self):
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('Layer.updates not supported in Eager mode.')
|
||||
if not self.trainable and not self.stateful:
|
||||
return []
|
||||
@ -260,7 +260,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
have is available at runtime.
|
||||
A step counter might fall into this category.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return # Updates already applied when in eager mode.
|
||||
|
||||
updates = _to_list(updates)
|
||||
@ -286,7 +286,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
Raises:
|
||||
RuntimeError: If called in Eager mode.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('`get_updates_for()` not supported in Eager mode.')
|
||||
|
||||
# Updates disabled if layer is not trainable and not explicitly stateful.
|
||||
@ -317,7 +317,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
Returns:
|
||||
A list of tensors.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
# _losses may only contain variable regularization losses when executing
|
||||
# eagerly, and they have been saved as lambdas to be executed when
|
||||
# requested.
|
||||
@ -355,7 +355,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
Raises:
|
||||
RuntimeError: If called in Eager mode.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
# TODO(fchollet): it should be possible (and highly desirable) to support
|
||||
# `add_loss` in eager mode. This allows great convenience and flexibility
|
||||
# in defining custom losses on the fly (e.g. in VAEs).
|
||||
@ -389,7 +389,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
Raises:
|
||||
RuntimeError: If called in Eager mode.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('Layer.get_losses_for not supported in Eager mode.')
|
||||
|
||||
if inputs is None:
|
||||
@ -509,7 +509,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
# will occur; it should be None if and only if initialization will take
|
||||
# place in the eager context.
|
||||
init_graph = None
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
default_graph = ops.get_default_graph()
|
||||
if default_graph.building_function:
|
||||
with ops.init_scope():
|
||||
@ -517,7 +517,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
# will be lifted; if initialization ops will be lifted into
|
||||
# the eager context, then there is nothing to retrieve, since variable
|
||||
# collections are not supported when eager execution is enabled.
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
init_graph = ops.get_default_graph()
|
||||
existing_variables = set(tf_variables.global_variables())
|
||||
else:
|
||||
@ -624,17 +624,17 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
self._set_scope(kwargs.pop('scope', None))
|
||||
input_list = nest.flatten(inputs)
|
||||
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
build_graph = not context.executing_eagerly()
|
||||
in_deferred_mode = isinstance(input_list[0], _DeferredTensor)
|
||||
# Ensure the Layer, if being reused, is working with inputs from
|
||||
# the same graph as where it was created.
|
||||
if in_graph_mode:
|
||||
if build_graph:
|
||||
try:
|
||||
# Set layer's "graph" at build time
|
||||
self._graph = ops._get_graph_from_inputs(input_list, graph=self._graph) # pylint: disable=protected-access
|
||||
except ValueError as e:
|
||||
raise ValueError('Input graph and Layer graph are not the same: %s' % e)
|
||||
if in_graph_mode or in_deferred_mode:
|
||||
if build_graph or in_deferred_mode:
|
||||
user_kwargs = copy.copy(kwargs)
|
||||
|
||||
# Handle Keras mask propagation from previous layer to current layer.
|
||||
@ -669,13 +669,14 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
with scope_context_manager as scope:
|
||||
with ops.name_scope(self._name_scope_name(scope)):
|
||||
if not self.built:
|
||||
if not in_graph_mode:
|
||||
if not build_graph:
|
||||
# Activity regularization is currently unsupported in Eager mode.
|
||||
if self._activity_regularizer:
|
||||
raise ValueError('activity_regularizer currently unsupported in '
|
||||
'Eager mode. Found an activity_regularizer in '
|
||||
'%s(%s).' % (self.__class__.__name__, self))
|
||||
if not in_graph_mode and not in_deferred_mode:
|
||||
raise ValueError(
|
||||
'activity_regularizer currently unsupported with '
|
||||
'eager execution enabled. Found an activity_regularizer in '
|
||||
'%s(%s).' % (self.__class__.__name__, self))
|
||||
if not build_graph and not in_deferred_mode:
|
||||
# TODO(agarwal): support _keras_history in Eager mode.
|
||||
for x in input_list:
|
||||
if hasattr(x, '_keras_history'):
|
||||
@ -706,7 +707,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
if call_has_scope_arg:
|
||||
kwargs['scope'] = scope
|
||||
# Check input assumptions set after layer building, e.g. input shape.
|
||||
if in_graph_mode or in_deferred_mode:
|
||||
if build_graph or in_deferred_mode:
|
||||
self._assert_input_compatibility(inputs)
|
||||
|
||||
if not in_deferred_mode:
|
||||
@ -730,7 +731,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
if len(outputs) == 1:
|
||||
outputs = outputs[0]
|
||||
|
||||
if in_graph_mode:
|
||||
if build_graph:
|
||||
# Apply activity regularization.
|
||||
# Note that it should be applied every time the layer creates a new
|
||||
# output, since it is output-specific.
|
||||
@ -752,7 +753,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
else:
|
||||
outputs._keras_mask = output_mask # pylint: disable=protected-access
|
||||
|
||||
if in_graph_mode:
|
||||
if build_graph:
|
||||
# If all input tensors have history metadata,
|
||||
# we update the output tensors
|
||||
# with corresponding history metadata, thus eventually allowing to use
|
||||
@ -775,7 +776,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
# Update global default collections.
|
||||
_add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
|
||||
|
||||
if in_deferred_mode or in_graph_mode:
|
||||
if in_deferred_mode or build_graph:
|
||||
if _have_all_keras_metadata(inputs):
|
||||
# Add an inbound node to the layer, so it can keep track of this call.
|
||||
# This updates the layer history of the output tensor(s).
|
||||
@ -787,7 +788,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('Layer.graph not supported in Eager mode.')
|
||||
return self._graph
|
||||
|
||||
@ -891,7 +892,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
mode.
|
||||
ValueError: If the index provided does not match any node.
|
||||
"""
|
||||
assert context.in_graph_mode()
|
||||
assert not context.executing_eagerly()
|
||||
if not self._inbound_nodes:
|
||||
raise RuntimeError('The layer has never been called '
|
||||
'and thus has no defined ' + attr_name + '.')
|
||||
@ -921,7 +922,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
Raises:
|
||||
RuntimeError: If called in Eager mode.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
'Layer.get_input_shape_at not supported in Eager mode.')
|
||||
return self._get_node_attribute_at_index(node_index, 'input_shapes',
|
||||
@ -943,7 +944,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
Raises:
|
||||
RuntimeError: If called in Eager mode.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
'Layer.get_output_shape_at not supported in Eager mode.')
|
||||
return self._get_node_attribute_at_index(node_index, 'output_shapes',
|
||||
@ -964,7 +965,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
Raises:
|
||||
RuntimeError: If called in Eager mode.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('Layer.get_input_at not supported in Eager mode.')
|
||||
return self._get_node_attribute_at_index(node_index, 'input_tensors',
|
||||
'input')
|
||||
@ -984,7 +985,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
Raises:
|
||||
RuntimeError: If called in Eager mode.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('Layer.get_output_at not supported in Eager mode.')
|
||||
return self._get_node_attribute_at_index(node_index, 'output_tensors',
|
||||
'output')
|
||||
@ -1007,7 +1008,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
RuntimeError: If called in Eager mode.
|
||||
AttributeError: If no inbound nodes are found.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('Layer.input not supported in Eager mode.')
|
||||
if not self._inbound_nodes:
|
||||
raise AttributeError('Layer ' + self.name +
|
||||
@ -1029,7 +1030,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
layers.
|
||||
RuntimeError: if called in Eager mode.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('Layer.output not supported in Eager mode.')
|
||||
if not self._inbound_nodes:
|
||||
raise AttributeError('Layer ' + self.name + ' has no inbound nodes.')
|
||||
@ -1051,7 +1052,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
AttributeError: if the layer has no defined input_shape.
|
||||
RuntimeError: if called in Eager mode.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('Layer.input_shape not supported in Eager mode.')
|
||||
if not self._inbound_nodes:
|
||||
raise AttributeError('The layer has never been called '
|
||||
@ -1112,7 +1113,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
AttributeError: if the layer has no defined output shape.
|
||||
RuntimeError: if called in Eager mode.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('Layer.output_shape not supported in Eager mode.')
|
||||
if not self._inbound_nodes:
|
||||
raise AttributeError('The layer has never been called '
|
||||
@ -1470,7 +1471,7 @@ def _to_list(x):
|
||||
|
||||
|
||||
def _add_elements_to_collection(elements, collection_list):
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('Using collections from Layers not supported in Eager '
|
||||
'mode. Tried to add %s to %s' % (elements,
|
||||
collection_list))
|
||||
|
@ -44,7 +44,7 @@ class BaseLayerTest(test.TestCase):
|
||||
self.assertEqual(layer.variables, [])
|
||||
self.assertEqual(layer.trainable_variables, [])
|
||||
self.assertEqual(layer.non_trainable_variables, [])
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
# updates, losses only supported in GRAPH mode
|
||||
self.assertEqual(layer.updates, [])
|
||||
self.assertEqual(layer.losses, [])
|
||||
@ -63,7 +63,7 @@ class BaseLayerTest(test.TestCase):
|
||||
self.assertEqual(layer.variables, [variable])
|
||||
self.assertEqual(layer.trainable_variables, [variable])
|
||||
self.assertEqual(layer.non_trainable_variables, [])
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.assertEqual(
|
||||
layer.variables,
|
||||
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
|
||||
@ -77,7 +77,7 @@ class BaseLayerTest(test.TestCase):
|
||||
self.assertEqual(layer.variables, [variable, variable_2])
|
||||
self.assertEqual(layer.trainable_variables, [variable])
|
||||
self.assertEqual(layer.non_trainable_variables, [variable_2])
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.assertEqual(
|
||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
|
||||
|
||||
@ -161,7 +161,7 @@ class BaseLayerTest(test.TestCase):
|
||||
inputs = random_ops.random_uniform((5,), seed=1)
|
||||
outputs = layer.apply(inputs)
|
||||
self.assertEqual(layer.built, True)
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
# op is only supported in GRAPH mode
|
||||
self.assertEqual(outputs.op.name, 'my_layer/Square')
|
||||
|
||||
@ -210,7 +210,7 @@ class BaseLayerTest(test.TestCase):
|
||||
inputs = random_ops.random_uniform((5,), seed=1)
|
||||
outputs = layer.apply(inputs)
|
||||
self.assertEqual(layer.built, True)
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
# op only supported in GRAPH mode.
|
||||
self.assertEqual(outputs.op.name, 'my_layer/Square')
|
||||
|
||||
@ -280,7 +280,7 @@ class BaseLayerTest(test.TestCase):
|
||||
def call(self, inputs):
|
||||
return inputs
|
||||
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
layer = CustomerLayer()
|
||||
with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
|
||||
layer.apply(array_ops.placeholder('int32'))
|
||||
@ -307,7 +307,7 @@ class BaseLayerTest(test.TestCase):
|
||||
def call(self, inputs):
|
||||
return inputs
|
||||
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
layer = CustomerLayer()
|
||||
with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
|
||||
layer.apply(array_ops.placeholder('int32'))
|
||||
@ -335,7 +335,7 @@ class BaseLayerTest(test.TestCase):
|
||||
def call(self, inputs):
|
||||
return inputs
|
||||
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
layer = CustomerLayer()
|
||||
with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
|
||||
layer.apply(array_ops.placeholder('int32'))
|
||||
@ -430,7 +430,7 @@ class BaseLayerTest(test.TestCase):
|
||||
layer.apply(constant_op.constant(1))
|
||||
|
||||
# Works
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
layer.apply(array_ops.placeholder('int32'))
|
||||
layer.apply(array_ops.placeholder('int32', shape=(2, 3)))
|
||||
|
||||
@ -453,13 +453,7 @@ class BaseLayerTest(test.TestCase):
|
||||
return {'l' + key: inputs[key] for key in inputs}
|
||||
|
||||
layer = DictLayer()
|
||||
if context.in_graph_mode():
|
||||
i1 = array_ops.placeholder('int32')
|
||||
i2 = array_ops.placeholder('float32')
|
||||
result = layer.apply({'abel': i1, 'ogits': i2})
|
||||
self.assertTrue(isinstance(result, dict))
|
||||
self.assertEqual(set(['label', 'logits']), set(result.keys()))
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
i1 = constant_op.constant(3)
|
||||
i2 = constant_op.constant(4.0)
|
||||
result = layer.apply({'abel': i1, 'ogits': i2})
|
||||
@ -467,6 +461,12 @@ class BaseLayerTest(test.TestCase):
|
||||
self.assertEqual(set(['label', 'logits']), set(result.keys()))
|
||||
self.assertEqual(3, result['label'].numpy())
|
||||
self.assertEqual(4.0, result['logits'].numpy())
|
||||
else:
|
||||
i1 = array_ops.placeholder('int32')
|
||||
i2 = array_ops.placeholder('float32')
|
||||
result = layer.apply({'abel': i1, 'ogits': i2})
|
||||
self.assertTrue(isinstance(result, dict))
|
||||
self.assertEqual(set(['label', 'logits']), set(result.keys()))
|
||||
|
||||
def testActivityRegularizer(self):
|
||||
regularizer = math_ops.reduce_sum
|
||||
|
@ -1664,7 +1664,7 @@ class Conv2DTranspose(Conv2D):
|
||||
padding=self.padding.upper(),
|
||||
data_format=utils.convert_data_format(self.data_format, ndim=4))
|
||||
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
# Infer the static output shape:
|
||||
out_shape = inputs.get_shape().as_list()
|
||||
out_shape[c_axis] = self.filters
|
||||
@ -1969,7 +1969,7 @@ class Conv3DTranspose(Conv3D):
|
||||
data_format=utils.convert_data_format(self.data_format, ndim=5),
|
||||
padding=self.padding.upper())
|
||||
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
# Infer the static output shape:
|
||||
out_shape = inputs.get_shape().as_list()
|
||||
out_shape[c_axis] = self.filters
|
||||
|
@ -156,7 +156,7 @@ class Dense(base.Layer):
|
||||
outputs = standard_ops.tensordot(inputs, self.kernel, [[len(shape) - 1],
|
||||
[0]])
|
||||
# Reshape the output back to the original ndim of the input.
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
output_shape = shape[:-1] + [self.units]
|
||||
outputs.set_shape(output_shape)
|
||||
else:
|
||||
@ -374,7 +374,7 @@ class Flatten(base.Layer):
|
||||
|
||||
def call(self, inputs):
|
||||
outputs = array_ops.reshape(inputs, (array_ops.shape(inputs)[0], -1))
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
outputs.set_shape(self.compute_output_shape(inputs.get_shape()))
|
||||
return outputs
|
||||
|
||||
|
@ -77,7 +77,7 @@ class DenseTest(test.TestCase):
|
||||
self.assertListEqual(dense.trainable_variables,
|
||||
[dense.kernel, dense.bias])
|
||||
self.assertListEqual(dense.non_trainable_variables, [])
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.assertEqual(
|
||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2)
|
||||
self.assertEqual(dense.kernel.name, 'my_dense/kernel:0')
|
||||
@ -98,7 +98,7 @@ class DenseTest(test.TestCase):
|
||||
self.assertListEqual(dense.variables, [dense.kernel])
|
||||
self.assertListEqual(dense.trainable_variables, [dense.kernel])
|
||||
self.assertListEqual(dense.non_trainable_variables, [])
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.assertEqual(
|
||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
|
||||
self.assertEqual(dense.kernel.name, 'my_dense/kernel:0')
|
||||
@ -113,7 +113,7 @@ class DenseTest(test.TestCase):
|
||||
self.assertListEqual(dense.non_trainable_variables,
|
||||
[dense.kernel, dense.bias])
|
||||
self.assertListEqual(dense.trainable_variables, [])
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.assertEqual(
|
||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 0)
|
||||
|
||||
@ -162,13 +162,13 @@ class DenseTest(test.TestCase):
|
||||
dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1')
|
||||
inputs = random_ops.random_uniform((5, 3), seed=1)
|
||||
outputs = dense(inputs)
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.assertEqual(outputs.op.name, 'dense1/Relu')
|
||||
|
||||
dense = core_layers.Dense(2, name='dense2')
|
||||
inputs = random_ops.random_uniform((5, 3), seed=1)
|
||||
outputs = dense(inputs)
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.assertEqual(outputs.op.name, 'dense2/BiasAdd')
|
||||
|
||||
def testActivityRegularizer(self):
|
||||
@ -374,7 +374,7 @@ class DropoutTest(test.TestCase):
|
||||
dp = core_layers.Dropout(0.5)
|
||||
inputs = array_ops.ones((5, 3))
|
||||
dropped = dp.apply(inputs, training=True)
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
np_output = self.evaluate(dropped)
|
||||
self.assertAlmostEqual(0., np_output.min())
|
||||
|
@ -338,8 +338,9 @@ class BatchNormalization(base.Layer):
|
||||
return var
|
||||
|
||||
with ops.device(None):
|
||||
device = ((lambda _: self.moving_mean.device)
|
||||
if context.in_graph_mode() else self.moving_mean.device)
|
||||
device = (
|
||||
self.moving_mean.device if context.executing_eagerly() else
|
||||
(lambda _: self.moving_mean.device))
|
||||
with ops.device(device):
|
||||
self.renorm_mean = _renorm_variable('renorm_mean', param_shape)
|
||||
self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
|
||||
@ -347,8 +348,9 @@ class BatchNormalization(base.Layer):
|
||||
# renorm_stddev_weight. This allows us to (1) mix the average
|
||||
# stddev with the minibatch stddev early in training, and (2) compute
|
||||
# the unbiased average stddev by dividing renorm_stddev by the weight.
|
||||
device = ((lambda _: self.moving_variance.device)
|
||||
if context.in_graph_mode() else self.moving_variance.device)
|
||||
device = (
|
||||
self.moving_variance.device if context.executing_eagerly() else
|
||||
(lambda _: self.moving_variance.device))
|
||||
with ops.device(device):
|
||||
self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape)
|
||||
self.renorm_stddev_weight = _renorm_variable(
|
||||
@ -420,7 +422,7 @@ class BatchNormalization(base.Layer):
|
||||
one_minus_decay)
|
||||
variance_update = self._assign_moving_average(self.moving_variance,
|
||||
variance, one_minus_decay)
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
# Note that in Eager mode, the updates are already executed when running
|
||||
# assign_moving_averages. So we do not need to put them into
|
||||
# collections.
|
||||
@ -493,7 +495,7 @@ class BatchNormalization(base.Layer):
|
||||
return (r, d, new_mean, new_variance)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
in_eager_mode = context.in_eager_mode()
|
||||
in_eager_mode = context.executing_eagerly()
|
||||
if self.virtual_batch_size is not None:
|
||||
# Virtual batches (aka ghost batches) can be simulated by reshaping the
|
||||
# Tensor and reusing the existing batch norm implementation
|
||||
@ -610,7 +612,7 @@ class BatchNormalization(base.Layer):
|
||||
training,
|
||||
lambda: _do_update(self.moving_variance, new_variance),
|
||||
lambda: self.moving_variance)
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.add_update(mean_update, inputs=inputs)
|
||||
self.add_update(variance_update, inputs=inputs)
|
||||
|
||||
|
@ -80,7 +80,7 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
|
||||
|
||||
def _ExtractInputShapes(inputs):
|
||||
"""Extract the shapes of a set of input tensors."""
|
||||
if not context.in_graph_mode():
|
||||
if context.executing_eagerly():
|
||||
return array_ops.shape_n(inputs)
|
||||
sizes = []
|
||||
fully_known = True
|
||||
@ -106,7 +106,7 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
|
||||
|
||||
out_grads = []
|
||||
if isinstance(grad, ops.Tensor):
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
# Using mod here for convenience since concat_dim is already verified
|
||||
# in concat implementation to be within the allowed [-rank, rank) range.
|
||||
non_neg_concat_dim = (
|
||||
@ -428,7 +428,7 @@ def _GatherV2Grad(op, grad):
|
||||
|
||||
# For axis 0 gathers, build an appropriately shaped IndexedSlices.
|
||||
if axis_static == 0:
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
params_tail_shape = params_shape.cpu()[1:]
|
||||
else:
|
||||
params_tail_shape = params_shape[1:]
|
||||
@ -578,7 +578,7 @@ def _TileGrad(op, grad):
|
||||
axes = math_ops.range(0, array_ops.size(split_shape), 2)
|
||||
input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes)
|
||||
# Fix shape inference
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
input_grad.set_shape(op.inputs[0].get_shape())
|
||||
return [input_grad, None]
|
||||
|
||||
|
@ -128,9 +128,7 @@ def identity(input, name=None): # pylint: disable=redefined-builtin
|
||||
Returns:
|
||||
A `Tensor`. Has the same type as `input`.
|
||||
"""
|
||||
if context.in_graph_mode():
|
||||
return gen_array_ops.identity(input, name=name)
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
input = ops.convert_to_tensor(input)
|
||||
in_device = input.device
|
||||
# TODO(ashankar): Does 'identity' need to invoke execution callbacks?
|
||||
@ -140,6 +138,8 @@ def identity(input, name=None): # pylint: disable=redefined-builtin
|
||||
if context_device != in_device:
|
||||
return input._copy() # pylint: disable=protected-access
|
||||
return input
|
||||
else:
|
||||
return gen_array_ops.identity(input, name=name)
|
||||
|
||||
|
||||
# pylint: disable=redefined-builtin,protected-access
|
||||
@ -305,7 +305,7 @@ def shape_internal(input, name=None, optimize=True, out_type=dtypes.int32):
|
||||
sparse_tensor.SparseTensorValue)):
|
||||
return gen_math_ops.cast(input.dense_shape, out_type)
|
||||
else:
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
input_tensor = ops.convert_to_tensor(input)
|
||||
input_shape = input_tensor.get_shape()
|
||||
if optimize and input_shape.is_fully_defined():
|
||||
@ -330,7 +330,7 @@ def shape_n(input, out_type=dtypes.int32, name=None):
|
||||
"""
|
||||
|
||||
output = gen_array_ops.shape_n(input, out_type=out_type, name=name)
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
for i, input_tensor in enumerate(input):
|
||||
input_tensor = ops.convert_to_tensor(input_tensor)
|
||||
input_shape = input_tensor.get_shape()
|
||||
@ -385,9 +385,8 @@ def size_internal(input, name=None, optimize=True, out_type=dtypes.int32):
|
||||
Returns:
|
||||
A `Tensor` of type `out_type`. Defaults to `tf.int32`.
|
||||
"""
|
||||
if context.in_eager_mode() and not isinstance(
|
||||
input, (sparse_tensor.SparseTensor,
|
||||
sparse_tensor.SparseTensorValue)):
|
||||
if context.executing_eagerly() and not isinstance(
|
||||
input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
|
||||
return np.prod(ops.convert_to_tensor(input)._shape_tuple()) # pylint: disable=protected-access
|
||||
with ops.name_scope(name, "Size", [input]) as name:
|
||||
if isinstance(input, (sparse_tensor.SparseTensor,
|
||||
@ -783,7 +782,7 @@ def strided_slice(input_,
|
||||
new_axis_mask=new_axis_mask,
|
||||
shrink_axis_mask=shrink_axis_mask)
|
||||
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
# TODO(apassos) In eager mode assignment will be done by overriding
|
||||
# __setitem__ instead.
|
||||
op.assign = assign
|
||||
@ -1457,7 +1456,7 @@ def transpose(a, perm=None, name="transpose", conjugate=False):
|
||||
ret = transpose_fn(a, perm, name=name)
|
||||
# NOTE(mrry): Setting the shape explicitly because
|
||||
# reverse is not handled by the shape function.
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
input_shape = ret.op.inputs[0].get_shape().dims
|
||||
if input_shape is not None:
|
||||
ret.set_shape(input_shape[::-1])
|
||||
@ -1622,7 +1621,7 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
|
||||
with ops.name_scope(name, "zeros_like", [tensor]) as name:
|
||||
tensor = ops.convert_to_tensor(tensor, name="tensor")
|
||||
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
if dtype is not None and dtype != tensor.dtype:
|
||||
return zeros(
|
||||
shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
|
||||
@ -1678,7 +1677,7 @@ def ones_like(tensor, dtype=None, name=None, optimize=True):
|
||||
if dtype is None:
|
||||
dtype = tensor.dtype
|
||||
ret = ones(ones_shape, dtype=dtype, name=name)
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
ret.set_shape(tensor.get_shape())
|
||||
return ret
|
||||
|
||||
@ -1759,7 +1758,7 @@ def placeholder(dtype, shape=None, name=None):
|
||||
Raises:
|
||||
RuntimeError: if eager execution is enabled
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError("tf.placeholder() is not compatible with "
|
||||
"eager execution.")
|
||||
|
||||
@ -1822,7 +1821,7 @@ def sparse_placeholder(dtype, shape=None, name=None):
|
||||
Raises:
|
||||
RuntimeError: if eager execution is enabled
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError("tf.placeholder() is not compatible with "
|
||||
"eager execution.")
|
||||
|
||||
@ -1921,7 +1920,7 @@ def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pyl
|
||||
raise ValueError("Unknown padding mode: %s" % mode)
|
||||
|
||||
# Restore shape information where possible.
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
paddings_constant = tensor_util.constant_value(
|
||||
result.op.inputs[1], partial=True)
|
||||
input_shape = result.op.inputs[0].shape
|
||||
|
@ -169,7 +169,7 @@ def assert_negative(x, data=None, summarize=None, message=None, name=None):
|
||||
with ops.name_scope(name, 'assert_negative', [x, data]):
|
||||
x = ops.convert_to_tensor(x, name='x')
|
||||
if data is None:
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
name = _shape_and_dtype_str(x)
|
||||
else:
|
||||
name = x.name
|
||||
@ -210,7 +210,7 @@ def assert_positive(x, data=None, summarize=None, message=None, name=None):
|
||||
with ops.name_scope(name, 'assert_positive', [x, data]):
|
||||
x = ops.convert_to_tensor(x, name='x')
|
||||
if data is None:
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
name = _shape_and_dtype_str(x)
|
||||
else:
|
||||
name = x.name
|
||||
@ -251,7 +251,7 @@ def assert_non_negative(x, data=None, summarize=None, message=None, name=None):
|
||||
with ops.name_scope(name, 'assert_non_negative', [x, data]):
|
||||
x = ops.convert_to_tensor(x, name='x')
|
||||
if data is None:
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
name = _shape_and_dtype_str(x)
|
||||
else:
|
||||
name = x.name
|
||||
@ -293,7 +293,7 @@ def assert_non_positive(x, data=None, summarize=None, message=None, name=None):
|
||||
with ops.name_scope(name, 'assert_non_positive', [x, data]):
|
||||
x = ops.convert_to_tensor(x, name='x')
|
||||
if data is None:
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
name = _shape_and_dtype_str(x)
|
||||
else:
|
||||
name = x.name
|
||||
@ -343,7 +343,7 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
|
||||
x = ops.convert_to_tensor(x, name='x')
|
||||
y = ops.convert_to_tensor(y, name='y')
|
||||
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
eq = math_ops.equal(x, y)
|
||||
condition = math_ops.reduce_all(eq)
|
||||
if not condition:
|
||||
@ -435,7 +435,7 @@ def assert_none_equal(
|
||||
with ops.name_scope(name, 'assert_none_equal', [x, y, data]):
|
||||
x = ops.convert_to_tensor(x, name='x')
|
||||
y = ops.convert_to_tensor(y, name='y')
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
x_name = _shape_and_dtype_str(x)
|
||||
y_name = _shape_and_dtype_str(y)
|
||||
else:
|
||||
@ -512,7 +512,7 @@ def assert_near(
|
||||
rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=x.dtype)
|
||||
atol = ops.convert_to_tensor(atol, name='atol', dtype=x.dtype)
|
||||
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
x_name = _shape_and_dtype_str(x)
|
||||
y_name = _shape_and_dtype_str(y)
|
||||
else:
|
||||
@ -562,7 +562,7 @@ def assert_less(x, y, data=None, summarize=None, message=None, name=None):
|
||||
with ops.name_scope(name, 'assert_less', [x, y, data]):
|
||||
x = ops.convert_to_tensor(x, name='x')
|
||||
y = ops.convert_to_tensor(y, name='y')
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
x_name = _shape_and_dtype_str(x)
|
||||
y_name = _shape_and_dtype_str(y)
|
||||
else:
|
||||
@ -610,7 +610,7 @@ def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
|
||||
with ops.name_scope(name, 'assert_less_equal', [x, y, data]):
|
||||
x = ops.convert_to_tensor(x, name='x')
|
||||
y = ops.convert_to_tensor(y, name='y')
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
x_name = _shape_and_dtype_str(x)
|
||||
y_name = _shape_and_dtype_str(y)
|
||||
else:
|
||||
@ -658,7 +658,7 @@ def assert_greater(x, y, data=None, summarize=None, message=None, name=None):
|
||||
with ops.name_scope(name, 'assert_greater', [x, y, data]):
|
||||
x = ops.convert_to_tensor(x, name='x')
|
||||
y = ops.convert_to_tensor(y, name='y')
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
x_name = _shape_and_dtype_str(x)
|
||||
y_name = _shape_and_dtype_str(y)
|
||||
else:
|
||||
@ -708,7 +708,7 @@ def assert_greater_equal(x, y, data=None, summarize=None, message=None,
|
||||
with ops.name_scope(name, 'assert_greater_equal', [x, y, data]):
|
||||
x = ops.convert_to_tensor(x, name='x')
|
||||
y = ops.convert_to_tensor(y, name='y')
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
x_name = _shape_and_dtype_str(x)
|
||||
y_name = _shape_and_dtype_str(y)
|
||||
else:
|
||||
@ -808,7 +808,7 @@ def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
|
||||
static_condition = lambda actual_rank, given_rank: actual_rank == given_rank
|
||||
dynamic_condition = math_ops.equal
|
||||
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
name = ''
|
||||
else:
|
||||
name = x.name
|
||||
@ -873,7 +873,7 @@ def assert_rank_at_least(
|
||||
static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank
|
||||
dynamic_condition = math_ops.greater_equal
|
||||
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
name = ''
|
||||
else:
|
||||
name = x.name
|
||||
@ -1001,7 +1001,7 @@ def assert_rank_in(
|
||||
ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks])
|
||||
message = message or ''
|
||||
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
name = ''
|
||||
else:
|
||||
name = x.name
|
||||
@ -1054,7 +1054,7 @@ def assert_integer(x, message=None, name=None):
|
||||
with ops.name_scope(name, 'assert_integer', [x]):
|
||||
x = ops.convert_to_tensor(x, name='x')
|
||||
if not x.dtype.is_integer:
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
name = 'tensor'
|
||||
else:
|
||||
name = x.name
|
||||
@ -1087,12 +1087,11 @@ def assert_type(tensor, tf_type, message=None, name=None):
|
||||
with ops.name_scope(name, 'assert_type', [tensor]):
|
||||
tensor = ops.convert_to_tensor(tensor, name='tensor')
|
||||
if tensor.dtype != tf_type:
|
||||
if context.in_graph_mode():
|
||||
raise TypeError(
|
||||
'%s %s must be of type %s' % (message, tensor.name, tf_type))
|
||||
if context.executing_eagerly():
|
||||
raise TypeError('%s tensor must be of type %s' % (message, tf_type))
|
||||
else:
|
||||
raise TypeError(
|
||||
'%s tensor must be of type %s' % (message, tf_type))
|
||||
raise TypeError('%s %s must be of type %s' % (message, tensor.name,
|
||||
tf_type))
|
||||
|
||||
return control_flow_ops.no_op('statically_determined_correct_type')
|
||||
|
||||
@ -1240,7 +1239,7 @@ def assert_scalar(tensor, name=None):
|
||||
tensor = ops.convert_to_tensor(tensor, name=name_scope)
|
||||
shape = tensor.get_shape()
|
||||
if shape.ndims != 0:
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise ValueError('Expected scalar shape, saw shape: %s.'
|
||||
% (shape,))
|
||||
else:
|
||||
|
@ -152,7 +152,7 @@ def Assert(condition, data, summarize=None, name=None):
|
||||
@compatibility{eager} `tf.errors.InvalidArgumentError` if `condition`
|
||||
is not true
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
if not condition:
|
||||
xs = ops.convert_n_to_tensor(data)
|
||||
data_str = [_summarize_eager(x, summarize) for x in xs]
|
||||
@ -178,7 +178,7 @@ def Assert(condition, data, summarize=None, name=None):
|
||||
condition, data, summarize, name="Assert")
|
||||
|
||||
guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard")
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
return guarded_assert.op
|
||||
|
||||
@ -2025,7 +2025,7 @@ def cond(pred,
|
||||
raise TypeError("false_fn must be callable.")
|
||||
|
||||
with ops.name_scope(name, "cond", [pred]):
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
if pred:
|
||||
return _UnpackIfSingleton(true_fn())
|
||||
return _UnpackIfSingleton(false_fn())
|
||||
@ -3177,7 +3177,7 @@ def while_loop(cond,
|
||||
math_ops.logical_and(i < maximum_iterations, orig_cond(*lv)))
|
||||
body = lambda i, lv: (i + 1, orig_body(*lv))
|
||||
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
while cond(*loop_vars):
|
||||
loop_vars = body(*loop_vars)
|
||||
if maximum_iterations is not None:
|
||||
@ -3271,7 +3271,7 @@ def with_dependencies(dependencies, output_tensor, name=None):
|
||||
Raises:
|
||||
TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return output_tensor
|
||||
with ops.name_scope(name, "control_dependency",
|
||||
list(dependencies) + [output_tensor]) as name:
|
||||
@ -3316,7 +3316,7 @@ def group(*inputs, **kwargs):
|
||||
Raises:
|
||||
ValueError: If an unknown keyword argument is provided.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return None
|
||||
name = kwargs.pop("name", None)
|
||||
if kwargs:
|
||||
@ -3396,7 +3396,7 @@ def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined
|
||||
objects.
|
||||
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return tensors
|
||||
with ops.name_scope(name, "tuple", tensors) as name:
|
||||
tensors = [t if (isinstance(t, ops.Operation)
|
||||
|
@ -92,7 +92,7 @@ def custom_gradient(f):
|
||||
|
||||
def decorated(*args, **kwargs):
|
||||
"""Decorated function with custom gradient."""
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
if kwargs:
|
||||
raise ValueError(
|
||||
"The custom_gradient decorator currently suports keywords "
|
||||
|
@ -159,7 +159,7 @@ class QueueBase(object):
|
||||
ValueError: If one of the arguments is invalid.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
"Queues are not supported when eager execution is enabled. "
|
||||
"Instead, please use tf.data to get data into your model.")
|
||||
@ -177,10 +177,10 @@ class QueueBase(object):
|
||||
else:
|
||||
self._names = None
|
||||
self._queue_ref = queue_ref
|
||||
if context.in_graph_mode():
|
||||
self._name = self._queue_ref.op.name.split("/")[-1]
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
self._name = context.context().scope_name
|
||||
else:
|
||||
self._name = self._queue_ref.op.name.split("/")[-1]
|
||||
|
||||
@staticmethod
|
||||
def from_list(index, queues):
|
||||
@ -231,9 +231,9 @@ class QueueBase(object):
|
||||
@property
|
||||
def name(self):
|
||||
"""The name of the underlying queue."""
|
||||
if context.in_graph_mode():
|
||||
return self._queue_ref.op.name
|
||||
return self._name
|
||||
if context.executing_eagerly():
|
||||
return self._name
|
||||
return self._queue_ref.op.name
|
||||
|
||||
@property
|
||||
def dtypes(self):
|
||||
@ -444,7 +444,7 @@ class QueueBase(object):
|
||||
|
||||
# NOTE(mrry): Not using a shape function because we need access to
|
||||
# the `QueueBase` object.
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
op = ret[0].op
|
||||
for output, shape in zip(op.values(), self._shapes):
|
||||
output.set_shape(shape)
|
||||
@ -484,7 +484,7 @@ class QueueBase(object):
|
||||
|
||||
# NOTE(mrry): Not using a shape function because we need access to
|
||||
# the Queue object.
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
op = ret[0].op
|
||||
batch_dim = tensor_shape.Dimension(
|
||||
tensor_util.constant_value(op.inputs[1]))
|
||||
@ -528,7 +528,7 @@ class QueueBase(object):
|
||||
|
||||
# NOTE(mrry): Not using a shape function because we need access to
|
||||
# the Queue object.
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
op = ret[0].op
|
||||
for output, shape in zip(op.values(), self._shapes):
|
||||
output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape))
|
||||
@ -990,10 +990,10 @@ class Barrier(object):
|
||||
shapes=self._shapes,
|
||||
shared_name=shared_name,
|
||||
name=name)
|
||||
if context.in_graph_mode():
|
||||
self._name = self._barrier_ref.op.name.split("/")[-1]
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
self._name = context.context().scope_name
|
||||
else:
|
||||
self._name = self._barrier_ref.op.name.split("/")[-1]
|
||||
|
||||
@property
|
||||
def barrier_ref(self):
|
||||
@ -1003,9 +1003,9 @@ class Barrier(object):
|
||||
@property
|
||||
def name(self):
|
||||
"""The name of the underlying barrier."""
|
||||
if context.in_graph_mode():
|
||||
return self._barrier_ref.op.name
|
||||
return self._name
|
||||
if context.executing_eagerly():
|
||||
return self._name
|
||||
return self._barrier_ref.op.name
|
||||
|
||||
def insert_many(self, component_index, keys, values, name=None):
|
||||
"""For each key, assigns the respective value to the specified component.
|
||||
@ -1083,7 +1083,7 @@ class Barrier(object):
|
||||
|
||||
# NOTE(mrry): Not using a shape function because we need access to
|
||||
# the Barrier object.
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
op = ret[0].op
|
||||
if allow_small_batch:
|
||||
batch_dim = None
|
||||
@ -1183,10 +1183,10 @@ class ConditionalAccumulatorBase(object):
|
||||
else:
|
||||
self._shape = tensor_shape.unknown_shape()
|
||||
self._accumulator_ref = accumulator_ref
|
||||
if context.in_graph_mode():
|
||||
self._name = self._accumulator_ref.op.name.split("/")[-1]
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
self._name = context.context().scope_name
|
||||
else:
|
||||
self._name = self._accumulator_ref.op.name.split("/")[-1]
|
||||
|
||||
@property
|
||||
def accumulator_ref(self):
|
||||
|
@ -90,7 +90,7 @@ def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
||||
if not callable(fn):
|
||||
raise TypeError("fn must be callable.")
|
||||
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
in_graph_mode = not context.executing_eagerly()
|
||||
with ops.name_scope(name, "foldl", [elems]):
|
||||
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
|
||||
# supported in Eager
|
||||
@ -178,7 +178,7 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
||||
if not callable(fn):
|
||||
raise TypeError("fn must be callable.")
|
||||
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
in_graph_mode = not context.executing_eagerly()
|
||||
with ops.name_scope(name, "foldr", [elems]):
|
||||
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
|
||||
# supported in Eager
|
||||
@ -343,7 +343,7 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
|
||||
|
||||
elems_flat = input_flatten(elems)
|
||||
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
in_graph_mode = not context.executing_eagerly()
|
||||
with ops.name_scope(name, "map", elems_flat):
|
||||
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
|
||||
# supported in Eager
|
||||
@ -536,7 +536,7 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
||||
|
||||
elems_flat = input_flatten(elems)
|
||||
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
in_graph_mode = not context.executing_eagerly()
|
||||
with ops.name_scope(name, "scan", elems_flat):
|
||||
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
|
||||
# supported in Eager
|
||||
|
@ -86,7 +86,7 @@ def _IndexedSlicesToTensor(value, dtype=None, name=None, as_ref=False):
|
||||
% str(value))
|
||||
# TODO(mrry): Consider adding static shape information to
|
||||
# IndexedSlices, to avoid using numpy here.
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
dense_shape_value = tensor_util.constant_value(value.dense_shape)
|
||||
if dense_shape_value is not None:
|
||||
num_elements = np.prod(dense_shape_value)
|
||||
@ -491,9 +491,10 @@ def gradients(ys,
|
||||
def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
|
||||
gate_gradients, aggregation_method, stop_gradients):
|
||||
"""Implementation of gradients()."""
|
||||
if context.in_eager_mode():
|
||||
raise RuntimeError("tf.gradients not supported in EAGER mode. Use "
|
||||
"functions in tf.contrib.eager.backprop instead.")
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError("tf.gradients not supported when eager execution "
|
||||
"is enabled. Use tf.contrib.eager.GradientTape "
|
||||
"instead.")
|
||||
ys = _AsList(ys)
|
||||
xs = _AsList(xs)
|
||||
stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
|
||||
|
@ -173,7 +173,7 @@ class ReaderBase(object):
|
||||
Raises:
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
"Readers are not supported when eager execution is enabled. "
|
||||
"Instead, please use tf.data to get data into your model.")
|
||||
|
@ -157,10 +157,10 @@ class InitializableLookupTableBase(LookupInterface):
|
||||
default_value: The value to use if a key is missing in the table.
|
||||
initializer: The table initializer to use.
|
||||
"""
|
||||
if context.in_graph_mode():
|
||||
name = table_ref.op.name.split("/")[-1]
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
name = context.context().scope_name
|
||||
else:
|
||||
name = table_ref.op.name.split("/")[-1]
|
||||
super(InitializableLookupTableBase,
|
||||
self).__init__(initializer.key_dtype, initializer.value_dtype,
|
||||
name)
|
||||
@ -521,7 +521,7 @@ class TextFileInitializer(TableInitializerBase):
|
||||
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
|
||||
# If the filename tensor is anything other than a string constant (e.g., if
|
||||
# it is a placeholder) then it does not make sense to track it as an asset.
|
||||
if context.in_graph_mode() and constant_op.is_constant(filename):
|
||||
if not context.executing_eagerly() and constant_op.is_constant(filename):
|
||||
ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
|
||||
return init_op
|
||||
|
||||
|
@ -136,7 +136,7 @@ def _num_present(losses, weights, per_batch=False):
|
||||
`[batch_size]`. Otherwise, a single scalar tensor is returned.
|
||||
"""
|
||||
if ((isinstance(weights, float) and weights != 0.0) or
|
||||
(context.in_eager_mode() and weights._rank() == 0 # pylint: disable=protected-access
|
||||
(context.executing_eagerly() and weights._rank() == 0 # pylint: disable=protected-access
|
||||
and not math_ops.equal(weights, 0.0))):
|
||||
return _num_elements(losses)
|
||||
with ops.name_scope(None, "num_present", (losses, weights)) as scope:
|
||||
|
@ -52,14 +52,14 @@ def _SumGrad(op, grad):
|
||||
if axes is not None:
|
||||
rank = len(input_0_shape)
|
||||
if np.array_equal(axes, np.arange(rank)): # Reduce all dims.
|
||||
if context.in_graph_mode():
|
||||
new_shape = [1] * rank
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
ctx = context.context()
|
||||
new_shape = ctx.ones_rank_cache().get(rank)
|
||||
if new_shape is None:
|
||||
new_shape = constant_op.constant([1] * rank, dtype=dtypes.int32)
|
||||
ctx.ones_rank_cache().put(rank, new_shape)
|
||||
else:
|
||||
new_shape = [1] * rank
|
||||
grad = array_ops.reshape(grad, new_shape)
|
||||
# If shape is not fully defined (but rank is), we use Shape.
|
||||
if None not in input_0_shape:
|
||||
@ -997,7 +997,7 @@ def _SparseMatMulGrad(op, grad):
|
||||
op.inputs[0]: op.get_attr("a_is_sparse"),
|
||||
op.inputs[1]: op.get_attr("b_is_sparse"),
|
||||
# Use heuristic to figure out if grad might be sparse
|
||||
grad: context.in_graph_mode() and (grad.op.type == "ReluGrad")
|
||||
grad: not context.executing_eagerly() and (grad.op.type == "ReluGrad")
|
||||
}
|
||||
|
||||
def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False):
|
||||
|
@ -2007,14 +2007,14 @@ def matmul(a,
|
||||
if transpose_b and adjoint_b:
|
||||
raise ValueError("Only one of transpose_b and adjoint_b can be True.")
|
||||
|
||||
if context.in_graph_mode():
|
||||
a = ops.convert_to_tensor(a, name="a")
|
||||
b = ops.convert_to_tensor(b, name="b")
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
if not isinstance(a, (ops.EagerTensor, _resource_variable_type)):
|
||||
a = ops.convert_to_tensor(a, name="a")
|
||||
if not isinstance(b, (ops.EagerTensor, _resource_variable_type)):
|
||||
b = ops.convert_to_tensor(b, name="b")
|
||||
else:
|
||||
a = ops.convert_to_tensor(a, name="a")
|
||||
b = ops.convert_to_tensor(b, name="b")
|
||||
|
||||
# TODO(apassos) remove _shape_tuple here when it is not needed.
|
||||
a_shape = a._shape_tuple() # pylint: disable=protected-access
|
||||
@ -2249,7 +2249,7 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
|
||||
return inputs[0]
|
||||
elif len(inputs) == 1 and name is not None:
|
||||
return array_ops.identity(inputs[0], name=name)
|
||||
elif context.in_eager_mode():
|
||||
elif context.executing_eagerly():
|
||||
# TemporaryVariable not currently supported in eager mode; fall back
|
||||
# onto AddN for now.
|
||||
# TODO(frreiss) remove this once the lifetime of eager variables gets
|
||||
|
@ -60,7 +60,7 @@ class ReduceTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testReduceInvalidAxis(self):
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
# The shape check is in run a graph construction time. In eager mode,
|
||||
# it misses the check, magically return result given wrong shape.
|
||||
return
|
||||
@ -249,7 +249,7 @@ class ScalarMulTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testAcceptsRefs(self):
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
var = resource_variable_ops.ResourceVariable(10, name="var")
|
||||
else:
|
||||
var = variables.Variable(10)
|
||||
|
@ -308,7 +308,7 @@ def mean(values,
|
||||
or tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.mean is not supported when eager execution '
|
||||
'is enabled.')
|
||||
|
||||
@ -394,7 +394,7 @@ def accuracy(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.accuracy is not supported when eager '
|
||||
'execution is enabled.')
|
||||
|
||||
@ -644,7 +644,7 @@ def auc(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.auc is not supported when eager execution '
|
||||
'is enabled.')
|
||||
|
||||
@ -758,7 +758,7 @@ def mean_absolute_error(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.mean_absolute_error is not supported '
|
||||
'when eager execution is enabled.')
|
||||
|
||||
@ -818,7 +818,7 @@ def mean_cosine_distance(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when '
|
||||
'eager execution is enabled.')
|
||||
|
||||
@ -891,7 +891,7 @@ def mean_per_class_accuracy(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported '
|
||||
'when eager execution is enabled.')
|
||||
|
||||
@ -996,7 +996,7 @@ def mean_iou(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.mean_iou is not supported when '
|
||||
'eager execution is enabled.')
|
||||
|
||||
@ -1098,7 +1098,7 @@ def mean_relative_error(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.mean_relative_error is not supported when '
|
||||
'eager execution is enabled.')
|
||||
|
||||
@ -1165,7 +1165,7 @@ def mean_squared_error(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.mean_squared_error is not supported when '
|
||||
'eager execution is enabled.')
|
||||
|
||||
@ -1223,7 +1223,7 @@ def mean_tensor(values,
|
||||
or tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.mean_tensor is not supported when '
|
||||
'eager execution is enabled.')
|
||||
|
||||
@ -1304,7 +1304,7 @@ def percentage_below(values,
|
||||
or tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.percentage_below is not supported when '
|
||||
'eager execution is enabled.')
|
||||
|
||||
@ -1397,7 +1397,7 @@ def false_negatives(labels,
|
||||
or tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.false_negatives is not supported when '
|
||||
'eager execution is enabled.')
|
||||
|
||||
@ -1453,7 +1453,7 @@ def false_negatives_at_thresholds(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not '
|
||||
'supported when eager execution is enabled.')
|
||||
|
||||
@ -1507,7 +1507,7 @@ def false_positives(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.false_positives is not supported when '
|
||||
'eager execution is enabled.')
|
||||
|
||||
@ -1563,7 +1563,7 @@ def false_positives_at_thresholds(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.false_positives_at_thresholds is not '
|
||||
'supported when eager execution is enabled.')
|
||||
|
||||
@ -1617,7 +1617,7 @@ def true_negatives(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.true_negatives is not '
|
||||
'supported when eager execution is enabled.')
|
||||
|
||||
@ -1673,7 +1673,7 @@ def true_negatives_at_thresholds(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not '
|
||||
'supported when eager execution is enabled.')
|
||||
|
||||
@ -1727,7 +1727,7 @@ def true_positives(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.true_positives is not '
|
||||
'supported when eager execution is enabled.')
|
||||
|
||||
@ -1783,7 +1783,7 @@ def true_positives_at_thresholds(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.true_positives_at_thresholds is not '
|
||||
'supported when eager execution is enabled.')
|
||||
|
||||
@ -1851,7 +1851,7 @@ def precision(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.precision is not '
|
||||
'supported when eager execution is enabled.')
|
||||
|
||||
@ -1947,7 +1947,7 @@ def precision_at_thresholds(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.precision_at_thresholds is not '
|
||||
'supported when eager execution is enabled.')
|
||||
|
||||
@ -2023,7 +2023,7 @@ def recall(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.recall is not supported is not '
|
||||
'supported when eager execution is enabled.')
|
||||
|
||||
@ -2400,7 +2400,7 @@ def recall_at_k(labels,
|
||||
are not a list or tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.recall_at_k is not '
|
||||
'supported when eager execution is enabled.')
|
||||
|
||||
@ -2549,7 +2549,7 @@ def recall_at_thresholds(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.recall_at_thresholds is not '
|
||||
'supported when eager execution is enabled.')
|
||||
|
||||
@ -2626,7 +2626,7 @@ def root_mean_squared_error(labels,
|
||||
tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.root_mean_squared_error is not '
|
||||
'supported when eager execution is enabled.')
|
||||
|
||||
@ -2707,7 +2707,7 @@ def sensitivity_at_specificity(labels,
|
||||
or `updates_collections` are not a list or tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.sensitivity_at_specificity is not '
|
||||
'supported when eager execution is enabled.')
|
||||
|
||||
@ -3098,7 +3098,7 @@ def average_precision_at_k(labels,
|
||||
ValueError: if k is invalid.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not '
|
||||
'supported when eager execution is enabled.')
|
||||
|
||||
@ -3267,7 +3267,7 @@ def precision_at_top_k(labels,
|
||||
are not a list or tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.precision_at_top_k is not '
|
||||
'supported when eager execution is enabled.')
|
||||
|
||||
@ -3396,7 +3396,7 @@ def precision_at_k(labels,
|
||||
are not a list or tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.sparse_precision_at_k is not '
|
||||
'supported when eager execution is enabled.')
|
||||
|
||||
@ -3473,7 +3473,7 @@ def specificity_at_sensitivity(labels,
|
||||
or `updates_collections` are not a list or tuple.
|
||||
RuntimeError: If eager execution is enabled.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('tf.metrics.specificity_at_sensitivity is not '
|
||||
'supported when eager execution is enabled.')
|
||||
|
||||
|
@ -456,7 +456,7 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
|
||||
|
||||
def IsZero(g):
|
||||
# Some introspection to check if the gradient is feeding zeros
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
# TODO(apassos) add an efficient way to detect eager zeros here.
|
||||
return False
|
||||
if g.op.type in ("ZerosLike", "Zeros"):
|
||||
|
@ -1504,7 +1504,7 @@ def bias_add(value, bias, data_format=None, name=None):
|
||||
A `Tensor` with the same type as `value`.
|
||||
"""
|
||||
with ops.name_scope(name, "BiasAdd", [value, bias]) as name:
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
value = ops.convert_to_tensor(value, name="input")
|
||||
bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
|
||||
return gen_nn_ops.bias_add(value, bias, data_format=data_format, name=name)
|
||||
@ -1616,7 +1616,7 @@ def _flatten_outer_dims(logits):
|
||||
output = array_ops.reshape(logits, array_ops.concat([[-1], last_dim_size], 0))
|
||||
|
||||
# Set output shape if known.
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
shape = logits.get_shape()
|
||||
if shape is not None and shape.dims is not None:
|
||||
shape = shape.as_list()
|
||||
@ -1881,7 +1881,8 @@ def softmax_cross_entropy_with_logits_v2(
|
||||
|
||||
# Make shape inference work since reshape and transpose may erase its static
|
||||
# shape.
|
||||
if context.in_graph_mode() and shape is not None and shape.dims is not None:
|
||||
if not context.executing_eagerly(
|
||||
) and shape is not None and shape.dims is not None:
|
||||
shape = shape.as_list()
|
||||
del shape[dim]
|
||||
cost.set_shape(shape)
|
||||
@ -2318,7 +2319,7 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: di
|
||||
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
|
||||
binary_tensor = math_ops.floor(random_tensor)
|
||||
ret = math_ops.div(x, keep_prob) * binary_tensor
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
ret.set_shape(x.get_shape())
|
||||
return ret
|
||||
|
||||
|
@ -74,7 +74,7 @@ def add_check_numerics_ops():
|
||||
the checked operations.
|
||||
@enc_compatibility
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
"add_check_numerics_ops() is not compatible with eager execution. "
|
||||
"To check for Inf's and NaN's under eager execution, call "
|
||||
|
@ -135,10 +135,10 @@ class EagerResourceDeleter(object):
|
||||
# valid, and so on. Printing warnings in these cases is silly
|
||||
# (exceptions raised from __del__ are printed as warnings to stderr).
|
||||
pass # 'NoneType' object is not callable when the handle has been
|
||||
# partially unloaded.
|
||||
# partially unloaded.
|
||||
except AttributeError:
|
||||
pass # 'NoneType' object has no attribute 'eager_mode' when context has
|
||||
# been unloaded. Will catch other module unloads as well.
|
||||
# been unloaded. Will catch other module unloads as well.
|
||||
|
||||
|
||||
def shape_safe_assign_variable_handle(handle, shape, value, name=None):
|
||||
@ -267,9 +267,9 @@ class ResourceVariable(variables.Variable):
|
||||
if initial_value is not None:
|
||||
raise ValueError("variable_def and initial_value are mutually "
|
||||
"exclusive.")
|
||||
if not context.in_graph_mode():
|
||||
raise ValueError("Creating ResourceVariable from variable_def"
|
||||
" only supported in GRAPH mode.")
|
||||
if context.executing_eagerly():
|
||||
raise ValueError("Creating ResourceVariable from variable_def is "
|
||||
"not supported when eager execution is enabled.")
|
||||
self._init_from_proto(variable_def, import_scope=import_scope)
|
||||
else:
|
||||
self._init_from_args(
|
||||
@ -363,7 +363,7 @@ class ResourceVariable(variables.Variable):
|
||||
# this graph.
|
||||
self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
|
||||
with ops.init_scope():
|
||||
self._in_graph_mode = context.in_graph_mode()
|
||||
self._in_graph_mode = not context.executing_eagerly()
|
||||
with ops.name_scope(name, "Variable", []
|
||||
if init_from_fn else [initial_value]) as name:
|
||||
# pylint: disable=protected-access
|
||||
@ -470,7 +470,7 @@ class ResourceVariable(variables.Variable):
|
||||
self._cached_value = self._read_variable_op()
|
||||
else:
|
||||
self._cached_value = None
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
ops.add_to_collections(collections, self)
|
||||
elif ops.GraphKeys.GLOBAL_STEP in collections:
|
||||
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)
|
||||
@ -489,7 +489,7 @@ class ResourceVariable(variables.Variable):
|
||||
def _init_from_proto(self, variable_def, import_scope=None):
|
||||
"""Initializes from `VariableDef` proto."""
|
||||
# Note that init_from_proto is currently not supported in Eager mode.
|
||||
assert context.in_graph_mode()
|
||||
assert not context.executing_eagerly()
|
||||
self._in_graph_mode = True
|
||||
assert isinstance(variable_def, variable_pb2.VariableDef)
|
||||
if not variable_def.is_resource:
|
||||
@ -582,7 +582,8 @@ class ResourceVariable(variables.Variable):
|
||||
def create(self):
|
||||
"""The op responsible for initializing this variable."""
|
||||
if not self._in_graph_mode:
|
||||
raise RuntimeError("Calling create in EAGER mode not supported.")
|
||||
raise RuntimeError("Calling create is not supported when eager execution"
|
||||
" is enabled.")
|
||||
return self._initializer_op
|
||||
|
||||
@property
|
||||
@ -610,7 +611,7 @@ class ResourceVariable(variables.Variable):
|
||||
@property
|
||||
def initial_value(self):
|
||||
"""Returns the Tensor used as the initial value for the variable."""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError("initial_value not supported in EAGER mode.")
|
||||
return self._initial_value
|
||||
|
||||
@ -631,15 +632,15 @@ class ResourceVariable(variables.Variable):
|
||||
|
||||
def eval(self, session=None):
|
||||
"""Evaluates and returns the value of this variable."""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError("Trying to eval in EAGER mode")
|
||||
return self._graph_element.eval(session=session)
|
||||
|
||||
def numpy(self):
|
||||
if context.in_graph_mode():
|
||||
raise NotImplementedError(
|
||||
"numpy() is only available when eager execution is enabled.")
|
||||
return self.read_value().numpy()
|
||||
if context.executing_eagerly():
|
||||
return self.read_value().numpy()
|
||||
raise NotImplementedError(
|
||||
"numpy() is only available when eager execution is enabled.")
|
||||
|
||||
def count_up_to(self, limit):
|
||||
"""Increments this variable until it reaches `limit`.
|
||||
@ -720,7 +721,7 @@ class ResourceVariable(variables.Variable):
|
||||
A `VariableDef` protocol buffer, or `None` if the `Variable` is not
|
||||
in the specified name scope.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError("to_proto not supported in EAGER mode.")
|
||||
if export_scope is None or self.handle.name.startswith(export_scope):
|
||||
var_def = variable_pb2.VariableDef()
|
||||
@ -747,7 +748,7 @@ class ResourceVariable(variables.Variable):
|
||||
|
||||
@staticmethod
|
||||
def from_proto(variable_def, import_scope=None):
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError("from_proto not supported in EAGER mode.")
|
||||
return ResourceVariable(
|
||||
variable_def=variable_def, import_scope=import_scope)
|
||||
@ -984,10 +985,10 @@ class _UnreadVariable(ResourceVariable):
|
||||
self._is_initialized_op = None
|
||||
self._initializer_op = None
|
||||
self._parent_op = parent_op
|
||||
if context.in_graph_mode():
|
||||
self._graph_element = self.read_value()
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
self._graph_element = None
|
||||
else:
|
||||
self._graph_element = self.read_value()
|
||||
self._handle_deleter = deleter
|
||||
|
||||
def value(self):
|
||||
|
@ -575,7 +575,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
|
||||
# Create a new scope in which the caching device is either
|
||||
# determined by the parent scope, or is set to place the cached
|
||||
# Variable using the same placement as for the rest of the RNN.
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
if varscope.caching_device is None:
|
||||
varscope.set_caching_device(lambda op: op.device)
|
||||
|
||||
@ -616,7 +616,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
|
||||
["Expected shape for Tensor %s is " % x.name,
|
||||
packed_shape, " but saw shape: ", x_shape])
|
||||
|
||||
if context.in_graph_mode() and sequence_length is not None:
|
||||
if not context.executing_eagerly() and sequence_length is not None:
|
||||
# Perform some shape validation
|
||||
with ops.control_dependencies(
|
||||
[_assert_has_shape(sequence_length, [batch_size])]):
|
||||
@ -742,7 +742,7 @@ def _dynamic_rnn_loop(cell,
|
||||
element_shape=element_shape,
|
||||
tensor_array_name=base_name + name)
|
||||
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
in_graph_mode = not context.executing_eagerly()
|
||||
if in_graph_mode:
|
||||
output_ta = tuple(
|
||||
_create_ta(
|
||||
@ -1027,7 +1027,7 @@ def raw_rnn(cell, loop_fn,
|
||||
# determined by the parent scope, or is set to place the cached
|
||||
# Variable using the same placement as for the rest of the RNN.
|
||||
with vs.variable_scope(scope or "rnn") as varscope:
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
if varscope.caching_device is None:
|
||||
varscope.set_caching_device(lambda op: op.device)
|
||||
|
||||
@ -1242,7 +1242,7 @@ def static_rnn(cell,
|
||||
# determined by the parent scope, or is set to place the cached
|
||||
# Variable using the same placement as for the rest of the RNN.
|
||||
with vs.variable_scope(scope or "rnn") as varscope:
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
if varscope.caching_device is None:
|
||||
varscope.set_caching_device(lambda op: op.device)
|
||||
|
||||
|
@ -128,7 +128,7 @@ def _zero_state_tensors(state_size, batch_size, dtype):
|
||||
"""Combine s with batch_size to get a proper tensor shape."""
|
||||
c = _concat(batch_size, s)
|
||||
size = array_ops.zeros(c, dtype=dtype)
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
c_static = _concat(batch_size, s, static=True)
|
||||
size.set_shape(c_static)
|
||||
return size
|
||||
@ -192,12 +192,13 @@ class RNNCell(base_layer.Layer):
|
||||
|
||||
def _rnn_get_variable(self, getter, *args, **kwargs):
|
||||
variable = getter(*args, **kwargs)
|
||||
if context.in_graph_mode():
|
||||
trainable = (variable in tf_variables.trainable_variables() or
|
||||
(isinstance(variable, tf_variables.PartitionedVariable) and
|
||||
list(variable)[0] in tf_variables.trainable_variables()))
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
trainable = variable._trainable # pylint: disable=protected-access
|
||||
else:
|
||||
trainable = (
|
||||
variable in tf_variables.trainable_variables() or
|
||||
(isinstance(variable, tf_variables.PartitionedVariable) and
|
||||
list(variable)[0] in tf_variables.trainable_variables()))
|
||||
if trainable and variable not in self._trainable_weights:
|
||||
self._trainable_weights.append(variable)
|
||||
elif not trainable and variable not in self._non_trainable_weights:
|
||||
@ -241,7 +242,7 @@ class RNNCell(base_layer.Layer):
|
||||
# Try to use the last cached zero_state. This is done to avoid recreating
|
||||
# zeros, especially when eager execution is enabled.
|
||||
state_size = self.state_size
|
||||
is_eager = context.in_eager_mode()
|
||||
is_eager = context.executing_eagerly()
|
||||
if is_eager and hasattr(self, "_last_zero_state"):
|
||||
(last_state_size, last_batch_size, last_dtype,
|
||||
last_output) = getattr(self, "_last_zero_state")
|
||||
|
@ -317,7 +317,7 @@ def py_func(func, inp, Tout, stateful=True, name=None):
|
||||
Returns:
|
||||
A list of `Tensor` or a single `Tensor` which `func` computes.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
result = func(*[x.numpy() for x in inp])
|
||||
result = nest.flatten(result)
|
||||
|
||||
|
@ -186,7 +186,7 @@ def is_variable_initialized(ref, name=None):
|
||||
if ref.dtype._is_ref_dtype:
|
||||
return gen_state_ops.is_variable_initialized(ref=ref, name=name)
|
||||
# Handle resource variables.
|
||||
if context.in_eager_mode() or ref.op.type == "VarHandleOp":
|
||||
if context.executing_eagerly() or ref.op.type == "VarHandleOp":
|
||||
return gen_resource_variable_ops.var_is_initialized_op(ref.handle,
|
||||
name=name)
|
||||
|
||||
|
@ -204,7 +204,7 @@ def make_template_internal(name_,
|
||||
if kwargs:
|
||||
func_ = tf_decorator.make_decorator(func_, functools.partial(
|
||||
func_, **kwargs))
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
if unique_name_ is not None:
|
||||
raise ValueError(
|
||||
"unique_name_ cannot be used when eager exeuction is enabled.")
|
||||
@ -364,7 +364,7 @@ class Template(checkpointable.CheckpointableBase):
|
||||
"""
|
||||
def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
|
||||
inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
|
||||
# we don't want to propagate.
|
||||
# we don't want to propagate.
|
||||
return next_creator(
|
||||
initial_value=initializer,
|
||||
name=name,
|
||||
@ -647,7 +647,7 @@ class EagerTemplate(Template):
|
||||
Raises:
|
||||
RuntimeError: if eager execution is not enabled.
|
||||
"""
|
||||
if not context.in_eager_mode():
|
||||
if not context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
"{} objects can only be used when eager execution is enabled, use "
|
||||
"tf.Template for graph construction".
|
||||
|
@ -338,7 +338,7 @@ class _GraphTensorArray(object):
|
||||
with ops.name_scope(name, "TensorArrayScatter",
|
||||
[self._handle, value, indices]):
|
||||
value = ops.convert_to_tensor(value, name="value")
|
||||
if self._infer_shape and context.in_graph_mode():
|
||||
if self._infer_shape and not context.executing_eagerly():
|
||||
self._merge_element_shape(value.shape[1:])
|
||||
with self._maybe_colocate_with(value):
|
||||
flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
|
||||
@ -363,7 +363,7 @@ class _GraphTensorArray(object):
|
||||
value = ops.convert_to_tensor(value, name="value")
|
||||
with self._maybe_colocate_with(value):
|
||||
lengths_64 = math_ops.to_int64(lengths)
|
||||
if self._infer_shape and context.in_graph_mode():
|
||||
if self._infer_shape and not context.executing_eagerly():
|
||||
clengths = tensor_util.constant_value(lengths_64)
|
||||
if value.shape.dims is not None:
|
||||
if clengths is not None and clengths.max() == clengths.min():
|
||||
@ -774,10 +774,10 @@ class TensorArray(object):
|
||||
ValueError: if both handle and tensor_array_name are provided.
|
||||
TypeError: if handle is provided but is not a Tensor.
|
||||
"""
|
||||
if context.in_graph_mode():
|
||||
implementation = _GraphTensorArray
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
implementation = _EagerTensorArray
|
||||
else:
|
||||
implementation = _GraphTensorArray
|
||||
|
||||
self._implementation = implementation(
|
||||
dtype,
|
||||
|
@ -321,7 +321,7 @@ class _VariableStore(object):
|
||||
raise ValueError(
|
||||
"Passed a custom_getter which is not callable: %s" % custom_getter)
|
||||
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
if not self._store_eager_variables and reuse:
|
||||
raise RuntimeError(
|
||||
"When eager execution is enabled variable reuse is only supported"
|
||||
@ -518,7 +518,7 @@ class _VariableStore(object):
|
||||
when violating reuse during variable creation, or if an existing
|
||||
sharded variable exists for the given name but with different sharding.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise NotImplementedError("Partitioned variables are not yet supported "
|
||||
"when eager execution is enabled.")
|
||||
|
||||
@ -798,7 +798,7 @@ class _VariableStore(object):
|
||||
validate_shape=validate_shape,
|
||||
constraint=constraint,
|
||||
use_resource=use_resource)
|
||||
if context.in_graph_mode() or self._store_eager_variables:
|
||||
if not context.executing_eagerly() or self._store_eager_variables:
|
||||
# In eager mode we do not want to keep default references to Variable
|
||||
# objects as this will prevent their memory from being released.
|
||||
self._vars[name] = v
|
||||
@ -811,12 +811,12 @@ class _VariableStore(object):
|
||||
with ops.name_scope(name + "/Regularizer/"):
|
||||
loss = regularizer(v)
|
||||
if loss is not None:
|
||||
if context.in_graph_mode():
|
||||
v_name = v.name
|
||||
loss_name = loss.name
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
v_name = "v_%s" % type(v)
|
||||
loss_name = "loss_%s" % type(loss)
|
||||
else:
|
||||
v_name = v.name
|
||||
loss_name = loss.name
|
||||
logging.vlog(1, "Applied regularizer to %s and added the result %s "
|
||||
"to REGULARIZATION_LOSSES.", v_name, loss_name)
|
||||
ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, loss)
|
||||
@ -920,7 +920,7 @@ class VariableScope(object):
|
||||
self._dtype = dtype
|
||||
self._use_resource = use_resource
|
||||
self._constraint = constraint
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
if self._caching_device is not None:
|
||||
raise NotImplementedError("Caching devices is not yet supported "
|
||||
"when eager execution is enabled.")
|
||||
@ -988,7 +988,7 @@ class VariableScope(object):
|
||||
|
||||
def set_use_resource(self, use_resource):
|
||||
"""Sets whether to use ResourceVariables for this scope."""
|
||||
if context.in_eager_mode() and not use_resource:
|
||||
if context.executing_eagerly() and not use_resource:
|
||||
raise ValueError("When eager execution is enabled, "
|
||||
"use_resource cannot be set to false.")
|
||||
self._use_resource = use_resource
|
||||
@ -999,14 +999,14 @@ class VariableScope(object):
|
||||
|
||||
def set_caching_device(self, caching_device):
|
||||
"""Set caching_device for this scope."""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise NotImplementedError("Caching devices are not yet supported "
|
||||
"when eager execution is enabled.")
|
||||
self._caching_device = caching_device
|
||||
|
||||
def set_partitioner(self, partitioner):
|
||||
"""Set partitioner for this scope."""
|
||||
if partitioner and context.in_eager_mode():
|
||||
if partitioner and context.executing_eagerly():
|
||||
raise NotImplementedError("Partitioned variables are not yet supported "
|
||||
"when eager execution is enabled.")
|
||||
self._partitioner = partitioner
|
||||
@ -1057,14 +1057,14 @@ class VariableScope(object):
|
||||
partitioner = self._partitioner
|
||||
if custom_getter is None:
|
||||
custom_getter = self._custom_getter
|
||||
if context.in_graph_mode():
|
||||
if context.executing_eagerly():
|
||||
reuse = False
|
||||
use_resource = True
|
||||
else:
|
||||
if reuse is None:
|
||||
reuse = self._reuse
|
||||
if use_resource is None:
|
||||
use_resource = self._use_resource
|
||||
else:
|
||||
reuse = False
|
||||
use_resource = True
|
||||
|
||||
full_name = self.name + "/" + name if self.name else name
|
||||
# Variable names only depend on variable_scope (full_name here),
|
||||
@ -1107,7 +1107,7 @@ class VariableScope(object):
|
||||
use_resource=None,
|
||||
constraint=None):
|
||||
"""Gets an existing variable with this name or create a new one."""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise NotImplementedError("Partitioned variables are not yet supported "
|
||||
"when eager execution is enabled.")
|
||||
if initializer is None:
|
||||
@ -1871,7 +1871,7 @@ class variable_scope(object):
|
||||
raise ValueError("The reuse parameter must be True or False or None.")
|
||||
if self._values is None:
|
||||
self._values = []
|
||||
self._in_graph_mode = not context.in_eager_mode()
|
||||
self._in_graph_mode = not context.executing_eagerly()
|
||||
if self._in_graph_mode:
|
||||
self._graph = ops._get_graph_from_inputs(self._values) # pylint: disable=protected-access
|
||||
self._cached_pure_variable_scope = None
|
||||
@ -2111,13 +2111,13 @@ def default_variable_creator(next_creator=None, **kwargs):
|
||||
use_resource = kwargs.get("use_resource", None)
|
||||
if use_resource is None:
|
||||
use_resource = get_variable_scope().use_resource
|
||||
if use_resource or (use_resource is None and context.in_eager_mode()):
|
||||
if use_resource or (use_resource is None and context.executing_eagerly()):
|
||||
return resource_variable_ops.ResourceVariable(
|
||||
initial_value=initial_value, trainable=trainable,
|
||||
collections=collections, validate_shape=validate_shape,
|
||||
caching_device=caching_device, name=name, dtype=dtype,
|
||||
constraint=constraint)
|
||||
elif not use_resource and context.in_eager_mode():
|
||||
elif not use_resource and context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
"VariableScope should use resource variable when eager execution is"
|
||||
" enabled, but use_resource is False."
|
||||
|
@ -210,10 +210,11 @@ class Variable(checkpointable.CheckpointableBase):
|
||||
for details on how variables work in eager execution.
|
||||
@end_compatibility
|
||||
"""
|
||||
if not context.in_graph_mode():
|
||||
raise RuntimeError("tf.Variable not supported in Eager mode. "
|
||||
"Please use tfe.Variable instead")
|
||||
self._in_graph_mode = context.in_graph_mode()
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
"tf.Variable not supported when eager execution is enabled. "
|
||||
"Please use tf.contrib.eager.Variable instead")
|
||||
self._in_graph_mode = True
|
||||
if variable_def:
|
||||
# If variable_def is provided, recreates the variable from its fields.
|
||||
if initial_value:
|
||||
@ -234,7 +235,7 @@ class Variable(checkpointable.CheckpointableBase):
|
||||
constraint=constraint)
|
||||
|
||||
def __repr__(self):
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % (
|
||||
self.name, self.get_shape(), self.dtype.name,
|
||||
ops.numpy_text(self.read_value(), is_repr=True))
|
||||
@ -740,15 +741,15 @@ class Variable(checkpointable.CheckpointableBase):
|
||||
Raises:
|
||||
ValueError: Session is not passed and no default session
|
||||
"""
|
||||
if context.in_graph_mode():
|
||||
if context.executing_eagerly():
|
||||
self.assign(value)
|
||||
else:
|
||||
session = session or ops.get_default_session()
|
||||
if session is None:
|
||||
raise ValueError(
|
||||
"Either session argument should be provided or default session "
|
||||
"should be established")
|
||||
session.run(self._initializer_op, {self._initializer_op.inputs[1]: value})
|
||||
else:
|
||||
self.assign(value)
|
||||
|
||||
# Conversion to tensor.
|
||||
@staticmethod
|
||||
@ -1248,9 +1249,9 @@ class PartitionedVariable(object):
|
||||
information does not match `shape`, or `partitions` has invalid values.
|
||||
RuntimeError: If eager execution is enabled
|
||||
"""
|
||||
if not context.in_graph_mode():
|
||||
raise RuntimeError("tf.PartitionedVariable not supported in "
|
||||
"eager mode. Please use tfe.Variable instead")
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
"tf.PartitionedVariable not supported with eager execution enabled.")
|
||||
if not isinstance(variable_list, (list, tuple)):
|
||||
raise TypeError(
|
||||
"variable_list is not a list or tuple: %s" % variable_list)
|
||||
@ -1541,7 +1542,7 @@ def variables_initializer(var_list, name="init"):
|
||||
Returns:
|
||||
An Op that run the initializers of all the specified variables.
|
||||
"""
|
||||
if var_list and context.in_graph_mode():
|
||||
if var_list and not context.executing_eagerly():
|
||||
return control_flow_ops.group(*[v.initializer for v in var_list], name=name)
|
||||
return control_flow_ops.no_op(name=name)
|
||||
|
||||
@ -1563,7 +1564,7 @@ def global_variables_initializer():
|
||||
Returns:
|
||||
An Op that initializes global variables in the graph.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return control_flow_ops.no_op(name="global_variables_initializer")
|
||||
return variables_initializer(global_variables())
|
||||
|
||||
@ -1585,7 +1586,7 @@ def local_variables_initializer():
|
||||
Returns:
|
||||
An Op that initializes all local variables in the graph.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
return control_flow_ops.no_op(name="local_variables_initializer")
|
||||
return variables_initializer(local_variables())
|
||||
|
||||
|
@ -172,7 +172,7 @@ class Profiler(object):
|
||||
op_log: optional. tensorflow::tfprof::OpLogProto proto. Used to define
|
||||
extra op types.
|
||||
"""
|
||||
if not graph and context.in_graph_mode():
|
||||
if not graph and not context.executing_eagerly():
|
||||
graph = ops.get_default_graph()
|
||||
self._coverage = 0.0
|
||||
self._graph = graph
|
||||
@ -336,7 +336,7 @@ def profile(graph=None,
|
||||
If cmd is 'op' or 'code', returns MultiGraphNodeProto proto.
|
||||
Side effect: stdout/file/timeline.json depending on options['output']
|
||||
"""
|
||||
if not graph and context.in_graph_mode():
|
||||
if not graph and not context.executing_eagerly():
|
||||
graph = ops.get_default_graph()
|
||||
|
||||
if options == _DEFAULT_PROFILE_OPTIONS:
|
||||
|
@ -156,7 +156,7 @@ def merge_default_with_oplog(graph, op_log=None, run_meta=None,
|
||||
Returns:
|
||||
tmp_op_log: Merged OpLogProto proto.
|
||||
"""
|
||||
if not graph and context.in_graph_mode():
|
||||
if not graph and not context.executing_eagerly():
|
||||
graph = ops.get_default_graph()
|
||||
|
||||
tmp_op_log = tfprof_log_pb2.OpLogProto()
|
||||
@ -210,7 +210,7 @@ def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True):
|
||||
add_trace: Whether to add python code trace information.
|
||||
Used to support "code" view.
|
||||
"""
|
||||
if not graph and context.in_graph_mode():
|
||||
if not graph and not context.executing_eagerly():
|
||||
graph = ops.get_default_graph()
|
||||
op_log = merge_default_with_oplog(graph, op_log, run_meta, add_trace)
|
||||
|
||||
|
@ -278,7 +278,7 @@ def merge(inputs, collections=None, name=None):
|
||||
@end_compatbility
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
if _context.in_eager_mode():
|
||||
if _context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
'Merging tf.summary.* ops is not compatible with eager execution. '
|
||||
'Use tf.contrib.summary instead.')
|
||||
@ -311,7 +311,7 @@ def merge_all(key=_ops.GraphKeys.SUMMARIES, scope=None):
|
||||
summaries under eager execution, use `tf.contrib.summary` instead.
|
||||
@end_compatbility
|
||||
"""
|
||||
if _context.in_eager_mode():
|
||||
if _context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
'Merging tf.summary.* ops is not compatible with eager execution. '
|
||||
'Use tf.contrib.summary instead.')
|
||||
|
@ -343,7 +343,7 @@ class FileWriter(SummaryToEventTransformer):
|
||||
summaries under eager execution, use `tf.contrib.summary` instead.
|
||||
@end_compatbility
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
"tf.summary.FileWriter is not compatible with eager execution. "
|
||||
"Use tf.contrib.summary instead.")
|
||||
|
@ -106,10 +106,10 @@ class AdamOptimizer(optimizer.Optimizer):
|
||||
self._updated_lr = None
|
||||
|
||||
def _get_beta_accumulators(self):
|
||||
if context.in_graph_mode():
|
||||
graph = ops.get_default_graph()
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
graph = None
|
||||
else:
|
||||
graph = ops.get_default_graph()
|
||||
return (self._get_non_slot_variable("beta1_power", graph=graph),
|
||||
self._get_non_slot_variable("beta2_power", graph=graph))
|
||||
|
||||
|
@ -184,7 +184,7 @@ class AdamOptimizerTest(test.TestCase):
|
||||
# Shouldn't return non-slot variables from other graphs.
|
||||
self.assertEqual(0, len(opt.variables()))
|
||||
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||
@ -194,7 +194,7 @@ class AdamOptimizerTest(test.TestCase):
|
||||
|
||||
# Run 3 steps of Adam
|
||||
for t in range(1, 4):
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(update)
|
||||
elif t > 1:
|
||||
opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
|
@ -208,7 +208,7 @@ class _CheckpointPosition(object):
|
||||
# Name saveables based on the name this object had when it was checkpointed.
|
||||
named_saveables = {}
|
||||
restore_ops = []
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
building_graph = not context.executing_eagerly()
|
||||
for serialized_tensor in self.object_proto.attributes:
|
||||
saveable_object = saveables.get(serialized_tensor.name, None)
|
||||
if saveable_object is None:
|
||||
@ -219,7 +219,7 @@ class _CheckpointPosition(object):
|
||||
self._checkpoint.unused_attributes.setdefault(
|
||||
self.checkpointable, []).append(serialized_tensor.name)
|
||||
continue
|
||||
if in_graph_mode:
|
||||
if building_graph:
|
||||
existing_ops = self._checkpoint.restore_ops_by_name.get(
|
||||
serialized_tensor.name, None)
|
||||
else:
|
||||
@ -245,7 +245,7 @@ class _CheckpointPosition(object):
|
||||
saveable_index:saveable_index + num_specs]
|
||||
saveable_index += num_specs
|
||||
restore_op = saveable.restore(saveable_tensors, restored_shapes=None)
|
||||
if in_graph_mode:
|
||||
if building_graph:
|
||||
assert saveable.name not in self._checkpoint.restore_ops_by_name
|
||||
self._checkpoint.restore_ops_by_name[saveable.name] = restore_op
|
||||
restore_ops.append(restore_op)
|
||||
@ -388,7 +388,7 @@ class CheckpointableBase(object):
|
||||
"Checkpointable._add_variable called to create another with "
|
||||
"that name. Variable names must be unique within a Checkpointable "
|
||||
"object.") % (name,))
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
# If this is a variable with a single Tensor stored in the checkpoint, we
|
||||
# can set that value as an initializer rather than initializing and then
|
||||
# assigning (when executing eagerly). This call returns None if there is
|
||||
|
@ -71,6 +71,6 @@ class GradientDescentOptimizer(optimizer.Optimizer):
|
||||
return var.scatter_sub(delta, use_locking=self._use_locking)
|
||||
|
||||
def _prepare(self):
|
||||
if context.in_graph_mode() or self._learning_rate_tensor is None:
|
||||
if not context.executing_eagerly() or self._learning_rate_tensor is None:
|
||||
self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
|
||||
name="learning_rate")
|
||||
|
@ -159,7 +159,7 @@ def input_producer(input_tensor,
|
||||
enabled. Please use the `tf.data` API to ingest data under eager execution.
|
||||
@end_compatibility
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
"Input pipelines based on Queues are not supported when eager execution"
|
||||
" is enabled. Please use tf.data to ingest data into your model"
|
||||
@ -737,7 +737,7 @@ def _batch(tensors, batch_size, keep_input, num_threads=1, capacity=32,
|
||||
allow_smaller_final_batch=False, shared_name=None,
|
||||
name=None):
|
||||
"""Helper function for `batch` and `maybe_batch`."""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise ValueError(
|
||||
"Input pipelines based on Queues are not supported when eager execution"
|
||||
" is enabled. Please use tf.data to ingest data into your model"
|
||||
@ -775,7 +775,7 @@ def _batch_join(tensors_list, batch_size, keep_input, capacity=32,
|
||||
enqueue_many=False, shapes=None, dynamic_pad=False,
|
||||
allow_smaller_final_batch=False, shared_name=None, name=None):
|
||||
"""Helper function for `batch_join` and `maybe_batch_join`."""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise ValueError(
|
||||
"Input pipelines based on Queues are not supported when eager execution"
|
||||
" is enabled. Please use tf.data to ingest data into your model"
|
||||
@ -810,7 +810,7 @@ def _shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
|
||||
shapes=None, allow_smaller_final_batch=False,
|
||||
shared_name=None, name=None):
|
||||
"""Helper function for `shuffle_batch` and `maybe_shuffle_batch`."""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise ValueError(
|
||||
"Input pipelines based on Queues are not supported when eager execution"
|
||||
" is enabled. Please use tf.data to ingest data into your model"
|
||||
@ -855,7 +855,7 @@ def _shuffle_batch_join(tensors_list, batch_size, capacity,
|
||||
allow_smaller_final_batch=False, shared_name=None,
|
||||
name=None):
|
||||
"""Helper function for `shuffle_batch_join` and `maybe_shuffle_batch_join`."""
|
||||
if context.in_eager_mode():
|
||||
if context.executing_eagerly():
|
||||
raise ValueError(
|
||||
"Input pipelines based on Queues are not supported when eager execution"
|
||||
" is enabled. Please use tf.data to ingest data into your model"
|
||||
|
@ -113,7 +113,7 @@ class LRDecayTest(test_util.TensorFlowTestCase):
|
||||
learning_rate_decay.piecewise_constant(x, boundaries, values)
|
||||
|
||||
# Test that ref types are valid.
|
||||
if context.in_graph_mode():
|
||||
if not context.executing_eagerly():
|
||||
x = variables.Variable(0.0)
|
||||
x_ref = x.op.outputs[0] # float32_ref tensor should be accepted
|
||||
boundaries, values = [1.0, 2.0], [1, 2, 3]
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user