diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index f2e0e5127dd..0a55525e753 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -98,6 +98,7 @@ from tensorflow.python.ops.distributions import distributions from tensorflow.python.ops.linalg import linalg from tensorflow.python.ops.linalg.sparse import sparse from tensorflow.python.ops.losses import losses +from tensorflow.python.ops.ragged import ragged_ops as _ragged_ops from tensorflow.python.ops.signal import signal from tensorflow.python.profiler import profiler from tensorflow.python.profiler import profiler_client @@ -108,6 +109,9 @@ from tensorflow.python.tpu import api from tensorflow.python.user_ops import user_ops from tensorflow.python.util import compat +# Update the RaggedTensor package docs w/ a list of ops that support dispatch. +ragged.__doc__ += _ragged_ops.ragged_dispatch.ragged_op_list() + # Import to make sure the ops are registered. from tensorflow.python.ops import gen_audio_ops from tensorflow.python.ops import gen_boosted_trees_ops diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 1f6ac5b654f..65b59c72e8a 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -67,11 +67,12 @@ from tensorflow.python.framework import versions from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import control_flow_util_v2 -from tensorflow.python.ops.ragged import ragged_tensor -from tensorflow.python.ops.ragged import ragged_tensor_value from tensorflow.python.ops import script_ops from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.ops import variables +from tensorflow.python.ops.ragged import ragged_ops # pylint: disable=unused-import +from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.ops.ragged import ragged_tensor_value from tensorflow.python.platform import googletest from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib @@ -80,9 +81,9 @@ from tensorflow.python.util import deprecation from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect +from tensorflow.python.util.compat import collections_abc from tensorflow.python.util.protobuf import compare from tensorflow.python.util.tf_export import tf_export -from tensorflow.python.util.compat import collections_abc # If the below import is made available through the BUILD rule, then this diff --git a/tensorflow/python/ops/parsing_config.py b/tensorflow/python/ops/parsing_config.py index bed17c7859e..64553a1f169 100644 --- a/tensorflow/python/ops/parsing_config.py +++ b/tensorflow/python/ops/parsing_config.py @@ -29,21 +29,12 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops.ragged import ragged_math_ops +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import tf_logging -from tensorflow.python.util.lazy_loader import LazyLoader from tensorflow.python.util.tf_export import tf_export -# Avoid circular dependencies with RaggedTensor. -# TODO(b/141170488) Refactor ragged modules so this is unnecessary. -ragged_tensor = LazyLoader( - "ragged_tensor", globals(), - "tensorflow.python.ops.ragged.ragged_tensor") -ragged_math_ops = LazyLoader( - "ragged_math_ops", globals(), - "tensorflow.python.ops.ragged.ragged_math_ops") - - # TODO(b/122887740) Refactor code: # * Move input verification to feature configuration objects (e.g., # VarLenFeature should check that dtype is a valid dtype). diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD index bb70a706b4b..2c17ee953ac 100644 --- a/tensorflow/python/ops/ragged/BUILD +++ b/tensorflow/python/ops/ragged/BUILD @@ -41,6 +41,7 @@ py_library( ":ragged_map_ops", ":ragged_math_ops", ":ragged_operators", + ":ragged_ops", ":ragged_string_ops", ":ragged_tensor", ":ragged_tensor_shape", @@ -231,6 +232,35 @@ py_library( ], ) +py_library( + name = "ragged_ops", + srcs = ["ragged_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":ragged_array_ops", + ":ragged_batch_gather_ops", + ":ragged_batch_gather_with_default_op", + ":ragged_concat_ops", + ":ragged_config", + ":ragged_conversion_ops", + ":ragged_dispatch", + ":ragged_factory_ops", + ":ragged_functional_ops", + ":ragged_gather_ops", + ":ragged_getitem", + ":ragged_map_ops", + ":ragged_math_ops", + ":ragged_operators", + ":ragged_string_ops", + ":ragged_tensor", + ":ragged_tensor_shape", + ":ragged_tensor_value", + ":ragged_util", + ":ragged_where_op", + ":segment_id_ops", + ], +) + py_library( name = "ragged_string_ops", srcs = ["ragged_string_ops.py"], diff --git a/tensorflow/python/ops/ragged/__init__.py b/tensorflow/python/ops/ragged/__init__.py index e9232a1c641..1874943b913 100644 --- a/tensorflow/python/ops/ragged/__init__.py +++ b/tensorflow/python/ops/ragged/__init__.py @@ -27,26 +27,3 @@ and the [Ragged Tensor Guide](/guide/ragged_tensors). from __future__ import absolute_import from __future__ import division from __future__ import print_function - -from tensorflow.python.ops.ragged import ragged_array_ops -from tensorflow.python.ops.ragged import ragged_batch_gather_ops -from tensorflow.python.ops.ragged import ragged_batch_gather_with_default_op -from tensorflow.python.ops.ragged import ragged_concat_ops -from tensorflow.python.ops.ragged import ragged_conversion_ops -from tensorflow.python.ops.ragged import ragged_dispatch -from tensorflow.python.ops.ragged import ragged_factory_ops -from tensorflow.python.ops.ragged import ragged_functional_ops -from tensorflow.python.ops.ragged import ragged_gather_ops -from tensorflow.python.ops.ragged import ragged_getitem -from tensorflow.python.ops.ragged import ragged_map_ops -from tensorflow.python.ops.ragged import ragged_math_ops -from tensorflow.python.ops.ragged import ragged_operators -from tensorflow.python.ops.ragged import ragged_string_ops -from tensorflow.python.ops.ragged import ragged_tensor -from tensorflow.python.ops.ragged import ragged_tensor_shape -from tensorflow.python.ops.ragged import ragged_tensor_value -from tensorflow.python.ops.ragged import ragged_where_op -from tensorflow.python.ops.ragged import segment_id_ops - -# Add a list of the ops that support Ragged Tensors. -__doc__ += ragged_dispatch.ragged_op_list() # pylint: disable=redefined-builtin diff --git a/tensorflow/python/ops/ragged/ragged_ops.py b/tensorflow/python/ops/ragged/ragged_ops.py new file mode 100644 index 00000000000..4c5a9909141 --- /dev/null +++ b/tensorflow/python/ops/ragged/ragged_ops.py @@ -0,0 +1,50 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Import all modules in the `ragged` package that define exported symbols. + +Additional, import ragged_dispatch (which has the side-effect of registering +dispatch handlers for many standard TF ops) and ragged_operators (which has the +side-effect of overriding RaggedTensor operators, such as RaggedTensor.__add__). + +We don't import these modules from ragged/__init__.py, since we want to avoid +circular dependencies. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +# pylint: disable=unused-import +from tensorflow.python.ops.ragged import ragged_array_ops +from tensorflow.python.ops.ragged import ragged_batch_gather_ops +from tensorflow.python.ops.ragged import ragged_batch_gather_with_default_op +from tensorflow.python.ops.ragged import ragged_concat_ops +from tensorflow.python.ops.ragged import ragged_conversion_ops +from tensorflow.python.ops.ragged import ragged_dispatch +from tensorflow.python.ops.ragged import ragged_factory_ops +from tensorflow.python.ops.ragged import ragged_functional_ops +from tensorflow.python.ops.ragged import ragged_gather_ops +from tensorflow.python.ops.ragged import ragged_getitem +from tensorflow.python.ops.ragged import ragged_map_ops +from tensorflow.python.ops.ragged import ragged_math_ops +from tensorflow.python.ops.ragged import ragged_operators +from tensorflow.python.ops.ragged import ragged_squeeze_op +from tensorflow.python.ops.ragged import ragged_string_ops +from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.ops.ragged import ragged_tensor_shape +from tensorflow.python.ops.ragged import ragged_tensor_value +from tensorflow.python.ops.ragged import ragged_where_op +from tensorflow.python.ops.ragged import segment_id_ops diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py index fccdf8fe3c1..22d7ff3c588 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor.py +++ b/tensorflow/python/ops/ragged/ragged_tensor.py @@ -2117,15 +2117,50 @@ class RaggedTensor(composite_tensor.CompositeTensor): return isinstance(rt, ops.EagerTensor) #============================================================================= - # Indexing & Slicing + # Operators #============================================================================= - def __getitem__(self, key): - """Returns the specified piece of this RaggedTensor.""" - # See ragged_getitem.py for the documentation and implementation of this - # method. - # - # Note: the imports in ragged/__init__.py ensure that this method always - # gets overridden before it is called. + # To avoid circular dependencies, we define stub methods for operators here, + # and then override them when the ragged_operators module is imported. + + def _overloaded_operator(name): # pylint: disable=no-self-argument + def stub(*args, **kwargs): + del args, kwargs + raise ValueError( + "You must import 'tensorflow.python.ops.ragged.ragged_ops' " + "before using RaggedTensor.%s" % name) + return stub + + __getitem__ = _overloaded_operator("__getitem__") + __ge__ = _overloaded_operator("__ge__") + __gt__ = _overloaded_operator("__gt__") + __le__ = _overloaded_operator("__le__") + __lt__ = _overloaded_operator("__lt__") + __and__ = _overloaded_operator("__and__") + __rand__ = _overloaded_operator("__rand__") + __invert__ = _overloaded_operator("__invert__") + __ror__ = _overloaded_operator("__ror__") + __or__ = _overloaded_operator("__or__") + __xor__ = _overloaded_operator("__xor__") + __rxor__ = _overloaded_operator("__rxor__") + __abs__ = _overloaded_operator("__abs__") + __add__ = _overloaded_operator("__add__") + __radd__ = _overloaded_operator("__radd__") + __div__ = _overloaded_operator("__div__") + __rdiv__ = _overloaded_operator("__rdiv__") + __floordiv__ = _overloaded_operator("__floordiv__") + __rfloordiv__ = _overloaded_operator("__rfloordiv__") + __mod__ = _overloaded_operator("__mod__") + __rmod__ = _overloaded_operator("__rmod__") + __mul__ = _overloaded_operator("__mul__") + __rmul__ = _overloaded_operator("__rmul__") + __neg__ = _overloaded_operator("__neg__") + __pow__ = _overloaded_operator("__pow__") + __rpow__ = _overloaded_operator("__rpow__") + __sub__ = _overloaded_operator("__sub__") + __rsub__ = _overloaded_operator("__rsub__") + __truediv__ = _overloaded_operator("__truediv__") + __rtruediv__ = _overloaded_operator("__rtruediv__") + del _overloaded_operator #============================================================================= # Name Scope