Cleanup unused/deprecated functions in tf.layers.utils.

PiperOrigin-RevId: 302527461
Change-Id: I25146c03710baa61c96c666564e3e57f8809bd1f
This commit is contained in:
Scott Zhu 2020-03-23 15:09:42 -07:00 committed by TensorFlower Gardener
parent b3dab55405
commit c0d13627b3
3 changed files with 2 additions and 84 deletions

View File

@ -7089,7 +7089,6 @@ py_library(
deps = [ deps = [
":control_flow_ops", ":control_flow_ops",
":smart_cond", ":smart_cond",
":util",
":variables", ":variables",
], ],
) )

View File

@ -18,10 +18,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.ops import variables
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.framework import smart_cond as smart_module from tensorflow.python.framework import smart_cond as smart_module
from tensorflow.python.util import nest from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variables
def convert_data_format(data_format, ndim): def convert_data_format(data_format, ndim):
@ -225,61 +224,3 @@ def constant_value(pred):
if isinstance(pred, variables.Variable): if isinstance(pred, variables.Variable):
return None return None
return smart_module.smart_constant_value(pred) return smart_module.smart_constant_value(pred)
def object_list_uid(object_list):
"""Creates a single string from object ids."""
object_list = nest.flatten(object_list)
return ', '.join(str(abs(id(x))) for x in object_list)
def static_shape(x):
"""Get the static shape of a Tensor, or None if it is unavailable."""
if x is None:
return None
try:
return tuple(x.get_shape().as_list())
except ValueError:
return None
def get_reachable_from_inputs(inputs, targets=None):
"""Returns the set of tensors reachable from `inputs`.
Stops if all targets have been found (target is optional).
Only valid in Symbolic mode, not Eager mode.
Args:
inputs: List of tensors.
targets: List of tensors.
Returns:
A set of tensors reachable from the inputs (includes the inputs themselves).
"""
reachable = set(inputs)
if targets:
targets = set(targets)
queue = inputs[:]
while queue:
x = queue.pop()
outputs = []
try:
consumers = x.consumers()
except AttributeError:
# Case where x is a variable type
consumers = [x.op]
for z in consumers:
consumer_outputs = z.outputs
if consumer_outputs: # May be None
outputs += consumer_outputs
for y in outputs:
if y not in reachable:
reachable.add(y)
queue.insert(0, y)
if targets and targets.issubset(reachable):
return reachable
return reachable

View File

@ -115,27 +115,5 @@ class ConstantValueTest(test.TestCase):
utils.constant_value(5) utils.constant_value(5)
class GetReachableFromInputsTest(test.TestCase):
@test_util.run_deprecated_v1
def testGetReachableFromInputs(self):
pl_1 = array_ops.placeholder(shape=None, dtype='float32')
pl_2 = array_ops.placeholder(shape=None, dtype='float32')
pl_3 = array_ops.placeholder(shape=None, dtype='float32')
x_1 = pl_1 + pl_2
x_2 = pl_2 * 2
x_3 = pl_3 + 1
x_4 = x_1 + x_2
x_5 = x_3 * pl_1
self.assertEqual({pl_1, x_1, x_4, x_5},
utils.get_reachable_from_inputs([pl_1]))
self.assertEqual({pl_1, pl_2, x_1, x_2, x_4, x_5},
utils.get_reachable_from_inputs([pl_1, pl_2]))
self.assertEqual({pl_3, x_3, x_5}, utils.get_reachable_from_inputs([pl_3]))
self.assertEqual({x_3, x_5}, utils.get_reachable_from_inputs([x_3]))
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()