Update eager uniform replay buffer microbenchmarks to compare against graph functions when possible.
PiperOrigin-RevId: 187075418
This commit is contained in:
parent
c7ea6ace71
commit
ba2cc572f9
@ -143,7 +143,7 @@ class CriticalSection(object):
|
||||
def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name
|
||||
"""Initialize the CriticalSection from constructor arguments."""
|
||||
with ops.name_scope(name, "CriticalSection", []) as name:
|
||||
with ops.control_dependencies(None):
|
||||
with ops.init_scope():
|
||||
# pylint: disable=protected-access
|
||||
container = ops.get_default_graph()._container
|
||||
# pylint: enable=protected-access
|
||||
@ -226,7 +226,9 @@ class CriticalSection(object):
|
||||
# mode. This is generally ok; since eager mode (as of
|
||||
# writing) executes sequentially anyway.
|
||||
for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
|
||||
if sg.handle.name == self._handle.name:
|
||||
sg_handle_name = ops.convert_to_tensor(sg.handle).name
|
||||
self_handle_name = ops.convert_to_tensor(self._handle).name
|
||||
if sg_handle_name == self_handle_name:
|
||||
# Other executions in the same critical section are allowed.
|
||||
continue
|
||||
if not (exclusive_resource_access or sg.exclusive_resource_access):
|
||||
|
@ -4805,7 +4805,14 @@ def container(container_name):
|
||||
@tf_export("colocate_with")
|
||||
def colocate_with(op, ignore_existing=False):
|
||||
if context.in_graph_mode():
|
||||
return get_default_graph().colocate_with(op, ignore_existing)
|
||||
default_graph = get_default_graph()
|
||||
if isinstance(op, EagerTensor):
|
||||
if default_graph.building_function:
|
||||
op = internal_convert_to_tensor(op)
|
||||
else:
|
||||
raise ValueError("Encountered an Eager-defined Tensor during graph "
|
||||
"construction, but a function was not being built.")
|
||||
return default_graph.colocate_with(op, ignore_existing)
|
||||
else:
|
||||
if op is not None:
|
||||
return device(op.device)
|
||||
|
Loading…
Reference in New Issue
Block a user