Cleanup unused/deprecated functions in tf.layers.utils.
PiperOrigin-RevId: 302527461 Change-Id: I25146c03710baa61c96c666564e3e57f8809bd1f
This commit is contained in:
parent
b3dab55405
commit
c0d13627b3
@ -7089,7 +7089,6 @@ py_library(
|
||||
deps = [
|
||||
":control_flow_ops",
|
||||
":smart_cond",
|
||||
":util",
|
||||
":variables",
|
||||
],
|
||||
)
|
||||
|
@ -18,10 +18,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.framework import smart_cond as smart_module
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
def convert_data_format(data_format, ndim):
|
||||
@ -225,61 +224,3 @@ def constant_value(pred):
|
||||
if isinstance(pred, variables.Variable):
|
||||
return None
|
||||
return smart_module.smart_constant_value(pred)
|
||||
|
||||
|
||||
def object_list_uid(object_list):
|
||||
"""Creates a single string from object ids."""
|
||||
object_list = nest.flatten(object_list)
|
||||
return ', '.join(str(abs(id(x))) for x in object_list)
|
||||
|
||||
|
||||
def static_shape(x):
|
||||
"""Get the static shape of a Tensor, or None if it is unavailable."""
|
||||
if x is None:
|
||||
return None
|
||||
try:
|
||||
return tuple(x.get_shape().as_list())
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def get_reachable_from_inputs(inputs, targets=None):
|
||||
"""Returns the set of tensors reachable from `inputs`.
|
||||
|
||||
Stops if all targets have been found (target is optional).
|
||||
|
||||
Only valid in Symbolic mode, not Eager mode.
|
||||
|
||||
Args:
|
||||
inputs: List of tensors.
|
||||
targets: List of tensors.
|
||||
|
||||
Returns:
|
||||
A set of tensors reachable from the inputs (includes the inputs themselves).
|
||||
"""
|
||||
reachable = set(inputs)
|
||||
if targets:
|
||||
targets = set(targets)
|
||||
queue = inputs[:]
|
||||
|
||||
while queue:
|
||||
x = queue.pop()
|
||||
outputs = []
|
||||
try:
|
||||
consumers = x.consumers()
|
||||
except AttributeError:
|
||||
# Case where x is a variable type
|
||||
consumers = [x.op]
|
||||
for z in consumers:
|
||||
consumer_outputs = z.outputs
|
||||
if consumer_outputs: # May be None
|
||||
outputs += consumer_outputs
|
||||
|
||||
for y in outputs:
|
||||
if y not in reachable:
|
||||
reachable.add(y)
|
||||
queue.insert(0, y)
|
||||
|
||||
if targets and targets.issubset(reachable):
|
||||
return reachable
|
||||
return reachable
|
||||
|
@ -115,27 +115,5 @@ class ConstantValueTest(test.TestCase):
|
||||
utils.constant_value(5)
|
||||
|
||||
|
||||
class GetReachableFromInputsTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGetReachableFromInputs(self):
|
||||
|
||||
pl_1 = array_ops.placeholder(shape=None, dtype='float32')
|
||||
pl_2 = array_ops.placeholder(shape=None, dtype='float32')
|
||||
pl_3 = array_ops.placeholder(shape=None, dtype='float32')
|
||||
x_1 = pl_1 + pl_2
|
||||
x_2 = pl_2 * 2
|
||||
x_3 = pl_3 + 1
|
||||
x_4 = x_1 + x_2
|
||||
x_5 = x_3 * pl_1
|
||||
|
||||
self.assertEqual({pl_1, x_1, x_4, x_5},
|
||||
utils.get_reachable_from_inputs([pl_1]))
|
||||
self.assertEqual({pl_1, pl_2, x_1, x_2, x_4, x_5},
|
||||
utils.get_reachable_from_inputs([pl_1, pl_2]))
|
||||
self.assertEqual({pl_3, x_3, x_5}, utils.get_reachable_from_inputs([pl_3]))
|
||||
self.assertEqual({x_3, x_5}, utils.get_reachable_from_inputs([x_3]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user