Cleanup unused/deprecated functions in tf.layers.utils.
PiperOrigin-RevId: 302527461 Change-Id: I25146c03710baa61c96c666564e3e57f8809bd1f
This commit is contained in:
parent
b3dab55405
commit
c0d13627b3
@ -7089,7 +7089,6 @@ py_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":control_flow_ops",
|
":control_flow_ops",
|
||||||
":smart_cond",
|
":smart_cond",
|
||||||
":util",
|
|
||||||
":variables",
|
":variables",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -18,10 +18,9 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.ops import variables
|
|
||||||
from tensorflow.python.ops import control_flow_ops
|
|
||||||
from tensorflow.python.framework import smart_cond as smart_module
|
from tensorflow.python.framework import smart_cond as smart_module
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
|
||||||
|
|
||||||
def convert_data_format(data_format, ndim):
|
def convert_data_format(data_format, ndim):
|
||||||
@ -225,61 +224,3 @@ def constant_value(pred):
|
|||||||
if isinstance(pred, variables.Variable):
|
if isinstance(pred, variables.Variable):
|
||||||
return None
|
return None
|
||||||
return smart_module.smart_constant_value(pred)
|
return smart_module.smart_constant_value(pred)
|
||||||
|
|
||||||
|
|
||||||
def object_list_uid(object_list):
|
|
||||||
"""Creates a single string from object ids."""
|
|
||||||
object_list = nest.flatten(object_list)
|
|
||||||
return ', '.join(str(abs(id(x))) for x in object_list)
|
|
||||||
|
|
||||||
|
|
||||||
def static_shape(x):
|
|
||||||
"""Get the static shape of a Tensor, or None if it is unavailable."""
|
|
||||||
if x is None:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
return tuple(x.get_shape().as_list())
|
|
||||||
except ValueError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_reachable_from_inputs(inputs, targets=None):
|
|
||||||
"""Returns the set of tensors reachable from `inputs`.
|
|
||||||
|
|
||||||
Stops if all targets have been found (target is optional).
|
|
||||||
|
|
||||||
Only valid in Symbolic mode, not Eager mode.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs: List of tensors.
|
|
||||||
targets: List of tensors.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A set of tensors reachable from the inputs (includes the inputs themselves).
|
|
||||||
"""
|
|
||||||
reachable = set(inputs)
|
|
||||||
if targets:
|
|
||||||
targets = set(targets)
|
|
||||||
queue = inputs[:]
|
|
||||||
|
|
||||||
while queue:
|
|
||||||
x = queue.pop()
|
|
||||||
outputs = []
|
|
||||||
try:
|
|
||||||
consumers = x.consumers()
|
|
||||||
except AttributeError:
|
|
||||||
# Case where x is a variable type
|
|
||||||
consumers = [x.op]
|
|
||||||
for z in consumers:
|
|
||||||
consumer_outputs = z.outputs
|
|
||||||
if consumer_outputs: # May be None
|
|
||||||
outputs += consumer_outputs
|
|
||||||
|
|
||||||
for y in outputs:
|
|
||||||
if y not in reachable:
|
|
||||||
reachable.add(y)
|
|
||||||
queue.insert(0, y)
|
|
||||||
|
|
||||||
if targets and targets.issubset(reachable):
|
|
||||||
return reachable
|
|
||||||
return reachable
|
|
||||||
|
@ -115,27 +115,5 @@ class ConstantValueTest(test.TestCase):
|
|||||||
utils.constant_value(5)
|
utils.constant_value(5)
|
||||||
|
|
||||||
|
|
||||||
class GetReachableFromInputsTest(test.TestCase):
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testGetReachableFromInputs(self):
|
|
||||||
|
|
||||||
pl_1 = array_ops.placeholder(shape=None, dtype='float32')
|
|
||||||
pl_2 = array_ops.placeholder(shape=None, dtype='float32')
|
|
||||||
pl_3 = array_ops.placeholder(shape=None, dtype='float32')
|
|
||||||
x_1 = pl_1 + pl_2
|
|
||||||
x_2 = pl_2 * 2
|
|
||||||
x_3 = pl_3 + 1
|
|
||||||
x_4 = x_1 + x_2
|
|
||||||
x_5 = x_3 * pl_1
|
|
||||||
|
|
||||||
self.assertEqual({pl_1, x_1, x_4, x_5},
|
|
||||||
utils.get_reachable_from_inputs([pl_1]))
|
|
||||||
self.assertEqual({pl_1, pl_2, x_1, x_2, x_4, x_5},
|
|
||||||
utils.get_reachable_from_inputs([pl_1, pl_2]))
|
|
||||||
self.assertEqual({pl_3, x_3, x_5}, utils.get_reachable_from_inputs([pl_3]))
|
|
||||||
self.assertEqual({x_3, x_5}, utils.get_reachable_from_inputs([x_3]))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user