Remove circular dependencies by removing submodule imports from ragged package.

(If ragged/__init__.py imports all ragged modules, then it's impossible to depend on one ragged module without depending on all of them.)

PiperOrigin-RevId: 299479545
Change-Id: I469e160a49efdd47c137497328ea47800a621e85
This commit is contained in:
Edward Loper 2020-03-06 18:23:35 -08:00 committed by TensorFlower Gardener
parent 56ee67f7b1
commit 6d0f422525
7 changed files with 133 additions and 45 deletions

View File

@ -98,6 +98,7 @@ from tensorflow.python.ops.distributions import distributions
from tensorflow.python.ops.linalg import linalg
from tensorflow.python.ops.linalg.sparse import sparse
from tensorflow.python.ops.losses import losses
from tensorflow.python.ops.ragged import ragged_ops as _ragged_ops
from tensorflow.python.ops.signal import signal
from tensorflow.python.profiler import profiler
from tensorflow.python.profiler import profiler_client
@ -108,6 +109,9 @@ from tensorflow.python.tpu import api
from tensorflow.python.user_ops import user_ops
from tensorflow.python.util import compat
# Update the RaggedTensor package docs w/ a list of ops that support dispatch.
ragged.__doc__ += _ragged_ops.ragged_dispatch.ragged_op_list()
# Import to make sure the ops are registered.
from tensorflow.python.ops import gen_audio_ops
from tensorflow.python.ops import gen_boosted_trees_ops

View File

@ -67,11 +67,12 @@ from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import control_flow_util_v2
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_ops # pylint: disable=unused-import
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
@ -80,9 +81,9 @@ from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.protobuf import compare
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.util.compat import collections_abc
# If the below import is made available through the BUILD rule, then this

View File

@ -29,21 +29,12 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export
# Avoid circular dependencies with RaggedTensor.
# TODO(b/141170488) Refactor ragged modules so this is unnecessary.
ragged_tensor = LazyLoader(
"ragged_tensor", globals(),
"tensorflow.python.ops.ragged.ragged_tensor")
ragged_math_ops = LazyLoader(
"ragged_math_ops", globals(),
"tensorflow.python.ops.ragged.ragged_math_ops")
# TODO(b/122887740) Refactor code:
# * Move input verification to feature configuration objects (e.g.,
# VarLenFeature should check that dtype is a valid dtype).

View File

@ -41,6 +41,7 @@ py_library(
":ragged_map_ops",
":ragged_math_ops",
":ragged_operators",
":ragged_ops",
":ragged_string_ops",
":ragged_tensor",
":ragged_tensor_shape",
@ -231,6 +232,35 @@ py_library(
],
)
py_library(
name = "ragged_ops",
srcs = ["ragged_ops.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_array_ops",
":ragged_batch_gather_ops",
":ragged_batch_gather_with_default_op",
":ragged_concat_ops",
":ragged_config",
":ragged_conversion_ops",
":ragged_dispatch",
":ragged_factory_ops",
":ragged_functional_ops",
":ragged_gather_ops",
":ragged_getitem",
":ragged_map_ops",
":ragged_math_ops",
":ragged_operators",
":ragged_string_ops",
":ragged_tensor",
":ragged_tensor_shape",
":ragged_tensor_value",
":ragged_util",
":ragged_where_op",
":segment_id_ops",
],
)
py_library(
name = "ragged_string_ops",
srcs = ["ragged_string_ops.py"],

View File

@ -27,26 +27,3 @@ and the [Ragged Tensor Guide](/guide/ragged_tensors).
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_batch_gather_ops
from tensorflow.python.ops.ragged import ragged_batch_gather_with_default_op
from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_conversion_ops
from tensorflow.python.ops.ragged import ragged_dispatch
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_functional_ops
from tensorflow.python.ops.ragged import ragged_gather_ops
from tensorflow.python.ops.ragged import ragged_getitem
from tensorflow.python.ops.ragged import ragged_map_ops
from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_operators
from tensorflow.python.ops.ragged import ragged_string_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_shape
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.ops.ragged import ragged_where_op
from tensorflow.python.ops.ragged import segment_id_ops
# Add a list of the ops that support Ragged Tensors.
__doc__ += ragged_dispatch.ragged_op_list() # pylint: disable=redefined-builtin

View File

@ -0,0 +1,50 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Import all modules in the `ragged` package that define exported symbols.
Additional, import ragged_dispatch (which has the side-effect of registering
dispatch handlers for many standard TF ops) and ragged_operators (which has the
side-effect of overriding RaggedTensor operators, such as RaggedTensor.__add__).
We don't import these modules from ragged/__init__.py, since we want to avoid
circular dependencies.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_batch_gather_ops
from tensorflow.python.ops.ragged import ragged_batch_gather_with_default_op
from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_conversion_ops
from tensorflow.python.ops.ragged import ragged_dispatch
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_functional_ops
from tensorflow.python.ops.ragged import ragged_gather_ops
from tensorflow.python.ops.ragged import ragged_getitem
from tensorflow.python.ops.ragged import ragged_map_ops
from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_operators
from tensorflow.python.ops.ragged import ragged_squeeze_op
from tensorflow.python.ops.ragged import ragged_string_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_shape
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.ops.ragged import ragged_where_op
from tensorflow.python.ops.ragged import segment_id_ops

View File

@ -2117,15 +2117,50 @@ class RaggedTensor(composite_tensor.CompositeTensor):
return isinstance(rt, ops.EagerTensor)
#=============================================================================
# Indexing & Slicing
# Operators
#=============================================================================
def __getitem__(self, key):
"""Returns the specified piece of this RaggedTensor."""
# See ragged_getitem.py for the documentation and implementation of this
# method.
#
# Note: the imports in ragged/__init__.py ensure that this method always
# gets overridden before it is called.
# To avoid circular dependencies, we define stub methods for operators here,
# and then override them when the ragged_operators module is imported.
def _overloaded_operator(name): # pylint: disable=no-self-argument
def stub(*args, **kwargs):
del args, kwargs
raise ValueError(
"You must import 'tensorflow.python.ops.ragged.ragged_ops' "
"before using RaggedTensor.%s" % name)
return stub
__getitem__ = _overloaded_operator("__getitem__")
__ge__ = _overloaded_operator("__ge__")
__gt__ = _overloaded_operator("__gt__")
__le__ = _overloaded_operator("__le__")
__lt__ = _overloaded_operator("__lt__")
__and__ = _overloaded_operator("__and__")
__rand__ = _overloaded_operator("__rand__")
__invert__ = _overloaded_operator("__invert__")
__ror__ = _overloaded_operator("__ror__")
__or__ = _overloaded_operator("__or__")
__xor__ = _overloaded_operator("__xor__")
__rxor__ = _overloaded_operator("__rxor__")
__abs__ = _overloaded_operator("__abs__")
__add__ = _overloaded_operator("__add__")
__radd__ = _overloaded_operator("__radd__")
__div__ = _overloaded_operator("__div__")
__rdiv__ = _overloaded_operator("__rdiv__")
__floordiv__ = _overloaded_operator("__floordiv__")
__rfloordiv__ = _overloaded_operator("__rfloordiv__")
__mod__ = _overloaded_operator("__mod__")
__rmod__ = _overloaded_operator("__rmod__")
__mul__ = _overloaded_operator("__mul__")
__rmul__ = _overloaded_operator("__rmul__")
__neg__ = _overloaded_operator("__neg__")
__pow__ = _overloaded_operator("__pow__")
__rpow__ = _overloaded_operator("__rpow__")
__sub__ = _overloaded_operator("__sub__")
__rsub__ = _overloaded_operator("__rsub__")
__truediv__ = _overloaded_operator("__truediv__")
__rtruediv__ = _overloaded_operator("__rtruediv__")
del _overloaded_operator
#=============================================================================
# Name Scope