Update eager uniform replay buffer microbenchmarks to compare against graph functions when possible.

PiperOrigin-RevId: 187075418
This commit is contained in:
Akshay Agrawal 2018-02-26 13:54:02 -08:00 committed by TensorFlower Gardener
parent c7ea6ace71
commit ba2cc572f9
2 changed files with 12 additions and 3 deletions

View File

@ -143,7 +143,7 @@ class CriticalSection(object):
def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name
"""Initialize the CriticalSection from constructor arguments."""
with ops.name_scope(name, "CriticalSection", []) as name:
with ops.control_dependencies(None):
with ops.init_scope():
# pylint: disable=protected-access
container = ops.get_default_graph()._container
# pylint: enable=protected-access
@ -226,7 +226,9 @@ class CriticalSection(object):
# mode. This is generally ok; since eager mode (as of
# writing) executes sequentially anyway.
for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
if sg.handle.name == self._handle.name:
sg_handle_name = ops.convert_to_tensor(sg.handle).name
self_handle_name = ops.convert_to_tensor(self._handle).name
if sg_handle_name == self_handle_name:
# Other executions in the same critical section are allowed.
continue
if not (exclusive_resource_access or sg.exclusive_resource_access):

View File

@ -4805,7 +4805,14 @@ def container(container_name):
@tf_export("colocate_with")
def colocate_with(op, ignore_existing=False):
if context.in_graph_mode():
return get_default_graph().colocate_with(op, ignore_existing)
default_graph = get_default_graph()
if isinstance(op, EagerTensor):
if default_graph.building_function:
op = internal_convert_to_tensor(op)
else:
raise ValueError("Encountered an Eager-defined Tensor during graph "
"construction, but a function was not being built.")
return default_graph.colocate_with(op, ignore_existing)
else:
if op is not None:
return device(op.device)