Remove circular dependencies by removing submodule imports from ragged package.

(If ragged/__init__.py imports all ragged modules, then it's impossible to depend on one ragged module without depending on all of them.)

PiperOrigin-RevId: 299479545
Change-Id: I469e160a49efdd47c137497328ea47800a621e85
This commit is contained in:
Edward Loper 2020-03-06 18:23:35 -08:00 committed by TensorFlower Gardener
parent 56ee67f7b1
commit 6d0f422525
7 changed files with 133 additions and 45 deletions

View File

@ -98,6 +98,7 @@ from tensorflow.python.ops.distributions import distributions
from tensorflow.python.ops.linalg import linalg from tensorflow.python.ops.linalg import linalg
from tensorflow.python.ops.linalg.sparse import sparse from tensorflow.python.ops.linalg.sparse import sparse
from tensorflow.python.ops.losses import losses from tensorflow.python.ops.losses import losses
from tensorflow.python.ops.ragged import ragged_ops as _ragged_ops
from tensorflow.python.ops.signal import signal from tensorflow.python.ops.signal import signal
from tensorflow.python.profiler import profiler from tensorflow.python.profiler import profiler
from tensorflow.python.profiler import profiler_client from tensorflow.python.profiler import profiler_client
@ -108,6 +109,9 @@ from tensorflow.python.tpu import api
from tensorflow.python.user_ops import user_ops from tensorflow.python.user_ops import user_ops
from tensorflow.python.util import compat from tensorflow.python.util import compat
# Update the RaggedTensor package docs w/ a list of ops that support dispatch.
ragged.__doc__ += _ragged_ops.ragged_dispatch.ragged_op_list()
# Import to make sure the ops are registered. # Import to make sure the ops are registered.
from tensorflow.python.ops import gen_audio_ops from tensorflow.python.ops import gen_audio_ops
from tensorflow.python.ops import gen_boosted_trees_ops from tensorflow.python.ops import gen_boosted_trees_ops

View File

@ -67,11 +67,12 @@ from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import control_flow_util_v2 from tensorflow.python.ops import control_flow_util_v2
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.ops import script_ops from tensorflow.python.ops import script_ops
from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_ops # pylint: disable=unused-import
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib
@ -80,9 +81,9 @@ from tensorflow.python.util import deprecation
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect from tensorflow.python.util import tf_inspect
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.protobuf import compare from tensorflow.python.util.protobuf import compare
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.util.compat import collections_abc
# If the below import is made available through the BUILD rule, then this # If the below import is made available through the BUILD rule, then this

View File

@ -29,21 +29,12 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging from tensorflow.python.platform import tf_logging
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
# Avoid circular dependencies with RaggedTensor.
# TODO(b/141170488) Refactor ragged modules so this is unnecessary.
ragged_tensor = LazyLoader(
"ragged_tensor", globals(),
"tensorflow.python.ops.ragged.ragged_tensor")
ragged_math_ops = LazyLoader(
"ragged_math_ops", globals(),
"tensorflow.python.ops.ragged.ragged_math_ops")
# TODO(b/122887740) Refactor code: # TODO(b/122887740) Refactor code:
# * Move input verification to feature configuration objects (e.g., # * Move input verification to feature configuration objects (e.g.,
# VarLenFeature should check that dtype is a valid dtype). # VarLenFeature should check that dtype is a valid dtype).

View File

@ -41,6 +41,7 @@ py_library(
":ragged_map_ops", ":ragged_map_ops",
":ragged_math_ops", ":ragged_math_ops",
":ragged_operators", ":ragged_operators",
":ragged_ops",
":ragged_string_ops", ":ragged_string_ops",
":ragged_tensor", ":ragged_tensor",
":ragged_tensor_shape", ":ragged_tensor_shape",
@ -231,6 +232,35 @@ py_library(
], ],
) )
py_library(
name = "ragged_ops",
srcs = ["ragged_ops.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_array_ops",
":ragged_batch_gather_ops",
":ragged_batch_gather_with_default_op",
":ragged_concat_ops",
":ragged_config",
":ragged_conversion_ops",
":ragged_dispatch",
":ragged_factory_ops",
":ragged_functional_ops",
":ragged_gather_ops",
":ragged_getitem",
":ragged_map_ops",
":ragged_math_ops",
":ragged_operators",
":ragged_string_ops",
":ragged_tensor",
":ragged_tensor_shape",
":ragged_tensor_value",
":ragged_util",
":ragged_where_op",
":segment_id_ops",
],
)
py_library( py_library(
name = "ragged_string_ops", name = "ragged_string_ops",
srcs = ["ragged_string_ops.py"], srcs = ["ragged_string_ops.py"],

View File

@ -27,26 +27,3 @@ and the [Ragged Tensor Guide](/guide/ragged_tensors).
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_batch_gather_ops
from tensorflow.python.ops.ragged import ragged_batch_gather_with_default_op
from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_conversion_ops
from tensorflow.python.ops.ragged import ragged_dispatch
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_functional_ops
from tensorflow.python.ops.ragged import ragged_gather_ops
from tensorflow.python.ops.ragged import ragged_getitem
from tensorflow.python.ops.ragged import ragged_map_ops
from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_operators
from tensorflow.python.ops.ragged import ragged_string_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_shape
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.ops.ragged import ragged_where_op
from tensorflow.python.ops.ragged import segment_id_ops
# Add a list of the ops that support Ragged Tensors.
__doc__ += ragged_dispatch.ragged_op_list() # pylint: disable=redefined-builtin

View File

@ -0,0 +1,50 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Import all modules in the `ragged` package that define exported symbols.
Additional, import ragged_dispatch (which has the side-effect of registering
dispatch handlers for many standard TF ops) and ragged_operators (which has the
side-effect of overriding RaggedTensor operators, such as RaggedTensor.__add__).
We don't import these modules from ragged/__init__.py, since we want to avoid
circular dependencies.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_batch_gather_ops
from tensorflow.python.ops.ragged import ragged_batch_gather_with_default_op
from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_conversion_ops
from tensorflow.python.ops.ragged import ragged_dispatch
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_functional_ops
from tensorflow.python.ops.ragged import ragged_gather_ops
from tensorflow.python.ops.ragged import ragged_getitem
from tensorflow.python.ops.ragged import ragged_map_ops
from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_operators
from tensorflow.python.ops.ragged import ragged_squeeze_op
from tensorflow.python.ops.ragged import ragged_string_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_shape
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.ops.ragged import ragged_where_op
from tensorflow.python.ops.ragged import segment_id_ops

View File

@ -2117,15 +2117,50 @@ class RaggedTensor(composite_tensor.CompositeTensor):
return isinstance(rt, ops.EagerTensor) return isinstance(rt, ops.EagerTensor)
#============================================================================= #=============================================================================
# Indexing & Slicing # Operators
#============================================================================= #=============================================================================
def __getitem__(self, key): # To avoid circular dependencies, we define stub methods for operators here,
"""Returns the specified piece of this RaggedTensor.""" # and then override them when the ragged_operators module is imported.
# See ragged_getitem.py for the documentation and implementation of this
# method. def _overloaded_operator(name): # pylint: disable=no-self-argument
# def stub(*args, **kwargs):
# Note: the imports in ragged/__init__.py ensure that this method always del args, kwargs
# gets overridden before it is called. raise ValueError(
"You must import 'tensorflow.python.ops.ragged.ragged_ops' "
"before using RaggedTensor.%s" % name)
return stub
__getitem__ = _overloaded_operator("__getitem__")
__ge__ = _overloaded_operator("__ge__")
__gt__ = _overloaded_operator("__gt__")
__le__ = _overloaded_operator("__le__")
__lt__ = _overloaded_operator("__lt__")
__and__ = _overloaded_operator("__and__")
__rand__ = _overloaded_operator("__rand__")
__invert__ = _overloaded_operator("__invert__")
__ror__ = _overloaded_operator("__ror__")
__or__ = _overloaded_operator("__or__")
__xor__ = _overloaded_operator("__xor__")
__rxor__ = _overloaded_operator("__rxor__")
__abs__ = _overloaded_operator("__abs__")
__add__ = _overloaded_operator("__add__")
__radd__ = _overloaded_operator("__radd__")
__div__ = _overloaded_operator("__div__")
__rdiv__ = _overloaded_operator("__rdiv__")
__floordiv__ = _overloaded_operator("__floordiv__")
__rfloordiv__ = _overloaded_operator("__rfloordiv__")
__mod__ = _overloaded_operator("__mod__")
__rmod__ = _overloaded_operator("__rmod__")
__mul__ = _overloaded_operator("__mul__")
__rmul__ = _overloaded_operator("__rmul__")
__neg__ = _overloaded_operator("__neg__")
__pow__ = _overloaded_operator("__pow__")
__rpow__ = _overloaded_operator("__rpow__")
__sub__ = _overloaded_operator("__sub__")
__rsub__ = _overloaded_operator("__rsub__")
__truediv__ = _overloaded_operator("__truediv__")
__rtruediv__ = _overloaded_operator("__rtruediv__")
del _overloaded_operator
#============================================================================= #=============================================================================
# Name Scope # Name Scope