From 45cfe7126619b79ce3d88234a775d0468faa48cb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Dec 2018 14:51:36 -0800 Subject: [PATCH] Internal Change PiperOrigin-RevId: 224225849 --- .../api_def/python_api/api_def_FloorDiv.pbtxt | 4 +- .../api_def/python_api/api_def_FloorMod.pbtxt | 7 +- .../api_def/python_api/api_def_RealDiv.pbtxt | 4 +- .../python_api/api_def_TruncateDiv.pbtxt | 4 +- .../python_api/api_def_TruncateMod.pbtxt | 4 +- tensorflow/python/framework/python_op_gen.cc | 4 +- tensorflow/python/ops/array_ops.py | 11 + tensorflow/python/ops/clip_ops.py | 2 + tensorflow/python/ops/math_ops.py | 32 +- tensorflow/python/ops/ragged/BUILD | 66 +-- tensorflow/python/ops/ragged/__init__.py | 110 ++--- .../python/ops/ragged/ragged_array_ops.py | 121 ++--- .../python/ops/ragged/ragged_dispatch.py | 441 ++++++++++++++++++ ...se_ops_test.py => ragged_dispatch_test.py} | 246 +++++----- .../ops/ragged/ragged_elementwise_ops.py | 389 --------------- .../ops/ragged/ragged_map_fn_op_test.py | 11 +- .../python/ops/ragged/ragged_math_ops.py | 111 +++-- .../python/ops/ragged/ragged_operators.py | 61 ++- tensorflow/python/ops/sparse_ops.py | 1 + tensorflow/python/ops/string_ops.py | 10 + tensorflow/python/util/dispatch.py | 7 +- 21 files changed, 889 insertions(+), 757 deletions(-) create mode 100644 tensorflow/python/ops/ragged/ragged_dispatch.py rename tensorflow/python/ops/ragged/{ragged_elementwise_ops_test.py => ragged_dispatch_test.py} (77%) delete mode 100644 tensorflow/python/ops/ragged/ragged_elementwise_ops.py diff --git a/tensorflow/core/api_def/python_api/api_def_FloorDiv.pbtxt b/tensorflow/core/api_def/python_api/api_def_FloorDiv.pbtxt index 26598ab1fb9..efd42b888d2 100644 --- a/tensorflow/core/api_def/python_api/api_def_FloorDiv.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_FloorDiv.pbtxt @@ -1,4 +1,6 @@ op { graph_op_name: "FloorDiv" - visibility: HIDDEN + endpoint { + name: "floor_div" + } } diff --git a/tensorflow/core/api_def/python_api/api_def_FloorMod.pbtxt b/tensorflow/core/api_def/python_api/api_def_FloorMod.pbtxt index ef562e93a0d..e5db6d49b29 100644 --- a/tensorflow/core/api_def/python_api/api_def_FloorMod.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_FloorMod.pbtxt @@ -1,4 +1,9 @@ op { graph_op_name: "FloorMod" - visibility: HIDDEN + endpoint { + name: "floormod" + } + endpoint { + name: "mod" + } } diff --git a/tensorflow/core/api_def/python_api/api_def_RealDiv.pbtxt b/tensorflow/core/api_def/python_api/api_def_RealDiv.pbtxt index bd87eef8240..f9e01eb5674 100644 --- a/tensorflow/core/api_def/python_api/api_def_RealDiv.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_RealDiv.pbtxt @@ -1,4 +1,6 @@ op { graph_op_name: "RealDiv" - visibility: HIDDEN + endpoint { + name: "realdiv" + } } diff --git a/tensorflow/core/api_def/python_api/api_def_TruncateDiv.pbtxt b/tensorflow/core/api_def/python_api/api_def_TruncateDiv.pbtxt index 2a547f771cf..8e46c5e663a 100644 --- a/tensorflow/core/api_def/python_api/api_def_TruncateDiv.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_TruncateDiv.pbtxt @@ -1,4 +1,6 @@ op { graph_op_name: "TruncateDiv" - visibility: HIDDEN + endpoint { + name: "truncatediv" + } } diff --git a/tensorflow/core/api_def/python_api/api_def_TruncateMod.pbtxt b/tensorflow/core/api_def/python_api/api_def_TruncateMod.pbtxt index 0731e8810e2..97fb816a7ad 100644 --- a/tensorflow/core/api_def/python_api/api_def_TruncateMod.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_TruncateMod.pbtxt @@ -1,4 +1,6 @@ op { graph_op_name: "TruncateMod" - visibility: HIDDEN + endpoint { + name: "truncatemod" + } } diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index d91f7b0bdde..d460168631c 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -634,7 +634,9 @@ void GenEagerPythonOp::AddEagerFunctionTeardown( bool GenEagerPythonOp::AddEagerFastPathAndGraphCode( const string& parameters, const std::vector& output_sizes, const string& eager_not_allowed_error) { - strings::StrAppend(&result_, "@_dispatch.add_dispatch_list\n"); + if (api_def_.visibility() == ApiDef::VISIBLE) { + strings::StrAppend(&result_, "@_dispatch.add_dispatch_list\n"); + } AddExport(); AddDefLine(function_name_, parameters); AddDocStringDescription(); diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 7b6242b1cd7..b555f63cebb 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -56,6 +56,7 @@ _BaseSlice = slice @tf_export("identity") +@dispatch.add_dispatch_support def identity(input, name=None): # pylint: disable=redefined-builtin r"""Return a tensor with the same shape and contents as input. @@ -139,6 +140,7 @@ def expand_dims(input, axis=None, name=None, dim=None): @tf_export("expand_dims", v1=[]) +@dispatch.add_dispatch_support def expand_dims_v2(input, axis, name=None): """Inserts a dimension of 1 into a tensor's shape. @@ -941,6 +943,7 @@ def parallel_stack(values, name="parallel_stack"): @tf_export("stack") +@dispatch.add_dispatch_support def stack(values, axis=0, name="stack"): """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor. @@ -1151,6 +1154,7 @@ def unstack(value, num=None, axis=0, name="unstack"): @tf_export("concat") +@dispatch.add_dispatch_support def concat(values, axis, name="concat"): """Concatenates tensors along one dimension. @@ -1328,6 +1332,7 @@ def boolean_mask(tensor, mask, name="boolean_mask", axis=None): @tf_export("boolean_mask", v1=[]) +@dispatch.add_dispatch_support def boolean_mask_v2(tensor, mask, axis=None, name="boolean_mask"): """Apply boolean mask to tensor. @@ -1810,6 +1815,7 @@ def zeros(shape, dtype=dtypes.float32, name=None): @tf_export(v1=["zeros_like"]) +@dispatch.add_dispatch_support def zeros_like(tensor, dtype=None, name=None, optimize=True): """Creates a tensor with all elements set to zero. @@ -1840,6 +1846,7 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True): @tf_export("zeros_like", v1=[]) +@dispatch.add_dispatch_support def zeros_like_v2( input, # pylint: disable=redefined-builtin dtype=None, @@ -1899,6 +1906,7 @@ def zeros_like_impl(tensor, dtype, name, optimize=True): @tf_export(v1=["ones_like"]) +@dispatch.add_dispatch_support def ones_like(tensor, dtype=None, name=None, optimize=True): """Creates a tensor with all elements set to 1. @@ -1929,6 +1937,7 @@ def ones_like(tensor, dtype=None, name=None, optimize=True): @tf_export("ones_like", v1=[]) +@dispatch.add_dispatch_support def ones_like_v2( input, # pylint: disable=redefined-builtin dtype=None, @@ -3115,6 +3124,7 @@ def squeeze_v2(input, axis=None, name=None): @tf_export("where") +@dispatch.add_dispatch_support def where(condition, x=None, y=None, name=None): """Return the elements, either from `x` or `y`, depending on the `condition`. @@ -3234,6 +3244,7 @@ def gather(params, indices, validate_indices=None, name=None, axis=0): @tf_export("gather", v1=[]) +@dispatch.add_dispatch_support def gather_v2(params, indices, validate_indices=None, axis=0, name=None): return gather(params, indices, validate_indices=validate_indices, name=name, axis=axis) diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py index 82803ac3516..a237cfff826 100644 --- a/tensorflow/python/ops/clip_ops.py +++ b/tensorflow/python/ops/clip_ops.py @@ -31,10 +31,12 @@ from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import numerics from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @tf_export("clip_by_value") +@dispatch.add_dispatch_support def clip_by_value(t, clip_value_min, clip_value_max, name=None): """Clips tensor values to a specified min and max. diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index f0d8bed5087..e2b634ee8f8 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -230,6 +230,7 @@ class DivideDelegateWithName(object): @tf_export("math.divide", "divide") +@dispatch.add_dispatch_support def divide(x, y, name=None): """Computes Python style division of `x` by `y`.""" @@ -242,6 +243,7 @@ def divide(x, y, name=None): @tf_export("math.multiply", "multiply") +@dispatch.add_dispatch_support def multiply(x, y, name=None): return gen_math_ops.mul(x, y, name) @@ -262,6 +264,7 @@ _mul.__doc__ = ( @tf_export("math.subtract", "subtract") +@dispatch.add_dispatch_support def subtract(x, y, name=None): return gen_math_ops.sub(x, y, name) @@ -347,6 +350,7 @@ def scalar_mul_v2(scalar, x, name=None): @tf_export("math.pow", "pow") +@dispatch.add_dispatch_support def pow(x, y, name=None): # pylint: disable=redefined-builtin r"""Computes the power of one value to another. @@ -375,6 +379,7 @@ def pow(x, y, name=None): # pylint: disable=redefined-builtin # pylint: disable=redefined-builtin,redefined-outer-name @tf_export("dtypes.complex", "complex") +@dispatch.add_dispatch_support def complex(real, imag, name=None): r"""Converts two real numbers to a complex number. @@ -418,6 +423,7 @@ def complex(real, imag, name=None): @tf_export("math.real", v1=["math.real", "real"]) @deprecation.deprecated_endpoints("real") +@dispatch.add_dispatch_support def real(input, name=None): r"""Returns the real part of a complex (or real) tensor. @@ -450,6 +456,7 @@ def real(input, name=None): @tf_export("math.imag", v1=["math.imag", "imag"]) @deprecation.deprecated_endpoints("imag") +@dispatch.add_dispatch_support def imag(input, name=None): r"""Returns the imaginary part of a complex (or real) tensor. @@ -481,6 +488,7 @@ def imag(input, name=None): @tf_export("math.angle", v1=["math.angle", "angle"]) @deprecation.deprecated_endpoints("angle") +@dispatch.add_dispatch_support def angle(input, name=None): r"""Returns the element-wise argument of a complex (or real) tensor. @@ -520,6 +528,7 @@ def angle(input, name=None): @tf_export("math.round", "round") +@dispatch.add_dispatch_support def round(x, name=None): # pylint: disable=redefined-builtin """Rounds the values of a tensor to the nearest integer, element-wise. @@ -547,6 +556,7 @@ def round(x, name=None): # pylint: disable=redefined-builtin @tf_export("dtypes.cast", "cast") +@dispatch.add_dispatch_support def cast(x, dtype, name=None): """Casts a tensor to a new type. @@ -610,6 +620,7 @@ def cast(x, dtype, name=None): @tf_export("dtypes.saturate_cast", "saturate_cast") +@dispatch.add_dispatch_support def saturate_cast(value, dtype, name=None): """Performs a safe saturating cast of `value` to `dtype`. @@ -935,6 +946,7 @@ def _div_python2(x, y, name=None): @tf_export("math.truediv", "truediv") +@dispatch.add_dispatch_support def truediv(x, y, name=None): """Divides x / y elementwise (using Python 3 division operator semantics). @@ -992,6 +1004,7 @@ def div(x, y, name=None): @tf_export("div_no_nan") +@dispatch.add_dispatch_support def div_no_nan(x, y, name=None): """Computes an unsafe divide which returns 0 if the y is zero. @@ -1021,6 +1034,7 @@ mod = gen_math_ops.floor_mod # TODO(aselle): Deprecate this once all internal functionality uses # tf.truncatediv @tf_export("math.floordiv", v1=["math.floordiv", "floordiv"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("floordiv") def floordiv(x, y, name=None): """Divides `x / y` elementwise, rounding toward the most negative integer. @@ -1050,16 +1064,11 @@ def floordiv(x, y, name=None): realdiv = gen_math_ops.real_div -tf_export("realdiv")(realdiv) truncatediv = gen_math_ops.truncate_div -tf_export("truncatediv")(truncatediv) # TODO(aselle): Rename this to floordiv when we can. floor_div = gen_math_ops.floor_div -tf_export("floor_div")(floor_div) truncatemod = gen_math_ops.truncate_mod -tf_export("truncatemod")(truncatemod) floormod = gen_math_ops.floor_mod -tf_export("floormod", "mod")(floormod) def _mul_dispatch(x, y, name=None): @@ -1095,6 +1104,7 @@ _OverrideBinaryOperatorHelper(pow, "pow") @tf_export("math.logical_xor", v1=["math.logical_xor", "logical_xor"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("logical_xor") def logical_xor(x, y, name="LogicalXor"): """x ^ y = (x | y) & ~(x & y).""" @@ -1277,6 +1287,7 @@ def reduce_sum_v1(input_tensor, @tf_export("math.reduce_sum", "reduce_sum", v1=[]) +@dispatch.add_dispatch_support def reduce_sum(input_tensor, axis=None, keepdims=False, name=None): """Computes the sum of elements across dimensions of a tensor. @@ -1524,6 +1535,7 @@ def reduce_mean_v1(input_tensor, @tf_export("math.reduce_mean", "reduce_mean", v1=[]) +@dispatch.add_dispatch_support def reduce_mean(input_tensor, axis=None, keepdims=False, name=None): """Computes the mean of elements across dimensions of a tensor. @@ -1675,6 +1687,7 @@ def reduce_std(input_tensor, axis=None, keepdims=False, name=None): @tf_export("math.reduce_prod", "reduce_prod", v1=[]) +@dispatch.add_dispatch_support def reduce_prod(input_tensor, axis=None, keepdims=False, name=None): """Computes the product of elements across dimensions of a tensor. @@ -1796,6 +1809,7 @@ def reduce_min_v1(input_tensor, @tf_export("math.reduce_min", "reduce_min", v1=[]) +@dispatch.add_dispatch_support def reduce_min(input_tensor, axis=None, keepdims=False, name=None): """Computes the minimum of elements across dimensions of a tensor. @@ -1874,6 +1888,7 @@ def reduce_max_v1(input_tensor, @tf_export("math.reduce_max", "reduce_max", v1=[]) +@dispatch.add_dispatch_support def reduce_max(input_tensor, axis=None, keepdims=False, name=None): """Computes the maximum of elements across dimensions of a tensor. @@ -1961,6 +1976,7 @@ def reduce_all_v1(input_tensor, @tf_export("reduce_all", "math.reduce_all", v1=[]) +@dispatch.add_dispatch_support def reduce_all(input_tensor, axis=None, keepdims=False, name=None): """Computes the "logical and" of elements across dimensions of a tensor. @@ -2057,6 +2073,7 @@ def reduce_any_v1(input_tensor, @tf_export("math.reduce_any", "reduce_any", v1=[]) +@dispatch.add_dispatch_support def reduce_any(input_tensor, axis=None, keepdims=False, name=None): """Computes the "logical or" of elements across dimensions of a tensor. @@ -2619,6 +2636,7 @@ def _as_indexed_slices_list(inputs, optimize=True): @tf_export("math.add_n", "add_n") +@dispatch.add_dispatch_support def add_n(inputs, name=None): """Adds all input tensors element-wise. @@ -2764,6 +2782,7 @@ def sigmoid(x, name=None): @tf_export("math.log_sigmoid", v1=["math.log_sigmoid", "log_sigmoid"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("log_sigmoid") def log_sigmoid(x, name=None): """Computes log sigmoid of `x` element-wise. @@ -2973,6 +2992,7 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None): @tf_export("math.conj", v1=["math.conj", "conj"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("conj") def conj(x, name=None): r"""Returns the complex conjugate of a complex number. @@ -3077,6 +3097,7 @@ def _unsorted_segment_N(data, segment_ids, num_segments): "math.unsorted_segment_mean", v1=["math.unsorted_segment_mean", "unsorted_segment_mean"]) @deprecation.deprecated_endpoints("unsorted_segment_mean") +@dispatch.add_dispatch_support def unsorted_segment_mean(data, segment_ids, num_segments, name=None): r"""Computes the mean along segments of a tensor. @@ -3122,6 +3143,7 @@ def unsorted_segment_mean(data, segment_ids, num_segments, name=None): "math.unsorted_segment_sqrt_n", v1=["math.unsorted_segment_sqrt_n", "unsorted_segment_sqrt_n"]) @deprecation.deprecated_endpoints("unsorted_segment_sqrt_n") +@dispatch.add_dispatch_support def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None): r"""Computes the sum along segments of a tensor divided by the sqrt(N). diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD index e335c5cb6f3..fcd9adad218 100644 --- a/tensorflow/python/ops/ragged/BUILD +++ b/tensorflow/python/ops/ragged/BUILD @@ -25,7 +25,7 @@ py_library( deps = [ ":ragged_array_ops", ":ragged_conversion_ops", - ":ragged_elementwise_ops", + ":ragged_dispatch", ":ragged_factory_ops", ":ragged_functional_ops", ":ragged_getitem", @@ -150,33 +150,14 @@ py_library( ], ) -py_library( - name = "ragged_elementwise_ops", - srcs = ["ragged_elementwise_ops.py"], - srcs_version = "PY2AND3", - deps = [ - ":ragged_factory_ops", - ":ragged_tensor", - ":ragged_tensor_shape", - ":ragged_util", - "//tensorflow/python:array_ops", - "//tensorflow/python:clip_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:string_ops", - "//tensorflow/python:util", - ], -) - py_library( name = "ragged_operators", srcs = ["ragged_operators.py"], srcs_version = "PY2AND3", deps = [ - ":ragged_elementwise_ops", ":ragged_getitem", ":ragged_tensor", + "//tensorflow/python:math_ops", "//tensorflow/python:util", ], ) @@ -186,12 +167,13 @@ py_library( srcs = ["ragged_string_ops.py"], srcs_version = "PY2AND3", deps = [ - ":ragged_array_ops", ":ragged_conversion_ops", ":ragged_factory_ops", ":ragged_tensor", "//tensorflow/python:array_ops", - "//tensorflow/python:math_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:string_ops", "//tensorflow/python:util", ], ) @@ -219,10 +201,11 @@ py_library( ":ragged_tensor", ":ragged_util", "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", - "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", ], ) @@ -285,6 +268,29 @@ py_library( ], ) +py_library( + name = "ragged_dispatch", + srcs = ["ragged_dispatch.py"], + srcs_version = "PY2AND3", + deps = [ + ":ragged_array_ops", + ":ragged_factory_ops", + ":ragged_math_ops", + ":ragged_tensor", + ":ragged_tensor_shape", + ":ragged_util", + "//tensorflow/python:array_ops", + "//tensorflow/python:clip_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:string_ops", + "//tensorflow/python:util", + "//third_party/py/numpy", + ], +) + #------------------------------------------------------------------------------- # RaggedTensor Tests #------------------------------------------------------------------------------- @@ -458,6 +464,7 @@ py_test( "//tensorflow/python:errors", "//tensorflow/python:framework_test_lib", "//tensorflow/python:gradients_impl", + "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], ) @@ -684,17 +691,21 @@ py_test( ) py_test( - name = "ragged_elementwise_ops_test", - srcs = ["ragged_elementwise_ops_test.py"], + name = "ragged_dispatch_test", + srcs = ["ragged_dispatch_test.py"], srcs_version = "PY2AND3", deps = [ ":ragged", "//tensorflow/python:array_ops", + "//tensorflow/python:clip_ops", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:parsing_ops", "//tensorflow/python:platform_test", + "//tensorflow/python:string_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], @@ -725,6 +736,7 @@ py_test( "//tensorflow/python:platform_test", "//tensorflow/python:string_ops", "//tensorflow/python/keras:backend", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) @@ -735,8 +747,10 @@ py_test( srcs_version = "PY2AND3", deps = [ ":ragged", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/python/ops/ragged/__init__.py b/tensorflow/python/ops/ragged/__init__.py index 1b2a7be95fc..bfcaa366fc6 100644 --- a/tensorflow/python/ops/ragged/__init__.py +++ b/tensorflow/python/ops/ragged/__init__.py @@ -1,76 +1,53 @@ """Ragged Tensors. -This package defines the [`RaggedTensor`](ragged/RaggedTensor.md) class, which +This package defines the `tf.RaggedTensor` class, which represents tensors with non-uniform shapes. In particular, each `RaggedTensor` has one or more *ragged dimensions*, which are dimensions whose slices may have different lengths. For example, the inner (column) dimension of `rt=[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is ragged, since the column slices (`rt[0, :]`, ..., `rt[4, :]`) have different lengths. For a more detailed -description of ragged tensors, see the [`RaggedTensor`](ragged/RaggedTensor.md) +description of ragged tensors, see the `tf.RaggedTensor` class documentation. -## RaggedTensor Operations +## `RaggedTensor` Operations -This package also defines a collection of operations for manipulating -ragged tensors. +### `RaggedTensor` Factory ops -### RaggedTensor Versions of Standard Tensor Operations +* `tf.ragged.constant` +* `tf.ragged.from_row_splits` +* `tf.ragged.from_row_splits` +* `tf.ragged.from_row_lengths` +* `tf.ragged.from_row_starts` +* `tf.ragged.from_row_limits` +* `tf.ragged.from_value_rowids` +* `tf.ragged.from_nested_row_splits` +* `tf.ragged.from_nested_value_rowids` -Many of the operations defined by this package are analogous to -[`Tensor`](https://www.tensorflow.org/api_docs/python/tf/Tensor) -operations, but they accept `RaggedTensor`s as input and can return -`RaggedTensor`s as output. For example, `ragged.add` performs elementwise -addition just like `tf.add`, but can be used on `RaggedTensor`s. +### `RaggedTensor` Conversion ops -These `RaggedTensor` versions of the standard `Tensor` operations can also be -used with standard `Tensors`; and for the most part, they will return the same -value that the standard `Tensor` operation would return. However, there are -a few notable exceptions: +* `tf.ragged.from_tensor` +* `tf.ragged.to_tensor` +* `tf.ragged.from_sparse` +* `tf.ragged.to_sparse` +* `tf.ragged.from_variant` +* `tf.ragged.to_variant` +* `tf.ragged.convert_to_tensor_or_ragged_tensor` -* For [`ragged.stack(...)`](ragged/stack.md) and - [`ragged.concat(...)`](ragged/concat.md), the input tensors are not required - to have matching shapes. In the returned tensor, all dimensions up to the - `axis` dimension will be ragged. +### `RaggedTensor` Shape ops -### Ragged-Tensor Specific Operations +* `tf.ragged.row_splits` +* `tf.ragged.row_lengths` +* `tf.ragged.row_starts` +* `tf.ragged.row_limits` +* `tf.ragged.value_rowids` +* `tf.ragged.nrows` +* `tf.ragged.nested_row_splits` +* `tf.ragged.row_splits_to_segment_ids` +* `tf.ragged.segment_ids_to_row_splits` +* `tf.ragged.bounding_shape` -The following operations are specific to ragged tensors: - -* **Factory ops**: - [`constant(...)`](ragged/constant.md), - [`from_row_splits(...)`](ragged/from_row_splits.md), - [`from_row_lengths(...)`](ragged/from_row_lengths.md), - [`from_row_starts(...)`](ragged/from_row_starts.md), - [`from_row_limits(...)`](ragged/from_row_limits.md), - [`from_value_rowids(...)`](ragged/from_value_rowids.md), - [`from_nested_row_splits(...)`](ragged/from_nested_row_splits.md), - [`from_nested_value_rowids(...)`](ragged/from_nested_value_rowids.md). - -* **Conversion ops**: - [`from_tensor(...)`](ragged/from_tensor.md), - [`to_tensor(...)`](ragged/to_tensor.md), - [`from_sparse(...)`](ragged/from_sparse.md), - [`to_sparse(...)`](ragged/to_sparse.md), - [`from_variant(...)`](ragged/from_variant.md), - [`to_variant(...)`](ragged/to_variant.md), - [`convert_to_tensor_or_ragged_tensor(...)`]( - ragged/convert_to_tensor_or_ragged_tensor.md). - -* **Shape ops**: - [`row_splits(...)`](ragged/row_splits.md), - [`row_lengths(...)`](ragged/row_lengths.md), - [`row_starts(...)`](ragged/row_starts.md), - [`row_limits(...)`](ragged/row_limits.md), - [`value_rowids(...)`](ragged/value_rowids.md), - [`nrows(...)`](ragged/nrows.md), - [`nested_row_splits(...)`](ragged/nested_row_splits.md), - [`row_splits_to_segment_ids(...)`](ragged/row_splits_to_segment_ids.md), - [`segment_ids_to_row_splits(...)`](ragged/segment_ids_to_row_splits.md), - [`bounding_shape(...)`](ragged/bounding_shape.md). - -* **Functional ops**: - [`map_inner_values(...)`](ragged/map_inner_values.md), - [`make_elementwise_op(...)`](ragged/make_elementwise_op.md). +### Functional ops +* `tf.ragged.map_inner_values` @@ -140,21 +117,17 @@ The following operations are specific to ragged tensors: @@map_inner_values @@map_fn - -@@make_elementwise_op - @@RaggedTensorDynamicShape @@broadcast_to @@broadcast_dynamic_shape - - """ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.ops.ragged import ragged_dispatch from tensorflow.python.ops.ragged import ragged_operators from tensorflow.python.ops.ragged import ragged_string_ops @@ -179,11 +152,6 @@ from tensorflow.python.ops.ragged.ragged_conversion_ops import from_tensor from tensorflow.python.ops.ragged.ragged_conversion_ops import to_sparse from tensorflow.python.ops.ragged.ragged_conversion_ops import to_tensor -# pylint: disable=protected-access, wildcard-import -from tensorflow.python.ops.ragged.ragged_elementwise_ops import * -from tensorflow.python.ops.ragged.ragged_elementwise_ops import _symbols_to_export as _elementwise_ops -# pylint: enable=protected-access, wildcard-import - from tensorflow.python.ops.ragged.ragged_factory_ops import constant from tensorflow.python.ops.ragged.ragged_factory_ops import constant_value from tensorflow.python.ops.ragged.ragged_factory_ops import convert_to_tensor_or_ragged_tensor @@ -231,6 +199,10 @@ from tensorflow.python.ops.ragged.segment_id_ops import segment_ids_to_row_split from tensorflow.python.util import all_util as _all_util + +# Register OpDispatchers that override standard TF ops to work w/ RaggedTensors. +__doc__ += ragged_dispatch.register_dispatchers() # pylint: disable=redefined-builtin + # Any symbol that is not referenced (with "@@name") in the module docstring -# above, or included in the "_elementwise_ops" whitelist, will be removed. -_all_util.remove_undocumented(__name__, _elementwise_ops) +# above will be removed. +_all_util.remove_undocumented(__name__) diff --git a/tensorflow/python/ops/ragged/ragged_array_ops.py b/tensorflow/python/ops/ragged/ragged_array_ops.py index 603e39d1dcf..25317ba93ea 100644 --- a/tensorflow/python/ops/ragged/ragged_array_ops.py +++ b/tensorflow/python/ops/ragged/ragged_array_ops.py @@ -308,7 +308,7 @@ def bounding_shape(rt_input, axis=None, name=None): # ragged_gather #=============================================================================== # TODO(edloper): Add an `axis` argument -def gather(params, indices, name=None): +def gather(params, indices, validate_indices=None, axis=0, name=None): """Gathers ragged slices from `params` axis `0` according to `indices`. Returns `RaggedTensor` output, such that: @@ -347,6 +347,8 @@ def gather(params, indices, name=None): indices: The potentially ragged tensor indicating which values to gather. Must have dtype `int32` or `int64`. Values must be in the range `[0, params.shape[0]]`. + validate_indices: Ignored. + axis: Must be zero. name: A name for the operation (optional). Returns: @@ -357,6 +359,9 @@ def gather(params, indices, name=None): Raises: ValueError: If indices.shape.ndims is not known statically. """ + del validate_indices + if not isinstance(axis, int) or axis != 0: + raise ValueError('axis>0 is not supported for ragged gather yet.') with ops.name_scope(name, 'RaggedGather', [params, indices]): params = ragged_factory_ops.convert_to_tensor_or_ragged_tensor( params, name='params') @@ -812,29 +817,29 @@ def boolean_mask(data, mask, keepdims=False, name=None): #=============================================================================== # Concatenation and Stacking #=============================================================================== -def concat(rt_inputs, axis, name=None): +def concat(values, axis, name=None): """Concatenates potentially ragged tensors along one dimension. Given a list of tensors with the same rank `K` (`K >= axis`), returns a rank-`K` `RaggedTensor` `result` such that `result[i0...iaxis]` is the - concatenation of `[rt[i0...iaxis] for rt in rt_inputs]`. + concatenation of `[rt[i0...iaxis] for rt in values]`. Args: - rt_inputs: A list of potentially ragged tensors. May not be empty. All - `rt_inputs` must have the same rank and the same dtype; but unlike + values: A list of potentially ragged tensors. May not be empty. All + `values` must have the same rank and the same dtype; but unlike `tf.concat`, they can have arbitrary shapes. axis: A python integer, indicating the dimension along which to concatenate. (Note: Unlike `tf.concat`, the `axis` parameter must be statically known.) Negative values are supported only if the rank of at least one - `rt_inputs` value is statically known. + `values` value is statically known. name: A name prefix for the returned tensor (optional). Returns: A `RaggedTensor` with rank `K`. - `result.ragged_rank=max(axis, max(rt.ragged_rank for rt in rt_inputs]))`. + `result.ragged_rank=max(axis, max(rt.ragged_rank for rt in values]))`. Raises: - ValueError: If `rt_inputs` is empty, if `axis` is out of bounds or if + ValueError: If `values` is empty, if `axis` is out of bounds or if the input tensors have different ranks. #### Example: @@ -847,35 +852,35 @@ def concat(rt_inputs, axis, name=None): [[1, 2, 6], [3, 4, 5, 7, 8, 9]] ``` """ - if not isinstance(rt_inputs, (list, tuple)): - rt_inputs = [rt_inputs] - with ops.name_scope(name, 'RaggedConcat', rt_inputs): - return _ragged_stack_concat_helper(rt_inputs, axis, stack_values=False) + if not isinstance(values, (list, tuple)): + values = [values] + with ops.name_scope(name, 'RaggedConcat', values): + return _ragged_stack_concat_helper(values, axis, stack_values=False) -def stack(rt_inputs, axis, name=None): +def stack(values, axis, name=None): """Stacks potentially ragged tensors along one dimension. Given a list of tensors with the same rank `K` (`K >= axis`), returns a rank-`K+1` `RaggedTensor` `result` such that `result[i0...iaxis]` is the - list `[rt[i0...iaxis] for rt in rt_inputs]`. + list `[rt[i0...iaxis] for rt in values]`. Args: - rt_inputs: A list of potentially ragged tensors. May not be empty. All - `rt_inputs` must have the same rank and the same dtype; but unlike + values: A list of potentially ragged tensors. May not be empty. All + `values` must have the same rank and the same dtype; but unlike `tf.concat`, they can have arbitrary shapes. axis: A python integer, indicating the dimension along which to stack. (Note: Unlike `tf.stack`, the `axis` parameter must be statically known.) Negative values are supported only if the rank of at least one - `rt_inputs` value is statically known. + `values` value is statically known. name: A name prefix for the returned tensor (optional). Returns: A `RaggedTensor` with rank `K+1`. - `result.ragged_rank=max(axis, max(rt.ragged_rank for rt in rt_inputs]))`. + `result.ragged_rank=max(axis, max(rt.ragged_rank for rt in values]))`. Raises: - ValueError: If `rt_inputs` is empty, if `axis` is out of bounds or if + ValueError: If `values` is empty, if `axis` is out of bounds or if the input tensors have different ranks. #### Example: @@ -888,10 +893,10 @@ def stack(rt_inputs, axis, name=None): [[[1, 2], [6]], [[3, 4, 5], [7, 8, 9]]] ``` """ - if not isinstance(rt_inputs, (list, tuple)): - rt_inputs = [rt_inputs] - with ops.name_scope(name, 'RaggedConcat', rt_inputs): - return _ragged_stack_concat_helper(rt_inputs, axis, stack_values=True) + if not isinstance(values, (list, tuple)): + values = [values] + with ops.name_scope(name, 'RaggedConcat', values): + return _ragged_stack_concat_helper(values, axis, stack_values=True) def _ragged_stack_concat_helper(rt_inputs, axis, stack_values): @@ -1065,22 +1070,22 @@ def _copy_row_shape(rt_inputs, splits): #=============================================================================== # Tiling #=============================================================================== -def tile(rt_input, multiples, name=None): +def tile(input, multiples, name=None): # pylint: disable=redefined-builtin """Constructs a `RaggedTensor` by tiling a given `RaggedTensor`. - The values of `rt_input` are replicated `multiples[i]` times along the + The values of `input` are replicated `multiples[i]` times along the `i`th dimension (for each dimension `i`). For every dimension `axis` in - `rt_input`, the length of each output element in that dimension is the + `input`, the length of each output element in that dimension is the length of corresponding input element multiplied by `multiples[axis]`. Args: - rt_input: A `RaggedTensor`. + input: A `RaggedTensor`. multiples: A 1-D integer `Tensor`. Length must be the same as the number of - dimensions in `rt_input`. + dimensions in `input`. name: A name for the operation (optional). Returns: - A `RaggedTensor` with the same type, rank, and ragged_rank as `rt_input`. + A `RaggedTensor` with the same type, rank, and ragged_rank as `input`. #### Example: ```python @@ -1089,22 +1094,22 @@ def tile(rt_input, multiples, name=None): [[1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3]] ``` """ - with ops.name_scope(name, 'RaggedTile', [rt_input, multiples]): - rt_input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor( - rt_input, name='rt_input') + with ops.name_scope(name, 'RaggedTile', [input, multiples]): + input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor( + input, name='input') multiples = ragged_util.convert_to_int_tensor( multiples, name='multiples', dtype=dtypes.int64) multiples.shape.assert_has_rank(1) - if not ragged_tensor.is_ragged(rt_input): - return array_ops.tile(rt_input, multiples, name) + if not ragged_tensor.is_ragged(input): + return array_ops.tile(input, multiples, name) # If the constant value of `multiples` is available, then we can use it # to skip tiling dimensions where `multiples=1`. const_multiples = tensor_util.constant_value(multiples) return ragged_factory_ops.from_nested_row_splits( - _tile_ragged_values(rt_input, multiples, const_multiples), - _tile_ragged_splits(rt_input, multiples, const_multiples)) + _tile_ragged_values(input, multiples, const_multiples), + _tile_ragged_splits(input, multiples, const_multiples)) def _tile_ragged_values(rt_input, multiples, const_multiples=None): @@ -1240,26 +1245,26 @@ def _tile_ragged_splits(rt_input, multiples, const_multiples=None): #=============================================================================== -def expand_dims(rt_input, axis, name=None): +def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin """Inserts a dimension with shape 1 into a potentially ragged tensor's shape. - Given a potentially ragged tenor `rt_input`, this operation inserts a - dimension with size 1 at the dimension `axis` of `rt_input`'s shape. + Given a potentially ragged tenor `input`, this operation inserts a + dimension with size 1 at the dimension `axis` of `input`'s shape. - * If `rt_input` is a `Tensor`, then this is equivalent to + * If `input` is a `Tensor`, then this is equivalent to `tf.expand_dims`. - * If `rt_input` is ragged, and `axis=0`, then the new dimension will be + * If `input` is ragged, and `axis=0`, then the new dimension will be uniform; but the previously outermost dimension will become ragged. - * If `rt_input` is ragged, and `0 < axis < rt_input.ragged_rank`, then the + * If `input` is ragged, and `0 < axis < input.ragged_rank`, then the new dimension will be ragged. - * If `rt_input` is ragged, and axis >= rt_input.ragged_rank`, then the new + * If `input` is ragged, and axis >= input.ragged_rank`, then the new dimension will be uniform. The following table gives some examples showing how `ragged.expand_dims` impacts the shapes of different input tensors. Ragged dimensions are indicated by enclosing them in parentheses. - rt_input.shape | axis | result.shape + input.shape | axis | result.shape ----------------------- | ---- | ----------------------------- `[D1, D2]` | `0` | `[1, D1, D2]` `[D1, D2]` | `1` | `[D1, 1, D2]` @@ -1271,14 +1276,14 @@ def expand_dims(rt_input, axis, name=None): `[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]` Args: - rt_input: The potentially tensor that should be expanded with a new + input: The potentially tensor that should be expanded with a new dimension. axis: An integer constant indicating where the new dimension should be inserted. name: A name for the operation (optional). Returns: - A tensor with the same values as `rt_input`, with an added dimension of + A tensor with the same values as `input`, with an added dimension of size 1 at `axis`. #### Examples: @@ -1300,24 +1305,24 @@ def expand_dims(rt_input, axis, name=None): TensorShape([2, None, 1]) [[[1], [2]], [[3]]] ``` """ - with ops.name_scope(name, 'RaggedExpandDims', [rt_input]): - rt_input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor( - rt_input, name='rt_input') + with ops.name_scope(name, 'RaggedExpandDims', [input]): + input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor( + input, name='input') - if not ragged_tensor.is_ragged(rt_input): - return array_ops.expand_dims(rt_input, axis) + if not ragged_tensor.is_ragged(input): + return array_ops.expand_dims(input, axis) - ndims = None if rt_input.shape.ndims is None else rt_input.shape.ndims + 1 + ndims = None if input.shape.ndims is None else input.shape.ndims + 1 axis = ragged_util.get_positive_axis(axis, ndims) if axis == 0: - values = rt_input - splits = array_ops.stack([0, nrows(rt_input)]) + values = input + splits = array_ops.stack([0, nrows(input)]) elif axis == 1: - values = rt_input - splits = math_ops.range(nrows(rt_input) + 1) + values = input + splits = math_ops.range(nrows(input) + 1) else: - values = expand_dims(rt_input.values, axis - 1) - splits = rt_input.row_splits + values = expand_dims(input.values, axis - 1) + splits = input.row_splits return ragged_factory_ops.from_row_splits(values, splits) diff --git a/tensorflow/python/ops/ragged/ragged_dispatch.py b/tensorflow/python/ops/ragged/ragged_dispatch.py new file mode 100644 index 00000000000..7f44ac2ec1e --- /dev/null +++ b/tensorflow/python/ops/ragged/ragged_dispatch.py @@ -0,0 +1,441 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operator dispatch for RaggedTensors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import numpy as np + +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import string_ops +from tensorflow.python.ops import variables +from tensorflow.python.ops.ragged import ragged_array_ops +from tensorflow.python.ops.ragged import ragged_factory_ops +from tensorflow.python.ops.ragged import ragged_math_ops +from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.ops.ragged import ragged_tensor_shape +from tensorflow.python.ops.ragged import ragged_util +from tensorflow.python.util import dispatch +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_export +from tensorflow.python.util import tf_inspect + +# @TODO(edloper): Set this to True in the CL that exports RaggedTensors. +_UPDATE_DOCSTRINGS = False + +# Information about an argument to an operation: The name of the argument, its +# position in the argument list, and a boolean flag indicating whether it +# expects a list of tensors. +_ArgInfo = collections.namedtuple('ArgInfo', ['name', 'position', 'is_list']) + + +def _get_arg_infos(func, arg_names): + """Returns an `_ArgInfo` for each argument of `func` specified by `arg_names`. + + Args: + func: The function whose arguments should be described. + arg_names: The names of the arguments to get info for. + + Returns: + A tuple of `_ArgInfo`s. + """ + arg_infos = [] + + # Inspect the func's argspec to find the position of each arg. + arg_spec = tf_inspect.getargspec(func) + for argname in arg_names: + assert isinstance(argname, str) + is_list = argname.startswith('[') and argname.endswith(']') + if is_list: + argname = argname[1:-1] + if argname not in arg_spec.args: + raise ValueError('Argument %r not found function in %s. Args=%s' % + (argname, func, arg_spec.args)) + arg_infos.append(_ArgInfo(argname, arg_spec.args.index(argname), is_list)) + return arg_infos + + +def _is_convertible_to_tensor(value): + """Returns true if `value` is convertible to a `Tensor`.""" + if isinstance(value, + (ops.Tensor, variables.Variable, np.ndarray, int, float, str)): + return True + elif isinstance(value, (sparse_tensor.SparseTensor,)): + return False + else: + try: + ops.convert_to_tensor(value) + return True + except (TypeError, ValueError): + return False + + +class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher): + """OpDispatcher for unary ops that map a base op across ragged values.""" + + def __init__(self, original_op, arg_is_list=False): + self._original_op = original_op + self._arg_is_list = arg_is_list + arg_names = tf_inspect.getfullargspec(original_op)[0] + self._x = arg_names[0] + if _UPDATE_DOCSTRINGS: + original_op.__doc__ = ( + original_op.__doc__.rstrip() + '\n\n' + + ' `{x}` may be a `tf.RaggedTensor`.\n'.format(x=self._x)) + + def handle(self, args, kwargs): + if args: + x, args = args[0], args[1:] + else: + kwargs = kwargs.copy() + x = kwargs.pop(self._x, None) + if x is None: + return self.NOT_SUPPORTED + if self._arg_is_list: + found_ragged = False + for elt in x: + if ragged_tensor.is_ragged(elt): + found_ragged = True + elif not _is_convertible_to_tensor(elt): + return self.NOT_SUPPORTED + if found_ragged: + nested_splits_lists = [ + elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt) + ] + inner_values = [ + elt.inner_values if ragged_tensor.is_ragged(elt) else elt + for elt in x + ] + with ops.control_dependencies( + ragged_util.assert_splits_match(nested_splits_lists)): + return ragged_factory_ops.from_nested_row_splits( + self._original_op(inner_values, *args, **kwargs), + nested_splits_lists[0]) + else: + return self.NOT_SUPPORTED + else: + found_ragged = ragged_tensor.is_ragged(x) + if found_ragged: + mapped_values = self._original_op(x.inner_values, *args, **kwargs) + return x.with_inner_values(mapped_values) + else: + return self.NOT_SUPPORTED + + +class BinaryRaggedElementwiseDispatcher(dispatch.OpDispatcher): + """OpDispatcher for binary ops that map a base op across ragged values. + + Supports broadcasting. + """ + + def __init__(self, original_op): + self._original_op = original_op + arg_names = tf_inspect.getfullargspec(original_op)[0] + self._x = arg_names[0] + self._y = arg_names[1] + if _UPDATE_DOCSTRINGS: + original_op.__doc__ = ( + original_op.__doc__.rstrip() + '\n\n' + + ' `{x}` and `{y}` may be a `tf.RaggedTensor`.\n'.format( + x=self._x, y=self._y)) + + def handle(self, args, kwargs): + # Extract the binary args. + if len(args) > 1: + x = args[0] + y = args[1] + args = args[2:] + elif args: + kwargs = kwargs.copy() + x = args[0] + y = kwargs.pop(self._y, None) + args = args[1:] + else: + kwargs = kwargs.copy() + x = kwargs.pop(self._x, None) + y = kwargs.pop(self._y, None) + + # Bail if we don't have at least one ragged argument. + x_is_ragged = ragged_tensor.is_ragged(x) + y_is_ragged = ragged_tensor.is_ragged(y) + if not (x_is_ragged or y_is_ragged): + return self.NOT_SUPPORTED + + # Convert args to tensors. Bail if conversion fails. + try: + if not x_is_ragged: + x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype) + if not y_is_ragged: + y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype) + except (TypeError, ValueError): + return self.NOT_SUPPORTED + + if ((x_is_ragged and y_is_ragged) or + (x_is_ragged and x.inner_values.shape.ndims <= y.shape.ndims) or + (y_is_ragged and y.inner_values.shape.ndims <= x.shape.ndims)): + bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape( + ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x), + ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y)) + x = ragged_tensor_shape.broadcast_to( + x, bcast_shape, broadcast_inner_dimensions=False) + y = ragged_tensor_shape.broadcast_to( + y, bcast_shape, broadcast_inner_dimensions=False) + + x_values = x.inner_values if ragged_tensor.is_ragged(x) else x + y_values = y.inner_values if ragged_tensor.is_ragged(y) else y + mapped_values = self._original_op(x_values, y_values, *args, **kwargs) + if ragged_tensor.is_ragged(x): + return x.with_inner_values(mapped_values) + else: + return y.with_inner_values(mapped_values) + + +class RaggedDispatcher(dispatch.OpDispatcher): + """OpDispatcher for ragged ops. + + Dispatches to a wrapped op-handler if at least one of the `tensor_args` + arguments is a RaggedTensor or a RaggedTensorValue; and all of the + `tensor_args` arguments are convertible to Tensor or RaggedTensor. + """ + + def __init__(self, original_op, ragged_op, ragged_args): + op_arg_names = tf_inspect.getfullargspec(original_op)[0] + ragged_arg_names = tf_inspect.getfullargspec(ragged_op)[0] + if op_arg_names != ragged_arg_names: + raise AssertionError( + 'Signature must exactly match when overriding %s with %s: %s vs %s' % + (original_op, ragged_op, op_arg_names, ragged_arg_names)) + self._ragged_op = ragged_op + self._ragged_args = _get_arg_infos(ragged_op, ragged_args) + if _UPDATE_DOCSTRINGS: + arg_list = ' and '.join('`%s`' % arg for arg in ragged_args) + original_op.__doc__ = ( + original_op.__doc__.rstrip() + '\n\n' + + ' {0} may be a `tf.RaggedTensor`.\n'.format(arg_list)) + + def handle(self, args, kwargs): + if self.is_supported(args, kwargs): + return self._ragged_op(*args, **kwargs) + else: + return self.NOT_SUPPORTED + + def is_supported(self, args, kwargs): + found_ragged = False + for arg_info in self._ragged_args: + if arg_info.position < len(args): + arg = args[arg_info.position] + else: + arg = kwargs.get(arg_info.name, None) + + if arg_info.is_list: + if not isinstance(arg, (list, tuple)): + return False + for elt in arg: + if ragged_tensor.is_ragged(elt): + found_ragged = True + elif not _is_convertible_to_tensor(elt): + return False + else: + if ragged_tensor.is_ragged(arg): + found_ragged = True + elif not _is_convertible_to_tensor(arg): + return False + return found_ragged + + +def ragged_dispatch(original_op, tensor_args): + + def decorator(ragged_op): + dispatch.RaggedDispatcher(original_op, ragged_op, + tensor_args).register(original_op) + return ragged_op + + return decorator + + +_UNARY_ELEMENTWISE_OPS = [ + array_ops.check_numerics, + array_ops.identity, + array_ops.ones_like, + array_ops.ones_like_v2, + array_ops.zeros_like, + array_ops.zeros_like_v2, + clip_ops.clip_by_value, + math_ops.abs, + math_ops.acos, + math_ops.acosh, + math_ops.angle, + math_ops.asin, + math_ops.asinh, + math_ops.atan, + math_ops.atanh, + math_ops.cast, + math_ops.ceil, + math_ops.conj, + math_ops.cos, + math_ops.cosh, + math_ops.digamma, + math_ops.erf, + math_ops.erfc, + math_ops.exp, + math_ops.expm1, + math_ops.floor, + math_ops.imag, + math_ops.is_finite, + math_ops.is_inf, + math_ops.is_nan, + math_ops.lgamma, + math_ops.log, + math_ops.log1p, + math_ops.log_sigmoid, + math_ops.logical_not, + math_ops.negative, + math_ops.real, + math_ops.reciprocal, + math_ops.rint, + math_ops.round, + math_ops.rsqrt, + math_ops.saturate_cast, + math_ops.sign, + math_ops.sin, + math_ops.sinh, + math_ops.sqrt, + math_ops.square, + math_ops.tan, + parsing_ops.decode_compressed, + string_ops.string_to_number, + string_ops.string_to_hash_bucket, + string_ops.as_string, + string_ops.decode_base64, + string_ops.encode_base64, + string_ops.regex_full_match, + string_ops.regex_replace, + string_ops.string_strip, + string_ops.string_to_hash_bucket, + string_ops.string_to_hash_bucket_fast, + string_ops.string_to_hash_bucket_strong, + string_ops.substr, + string_ops.substr_v2, + string_ops.string_length, + string_ops.string_length_v2, + string_ops.unicode_script, +] + +_UNARY_LIST_ELEMENTWISE_OPS = [ + math_ops.add_n, + string_ops.string_join, +] + +_BINARY_ELEMENTWISE_OPS = [ + math_ops.add, + math_ops.atan2, + math_ops.complex, + math_ops.div_no_nan, + math_ops.divide, + math_ops.equal, + math_ops.floordiv, + math_ops.floormod, + math_ops.greater, + math_ops.greater_equal, + math_ops.less, + math_ops.less_equal, + math_ops.logical_and, + math_ops.logical_or, + math_ops.logical_xor, + math_ops.maximum, + math_ops.minimum, + math_ops.multiply, + math_ops.not_equal, + math_ops.pow, + math_ops.realdiv, + math_ops.squared_difference, + math_ops.subtract, + math_ops.truediv, + math_ops.truncatediv, + math_ops.truncatemod, +] + +# (original_op, ragged_op, ragged_args) +_RAGGED_DISPATCH_OPS = [ + (array_ops.batch_gather, ragged_array_ops.batch_gather, + ['params', 'indices']), + (array_ops.concat, ragged_array_ops.concat, ['values']), + (array_ops.expand_dims_v2, ragged_array_ops.expand_dims, ['input']), + (array_ops.gather_v2, ragged_array_ops.gather, ['params', 'indices']), + (array_ops.gather_nd, ragged_array_ops.gather_nd, ['params', 'indices']), + (array_ops.stack, ragged_array_ops.stack, ['values']), + (array_ops.tile, ragged_array_ops.tile, ['input']), + (array_ops.where, ragged_array_ops.where, ['condition', 'x', 'y']), + (math_ops.unsorted_segment_sum, ragged_math_ops.segment_sum, + ['data', 'segment_ids']), + (math_ops.unsorted_segment_prod, ragged_math_ops.segment_prod, + ['data', 'segment_ids']), + (math_ops.unsorted_segment_min, ragged_math_ops.segment_min, + ['data', 'segment_ids']), + (math_ops.unsorted_segment_max, ragged_math_ops.segment_max, + ['data', 'segment_ids']), + (math_ops.unsorted_segment_mean, ragged_math_ops.segment_mean, + ['data', 'segment_ids']), + (math_ops.unsorted_segment_sqrt_n, ragged_math_ops.segment_sqrt_n, + ['data', 'segment_ids']), + (math_ops.reduce_sum, ragged_math_ops.reduce_sum, ['input_tensor']), + (math_ops.reduce_prod, ragged_math_ops.reduce_prod, ['input_tensor']), + (math_ops.reduce_min, ragged_math_ops.reduce_min, ['input_tensor']), + (math_ops.reduce_max, ragged_math_ops.reduce_max, ['input_tensor']), + (math_ops.reduce_mean, ragged_math_ops.reduce_mean, ['input_tensor']), + (math_ops.reduce_any, ragged_math_ops.reduce_any, ['input_tensor']), + (math_ops.reduce_all, ragged_math_ops.reduce_all, ['input_tensor']), +] + + +def register_dispatchers(): + """Constructs & registers OpDispatchers for ragged ops.""" + + op_list = ( + _UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS + + _BINARY_ELEMENTWISE_OPS + [x[0] for x in _RAGGED_DISPATCH_OPS]) + for op in op_list: + _, undecorated_op = tf_decorator.unwrap(op) + if not hasattr(undecorated_op, tf_export.API_ATTRS['tensorflow'].names): + raise AssertionError('Expected %s to be an exported symbol ' + '(while adding a RaggedTensor dispatcher)') + + for op in _UNARY_ELEMENTWISE_OPS: + UnaryRaggedElementwiseDispatcher(op).register(op) + + for op in _UNARY_LIST_ELEMENTWISE_OPS: + UnaryRaggedElementwiseDispatcher(op, True).register(op) + + for op in _BINARY_ELEMENTWISE_OPS: + BinaryRaggedElementwiseDispatcher(op).register(op) + + for (original_op, ragged_op, args) in _RAGGED_DISPATCH_OPS: + RaggedDispatcher(original_op, ragged_op, args).register(original_op) + + docstring = ( + '\n\n### Additional ops that support `RaggedTensor`\n\n' + '\n'.join([ + '* `tf.%s`' % tf_export.get_canonical_name_for_symbol(op) + for op in op_list + ])) + + return docstring diff --git a/tensorflow/python/ops/ragged/ragged_elementwise_ops_test.py b/tensorflow/python/ops/ragged/ragged_dispatch_test.py similarity index 77% rename from tensorflow/python/ops/ragged/ragged_elementwise_ops_test.py rename to tensorflow/python/ops/ragged/ragged_dispatch_test.py index 305a96df9cc..2bb10adce0e 100644 --- a/tensorflow/python/ops/ragged/ragged_elementwise_ops_test.py +++ b/tensorflow/python/ops/ragged/ragged_dispatch_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for ragged.elementwise_ops.""" +"""Tests for RaggedTensor operator dispatch.""" from __future__ import absolute_import from __future__ import division @@ -21,106 +21,108 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np - from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import ragged +from tensorflow.python.ops import string_ops from tensorflow.python.platform import googletest -# Constants listing various op types to test. Each elementwise operation +# Constants listing various op types to test. Each operation # should be included in at least one list below, or tested separately if # necessary (e.g., because it expects additional arguments). UNARY_FLOAT_OPS = [ - ragged.abs, - ragged.acos, - ragged.acosh, - ragged.angle, - ragged.asin, - ragged.asinh, - ragged.atan, - ragged.atanh, - ragged.ceil, - ragged.conj, - ragged.cos, - ragged.cosh, - ragged.digamma, - ragged.erf, - ragged.erfc, - ragged.exp, - ragged.expm1, - ragged.floor, - ragged.imag, - ragged.is_finite, - ragged.is_inf, - ragged.is_nan, - ragged.lgamma, - ragged.log, - ragged.log1p, - ragged.log_sigmoid, - ragged.negative, - ragged.real, - ragged.reciprocal, - ragged.rint, - ragged.round, - ragged.rsqrt, - ragged.sign, - ragged.sin, - ragged.sinh, - ragged.sqrt, - ragged.square, - ragged.tan, - ragged.as_string, - ragged.identity, - ragged.ones_like, - ragged.zeros_like, + math_ops.abs, + math_ops.acos, + math_ops.acosh, + math_ops.angle, + math_ops.asin, + math_ops.asinh, + math_ops.atan, + math_ops.atanh, + math_ops.ceil, + math_ops.conj, + math_ops.cos, + math_ops.cosh, + math_ops.digamma, + math_ops.erf, + math_ops.erfc, + math_ops.exp, + math_ops.expm1, + math_ops.floor, + math_ops.imag, + math_ops.is_finite, + math_ops.is_inf, + math_ops.is_nan, + math_ops.lgamma, + math_ops.log, + math_ops.log1p, + math_ops.log_sigmoid, + math_ops.negative, + math_ops.real, + math_ops.reciprocal, + math_ops.rint, + math_ops.round, + math_ops.rsqrt, + math_ops.sign, + math_ops.sin, + math_ops.sinh, + math_ops.sqrt, + math_ops.square, + math_ops.tan, + array_ops.identity, + array_ops.ones_like, + array_ops.zeros_like, ] UNARY_BOOL_OPS = [ - ragged.logical_not, + math_ops.logical_not, ] UNARY_STRING_OPS = [ - ragged.decode_base64, - ragged.encode_base64, - ragged.string_strip, - ragged.decode_compressed, + string_ops.decode_base64, + string_ops.encode_base64, + string_ops.string_strip, + parsing_ops.decode_compressed, ] BINARY_FLOAT_OPS = [ - ragged.add, - ragged.atan2, - ragged.complex, - ragged.div, - ragged.div_no_nan, - ragged.divide, - ragged.equal, - ragged.floordiv, - ragged.floormod, - ragged.greater, - ragged.greater_equal, - ragged.less, - ragged.less_equal, - ragged.maximum, - ragged.minimum, - ragged.multiply, - ragged.not_equal, - ragged.pow, - ragged.realdiv, - ragged.squared_difference, - ragged.subtract, - ragged.truediv, + math_ops.add, + math_ops.atan2, + math_ops.complex, + math_ops.div_no_nan, + math_ops.divide, + math_ops.equal, + math_ops.floordiv, + math_ops.floormod, + math_ops.greater, + math_ops.greater_equal, + math_ops.less, + math_ops.less_equal, + math_ops.maximum, + math_ops.minimum, + math_ops.multiply, + math_ops.not_equal, + math_ops.pow, + math_ops.realdiv, + math_ops.squared_difference, + math_ops.subtract, + math_ops.truediv, ] BINARY_BOOL_OPS = [ - ragged.logical_and, - ragged.logical_or, - ragged.logical_xor, + math_ops.logical_and, + math_ops.logical_or, + math_ops.logical_xor, ] UNARY_INT_OPS = [ - ragged.unicode_script, + string_ops.unicode_script, ] BINARY_INT_OPS = [ - ragged.truncatediv, - ragged.truncatemod, + math_ops.truncatediv, + math_ops.truncatemod, ] @@ -171,50 +173,49 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, [{'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]), 'op': op} for op in UNARY_STRING_OPS] + [ - {'op': ragged.clip_by_value, + {'op': clip_ops.clip_by_value, 'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]), 'clip_value_min': 0.1, 'clip_value_max': 4.0}, - {'op': ragged.cast, + {'op': math_ops.cast, 'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]), 'dtype': dtypes.int32}, - {'op': ragged.saturate_cast, + {'op': math_ops.saturate_cast, 'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]), 'dtype': dtypes.int32}, - {'op': ragged.string_to_hash_bucket, + {'op': string_ops.string_to_hash_bucket, 'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]), 'num_buckets': 1000}, - {'op': ragged.string_to_hash_bucket_fast, + {'op': string_ops.string_to_hash_bucket_fast, 'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]), 'num_buckets': 1000}, - {'op': ragged.string_to_hash_bucket_strong, + {'op': string_ops.string_to_hash_bucket_strong, 'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]), 'num_buckets': 1000, 'key': [1231, 12512]}, - {'op': ragged.string_to_number, + {'op': string_ops.string_to_number, 'x': ragged.constant_value([['-2.0', '3.0'], ['-3.0']])}, - {'op': ragged.regex_full_match, + {'op': string_ops.regex_full_match, 'x': ragged.constant_value([['hello', '123'], ['1+1']]), 'pattern': r'\w+'}, - {'op': ragged.regex_replace, + {'op': string_ops.regex_replace, 'x': ragged.constant_value([['hello', '123'], ['1+1']]), 'pattern': r'\d', 'rewrite': '#'}, - {'op': ragged.substr, + {'op': string_ops.substr, 'x': ragged.constant_value([['hello', '123'], ['1+1']]), 'pos': 2, 'len': 3}, - {'op': ragged.check_numerics, + {'op': array_ops.check_numerics, 'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]), 'message': 'check-numerics'}, ] ) # pyformat: disable - def testUnaryOp(self, x, op=ragged.abs, **extra_args): + def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args): x = ragged.convert_to_tensor_or_ragged_tensor(x) result = op(x, **extra_args) # Run the wrapped op on the dense values, for comparison. dense_x = x.inner_values if isinstance(x, ragged.RaggedTensor) else x - expected_flat_values = array_ops.reshape( - op.__wrapped__(dense_x, **extra_args), [-1]) + expected_flat_values = array_ops.reshape(op(dense_x, **extra_args), [-1]) with self.test_session(): # Check that the result has the expected shape. @@ -285,12 +286,17 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, #===================================================================== {'x': ragged.constant_value([[[1, 2], [3], [4]], [[], [5, 7, 8]]]), 'y': ragged.constant_value([[[3, 8], [2], [5]], [[], [1, 9, 8]]]), - 'use_kwargs': True}, + 'use_kwargs': ('x', 'y')}, {'x': ragged.constant_value([[[1, 2]], [[3, 4], [5, 6], [7, 8]]], ragged_rank=1), 'y': ragged.constant_value([[[9, 3]], [[5, 2], [3, 4], [7, 6]]], ragged_rank=1), - 'use_kwargs': True}, + 'use_kwargs': ('x', 'y')}, + {'x': ragged.constant_value([[[1, 2]], [[3, 4], [5, 6], [7, 8]]], + ragged_rank=1), + 'y': ragged.constant_value([[[9, 3]], [[5, 2], [3, 4], [7, 6]]], + ragged_rank=1), + 'use_kwargs': ('x',)}, ] + #========================================================================= # Test each unary op. @@ -306,16 +312,16 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, [{'x': ragged.constant_value([[True, True], [False]]), 'y': ragged.constant_value([[False, True], [False]]), 'op': op} - for op in BINARY_BOOL_OPS] + - [ - ] + for op in BINARY_BOOL_OPS] ) # pyformat: disable - def testBinaryOp(self, x, y, op=ragged.add, **extra_args): - use_kwargs = extra_args.pop('use_kwargs', False) + def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args): + use_kwargs = extra_args.pop('use_kwargs', ()) x = ragged.convert_to_tensor_or_ragged_tensor(x) y = ragged.convert_to_tensor_or_ragged_tensor(y) - if use_kwargs: + if 'x' in use_kwargs and 'y' in use_kwargs: result = op(x=x, y=y, **extra_args) + elif 'y' in use_kwargs: + result = op(x, y=y, **extra_args) else: result = op(x, y, **extra_args) @@ -323,7 +329,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, dense_x = x.inner_values if isinstance(x, ragged.RaggedTensor) else x dense_y = y.inner_values if isinstance(y, ragged.RaggedTensor) else y expected_flat_values = array_ops.reshape( - op.__wrapped__(dense_x, dense_y, **extra_args), [-1]) + op(dense_x, dense_y, **extra_args), [-1]) with self.test_session(): # Check that the result has the expected shape. @@ -358,16 +364,17 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, ragged.constant_value([[[2, 9], [12]], [[8]]])), 'use_kwargs': True}, ] + [ - {'op': ragged.add_n, + {'op': math_ops.add_n, 'inputs': (ragged.constant_value([[1, 3], [-3]]), ragged.constant_value([[4, 7], [88]]), ragged.constant_value([[2, 9], [12]]))}, - {'op': ragged.string_join, + {'op': string_ops.string_join, 'inputs': (ragged.constant_value([['a', 'b'], ['c']]), ragged.constant_value([['foo', 'bar'], ['baz']]), ragged.constant_value([['2', '9'], ['12']]))}, ]) # pyformat: disable - def testListValuedOp(self, inputs, op=ragged.add_n, **extra_args): + def testListValuedElementwiseOp(self, inputs, op=math_ops.add_n, + **extra_args): use_kwargs = extra_args.pop('use_kwargs', False) inputs = [ragged.convert_to_tensor_or_ragged_tensor(x) for x in inputs] if use_kwargs: @@ -381,7 +388,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, for x in inputs ] expected_flat_values = array_ops.reshape( - op.__wrapped__(dense_inputs, **extra_args), [-1]) + op(dense_inputs, **extra_args), [-1]) with self.test_session(): # Check that the result has the expected shape. @@ -395,13 +402,13 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, self.assertAllEqual(expected_flat_values, result_flat_values) @test_util.run_deprecated_v1 - def testUnknownRankError(self): + def testElementwiseOpUnknownRankError(self): x = ragged.constant([[1, 2], [3]]) y = ragged.from_row_splits( array_ops.placeholder_with_default([1, 2, 3], shape=None), x.row_splits) - with self.assertRaisesRegexp( - ValueError, r'Unable to broadcast: unknown rank'): - ragged.add(x, y) + with self.assertRaisesRegexp(ValueError, + r'Unable to broadcast: unknown rank'): + math_ops.add(x, y) @parameterized.parameters([ dict( @@ -417,26 +424,31 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, y=ragged.constant_value([[1]]), expected=[[[2]]]), ]) - def testBroadcastAdd(self, x, y, expected): + def testElementwiseOpBroadcast(self, x, y, expected): x = ragged.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32) y = ragged.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32) result = x + y with self.cached_session(): self.assertEqual(result.eval().tolist(), expected) - def testShapeMismatch(self): + def testElementwiseOpShapeMismatch(self): x = ragged.constant([[1, 2, 3], [4, 5]]) y = ragged.constant([[1, 2, 3], [4, 5, 6]]) with self.assertRaisesRegexp(errors.InvalidArgumentError, 'Incompatible shapes'): with self.cached_session(): - ragged.add(x, y).eval() + math_ops.add(x, y).eval() - def testDocstring(self): - self.assertRegexpMatches( - ragged.add.__doc__, - 'Ragged version of the elementwise operation `tf.math.add`') - self.assertEqual(ragged.add.__name__, 'add') + def testBinaryOpSparseAndRagged(self): + x = ragged.constant([[1, 2, 3], [4, 5]]) + y = sparse_tensor.SparseTensor([[0, 0], [0, 1], [2, 0]], [1, 2, 3], [3, 2]) + with self.assertRaises(TypeError): + with self.cached_session(): + math_ops.add(x, y).eval() + + with self.assertRaises(TypeError): + with self.cached_session(): + math_ops.add_n([x, y]).eval() if __name__ == '__main__': diff --git a/tensorflow/python/ops/ragged/ragged_elementwise_ops.py b/tensorflow/python/ops/ragged/ragged_elementwise_ops.py deleted file mode 100644 index 59b7dd16617..00000000000 --- a/tensorflow/python/ops/ragged/ragged_elementwise_ops.py +++ /dev/null @@ -1,389 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Elementwise operations for RaggedTensors.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import string_ops -from tensorflow.python.ops.ragged import ragged_factory_ops -from tensorflow.python.ops.ragged import ragged_tensor -from tensorflow.python.ops.ragged import ragged_tensor_shape -from tensorflow.python.util import tf_decorator -from tensorflow.python.util import tf_export -from tensorflow.python.util import tf_inspect - -# Information about an argument to an operation: The name of the argument, its -# position in the argument list, and a boolean flag indicating whether it -# expects a list of tensors. -_ArgInfo = collections.namedtuple('ArgInfo', ['name', 'position', 'is_list']) - - -def make_elementwise_op(op, *elementwise_args): - """Returns a ragged-tensor version of the elementwise operation `op`. - - The returned operation will: - - 1. Broadcast the elementwise arguments to have a compatible shape. - An exception is raised if the tensors not broadcast-compatible. - 2. Call `op`, substituting the dense values of the broadcasted tensor for - each elementwise argument. - 3. Return a potentially ragged tensor constructed from the output of `op` - and the broadcasted tensors' nested row splits. - - For example, you can construct a ragged-tensor version of the standard - operation `tf.add` by calling `make_elementwise_op(tf.add, 'x', 'y')`. - - Args: - op: The operation to wrap. - *elementwise_args: The names of arguments to `op` that are treated as - elementwise. Arguments that take a list of tensors should have their - names wrapped in square brackets (e.g. "[inputs]"). - - Raises: - ValueError: If any name specified in `elementwise_args` is not the name - of an argument to `op`. - """ - elementwise_arg_infos = _get_arg_infos(op, elementwise_args) - - def ragged_op(*args, **kwargs): - """Ragged version of `op`.""" - args = list(args) - - # Collect all of the elementwise arguments, and put them in a single - # dict whose values are the (potentially ragged) tensors that need to - # be broadcast to a common shape. The keys of this dict are tuples - # (argkey, index), where argkey is an int for poitional args or a string - # for keyword args; and index is None for non-list args and the index of the - # tensor for list args. - elementwise_args = {} - for (name, position, is_list) in elementwise_arg_infos.values(): - if position < len(args): - if is_list: - args[position] = list(args[position]) - for (index, arg) in enumerate(args[position]): - elementwise_args[position, index] = arg - else: - elementwise_args[position, None] = args[position] - elif name in kwargs: - if is_list: - kwargs[name] = list(kwargs[name]) - for (i, arg) in enumerate(kwargs[name]): - elementwise_args[name, i] = arg - else: - elementwise_args[name, None] = kwargs[name] - - with ops.name_scope(None, op.__name__, elementwise_args.values()): - # Convert all inputs to tensors or ragged tensors. - for ((key, index), tensor) in elementwise_args.items(): - argname = elementwise_arg_infos[key].name - converted = ragged_factory_ops.convert_to_tensor_or_ragged_tensor( - tensor, name=argname) - elementwise_args[key, index] = converted - - # Broadcast tensors to have compatible shapes. - broadcast_args, result_splits, broadcast_check_ops = \ - _broadcast_elementwise_args(elementwise_args) - - # Replace tensor arguments with their dense values. - for ((key, index), tensor) in broadcast_args.items(): - if ragged_tensor.is_ragged(tensor): - if isinstance(key, int) and index is None: - args[key] = tensor.inner_values - elif isinstance(key, int) and index is not None: - args[key][index] = tensor.inner_values - elif isinstance(key, str) and index is None: - kwargs[key] = tensor.inner_values - else: - assert isinstance(key, str) and index is not None - kwargs[key][index] = tensor.inner_values - - # Call the elementwise op on the broadcasted dense values. - with ops.control_dependencies(broadcast_check_ops): - result_values = op(*args, **kwargs) - - # Restore any ragged dimensions that we stripped off, and return the - # result. - return ragged_factory_ops.from_nested_row_splits(result_values, - result_splits) - - # Construct the docstring. - op_name = tf_export.get_canonical_name_for_symbol(op) - assert op_name is not None, op - argnames = ', '.join('`%s`' % s.strip('[]') for s in elementwise_args) - docstring = _ELEMENTWISE_DOCSTRING % dict(op_name=op_name, argnames=argnames) - - # Update name, docstring, signature, etc., for the wrapper, and return it. - return tf_decorator.make_decorator(op, ragged_op, decorator_doc=docstring) - - -_ELEMENTWISE_DOCSTRING = """\ -Ragged version of the elementwise operation `tf.%(op_name)s`. - - The following elementwise arguments may be ragged or dense: - %(argnames)s. - These arguments will be broadcast to a compatible shape if necessary. - """ - - -def _get_arg_infos(func, elementwise_args): - """Returns `_ArgInfo`s for each `func` arg specified by `elementwise_args`. - - Args: - func: The function whose arguments should be described. - elementwise_args: The names of the arguments to get info for. - - Returns: - A dictionary that maps both names and positions of arguments to - `_ArgInfo` tuples. - """ - arg_infos = {} - - # Inspect the func's argspec to find the position of each arg. - arg_spec = tf_inspect.getargspec(func) - for argname in elementwise_args: - assert isinstance(argname, str) - is_list = argname.startswith('[') and argname.endswith(']') - if is_list: - argname = argname[1:-1] - assert argname in arg_spec.args, (func, argname, arg_spec.args) - arg_info = _ArgInfo(argname, arg_spec.args.index(argname), is_list) - arg_infos[arg_info.name] = arg_info - arg_infos[arg_info.position] = arg_info - return arg_infos - - -def _broadcast_elementwise_args(elementwise_args): - """Broadcasts the values of `elementwise_args` to have compatible shapes. - - Args: - elementwise_args: A dictionary whose keys are potentially ragged tensors. - - Returns: - A tuple `(broadcast_args, broadcast_splits, checks)` where: - - * `broadcast_args` is a dictionary with the same keys as - `elementwise_args`, mapping to broadcasted tensors. - * `broadcast_splits` is the broadcasted nested row splits. - * `checks` is a possibly empty tuple of assertion operations that should - be added as control dependencies. - - Raises: - ValueError: If broadcasting fails. - """ - # No elementwise arguments were used: nothing to do! - if not elementwise_args: - return elementwise_args, (), () - - # A single elementwise argument was used: no broadcasting necessary. - if len(elementwise_args) == 1: - arg = list(elementwise_args.values())[0] - if ragged_tensor.is_ragged(arg): - return elementwise_args, arg.nested_row_splits, () - else: - return elementwise_args, (), () - - # Multiple elementwise arguments. - else: - is_ragged = [ragged_tensor.is_ragged(t) for t in elementwise_args.values()] - if not any(is_ragged): - return elementwise_args, (), () - - # If we have a single ragged tensor plus a set of scalars, then we can - # rely on the underlying elementwise op to do broadcasting. - if (sum(is_ragged) == 1 and - all((ragged_tensor.is_ragged(t) or t.shape.ndims == 0) - for t in elementwise_args.values())): - nested_splits_lists = [ - t.nested_row_splits - for t in elementwise_args.values() - if ragged_tensor.is_ragged(t)][0] - return elementwise_args, nested_splits_lists, () - - else: - # Get the shapes of all the elementwise arguments. - shapes = [ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(t) - for t in elementwise_args.values()] - - # Broadcast the shapes to all have the same rank (the max rank). - ranks = [t.shape.ndims for t in elementwise_args.values()] - if any(rank is None for rank in ranks): - raise ValueError('Unable to broadcast: unknown rank') - broadcast_rank = max(ranks) - shapes = [shape.broadcast_to_rank(broadcast_rank) for shape in shapes] - - # For each dimension, broadcast the shapes to be compatible. - for axis in range(broadcast_rank): - # For each i, broadcast shape[i+1] to be compatible with shape[i]; and - # then finally broadcast shape[0] to be compatible with shape[-1]. - for i in range(len(shapes)): - j = (i + 1) % len(shapes) - dim_size = shapes[i].dimension_size(axis) - shapes[j] = shapes[j].broadcast_dimension(axis, dim_size) - broadcast_shape = shapes[0] - - # Broadcast every elementwise arg to the shape that we calculated. - elementwise_args = dict([ - (key, ragged_tensor_shape.broadcast_to(t, broadcast_shape, False)) - for (key, t) in elementwise_args.items()]) - nested_splits_lists = list(elementwise_args.values())[0].nested_row_splits - return elementwise_args, nested_splits_lists, () - - -# A list of symbols that should be exported in the "ragged" package. -_symbols_to_export = [] - - -def _add_elementwise_ops_to_this_module(specs, verbose=False): - """Adds ragged versions of the given ops to this module. - - Args: - specs: A list of tuples containing the arguments for `make_elementwise_op`. - verbose: If true, then display each op that gets added. - """ - for spec in specs: - original_op = spec[0] - ragged_op = make_elementwise_op(*spec) - canonical_name = tf_export.get_canonical_name_for_symbol(original_op) - if '.' not in canonical_name: - op_name = canonical_name - else: - op_name = original_op.__name__ - - # Temporary hack (will be removed once dispatch is added for RaggedTensors): - if op_name == 'neg': op_name = 'negative' - - if verbose: - print('Adding ragged_elementwise_op: tf.ragged.%s (based on tf.%s)' % - (op_name, canonical_name)) - globals()[op_name] = ragged_op - _symbols_to_export.append(op_name) - - -# A list of tuples containing arguments for `make_elementwise_op`, for each -# elementwise operation that should have a ragged version built. Each tuple -# contains a standard `Tensor` operation, and the names of any arguments -# that are processed in elementwise fashion. -_TF_ELEMENTWISE_OPS = [ - # Unary math operations. - (clip_ops.clip_by_value, 't'), - (math_ops.abs, 'x'), - (math_ops.acos, 'x'), - (math_ops.acosh, 'x'), - (math_ops.angle, 'input'), - (math_ops.asin, 'x'), - (math_ops.asinh, 'x'), - (math_ops.atan, 'x'), - (math_ops.atanh, 'x'), - (math_ops.cast, 'x'), - (math_ops.ceil, 'x'), - (math_ops.conj, 'x'), - (math_ops.cos, 'x'), - (math_ops.cosh, 'x'), - (math_ops.digamma, 'x'), - (math_ops.erf, 'x'), - (math_ops.erfc, 'x'), - (math_ops.exp, 'x'), - (math_ops.expm1, 'x'), - (math_ops.floor, 'x'), - (math_ops.imag, 'input'), - (math_ops.is_finite, 'x'), - (math_ops.is_inf, 'x'), - (math_ops.is_nan, 'x'), - (math_ops.lgamma, 'x'), - (math_ops.log, 'x'), - (math_ops.log1p, 'x'), - (math_ops.log_sigmoid, 'x'), - (math_ops.logical_not, 'x'), - (math_ops.negative, 'x'), - (math_ops.real, 'input'), - (math_ops.reciprocal, 'x'), - (math_ops.rint, 'x'), - (math_ops.round, 'x'), - (math_ops.rsqrt, 'x'), - (math_ops.saturate_cast, 'value'), - (math_ops.sign, 'x'), - (math_ops.sin, 'x'), - (math_ops.sinh, 'x'), - (math_ops.sqrt, 'x'), - (math_ops.square, 'x'), - (math_ops.tan, 'x'), - - # Binary math operations - (math_ops.add, 'x', 'y'), - (math_ops.atan2, 'y', 'x'), - (math_ops.complex, 'real', 'imag'), - (math_ops.div, 'x', 'y'), - (math_ops.div_no_nan, 'x', 'y'), - (math_ops.divide, 'x', 'y'), - (math_ops.equal, 'x', 'y'), - (math_ops.floordiv, 'x', 'y'), - (math_ops.floormod, 'x', 'y'), - (math_ops.greater, 'x', 'y'), - (math_ops.greater_equal, 'x', 'y'), - (math_ops.less, 'x', 'y'), - (math_ops.less_equal, 'x', 'y'), - (math_ops.logical_and, 'x', 'y'), - (math_ops.logical_or, 'x', 'y'), - (math_ops.logical_xor, 'x', 'y'), - (math_ops.maximum, 'x', 'y'), - (math_ops.minimum, 'x', 'y'), - (math_ops.multiply, 'x', 'y'), - (math_ops.not_equal, 'x', 'y'), - (math_ops.pow, 'x', 'y'), - (math_ops.realdiv, 'x', 'y'), - (math_ops.squared_difference, 'x', 'y'), - (math_ops.subtract, 'x', 'y'), - (math_ops.truediv, 'x', 'y'), - (math_ops.truncatediv, 'x', 'y'), - (math_ops.truncatemod, 'x', 'y'), - - # N-ary math operations - (math_ops.add_n, '[inputs]'), - - # String operations - (string_ops.as_string, 'input'), - (string_ops.decode_base64, 'input'), - (string_ops.encode_base64, 'input'), - (string_ops.regex_full_match, 'input'), - (string_ops.regex_replace, 'input'), - (string_ops.string_join, '[inputs]'), - (string_ops.string_strip, 'input'), - (string_ops.string_to_hash_bucket, 'input'), - (string_ops.string_to_hash_bucket_fast, 'input'), - (string_ops.string_to_hash_bucket_strong, 'input'), - (string_ops.substr, 'input'), - (string_ops.unicode_script, 'input'), - - # Array ops - (array_ops.check_numerics, 'tensor'), - (array_ops.identity, 'input'), - (array_ops.ones_like, 'tensor'), - (array_ops.zeros_like, 'tensor'), - - # Parsing ops - (parsing_ops.decode_compressed, 'bytes'), - (parsing_ops.string_to_number, 'string_tensor'), -] -_add_elementwise_ops_to_this_module(_TF_ELEMENTWISE_OPS) - diff --git a/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py b/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py index 7a8603c949a..ecd78a91b2d 100644 --- a/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from absl.testing import parameterized +import numpy as np from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util @@ -55,7 +56,7 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase): ), # [d1, (d2)] -> [d1, (d2)] dict( - fn=lambda x: x+1, + fn=lambda x: x + np.int64(1), elems=[[1, 2, 3], [4, 5], [6, 7]], expected_output=[[2, 3, 4], [5, 6], [7, 8]], dtype=dtypes.int64, @@ -64,7 +65,7 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase): ), # [d1, (d2), d3] -> [d1, (d2), d3] dict( - fn=lambda x: x+1, + fn=lambda x: x + np.int64(1), elems=[[[1, 2], [3, 4]], [], [[5, 6], [7, 8], [9, 0]]], elems_ragged_rank=1, expected_ragged_rank=1, @@ -131,7 +132,7 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase): ), # [d1, (d2), (d3), (d4a), (d5)] -> [d1, (d2), (d3), (d4b), (d5)] dict( - fn=lambda x: ragged.add(x, 1), + fn=lambda x: x + np.int64(1), elems=[[[[[1, 2, 3]], [[4], [5]]]], [[[[6, 7]]], [[[8], []]]]], expected_output=[[[[[2, 3, 4]], [[5], [6]]]], [[[[7, 8]]], [[[9], []]]]], @@ -196,8 +197,8 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase): def _increment(f): return { - 'batman': ragged.add(f['batman'], 1), - 'robin': ragged.add(f['robin'], 1), + 'batman': f['batman'] + 1, + 'robin': f['robin'] + 1, } output = ragged.map_fn( diff --git a/tensorflow/python/ops/ragged/ragged_math_ops.py b/tensorflow/python/ops/ragged/ragged_math_ops.py index 857b8dbfa36..d661563a9f0 100644 --- a/tensorflow/python/ops/ragged/ragged_math_ops.py +++ b/tensorflow/python/ops/ragged/ragged_math_ops.py @@ -143,8 +143,11 @@ Computes the %(combination)s along segments of a RaggedTensor. """ -def _ragged_segment_aggregate(unsorted_segment_op, data, segment_ids, - num_segments, name=None): +def _ragged_segment_aggregate(unsorted_segment_op, + data, + segment_ids, + num_segments, + name=None): """Aggregates along segments of a RaggedTensor using `unsorted_segment_op`. Returns a RaggedTensor `output` with `num_segments` rows, where the row @@ -212,12 +215,11 @@ def _ragged_segment_aggregate(unsorted_segment_op, data, segment_ids, assert output_row_lengths.dtype == dtypes.int64 # Build the splits tensor for the output RaggedTensor. - output_splits = array_ops.concat( - [ - array_ops.zeros([1], dtypes.int64), - math_ops.cumsum(output_row_lengths) - ], - axis=0) + output_splits = array_ops.concat([ + array_ops.zeros([1], dtypes.int64), + math_ops.cumsum(output_row_lengths) + ], + axis=0) # For each row in `data`, find the start & limit position where that row's # values will be aggregated in output.values. @@ -311,7 +313,7 @@ _set_ragged_segment_docstring(segment_sqrt_n, 'sum divided by sqrt(N)', _RAGGED_REDUCE_DOCSTRING = """\ Computes the %(combination)s of elements across dimensions of a `RaggedTensor`. - Reduces `rt_input` along the dimensions given in `axis` by taking the + Reduces `input_tensor` along the dimensions given in `axis` by taking the %(combination)s of values. If a reduced dimension has no elements for some index, then the value for that index will be %(default)s. @@ -319,18 +321,18 @@ Computes the %(combination)s of elements across dimensions of a `RaggedTensor`. `axis` is not specified, then all dimensions are reduced, and a scalar value is returned. Args: - rt_input: A `RaggedTensor` containing the values to be %(combined)s. + input_tensor: A `RaggedTensor` containing the values to be %(combined)s. axis: The dimensions to reduce. May be `None` (to reduce all axes), an `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce a given set of axes), or a `Tensor` with a constant value. Must be in - the range `[0, rt_input.rank]`. + the range `[0, input_tensor.rank]`. name: A name prefix for the returned tensor (optional). Returns: A `RaggedTensor` containing the %(combined)s values. The returned tensor has the same dtype as `data`, and its shape is given by removing the - dimensions specified in `axis` from `rt_input.shape`. The `ragged_rank` + dimensions specified in `axis` from `input_tensor.shape`. The `ragged_rank` of the returned tensor is given by substracting any ragged dimensions - specified in `axis` from `rt_input.ragged_rank`. + specified in `axis` from `input_tensor.ragged_rank`. Raises: ValueError: If `axis` contains a `Tensor` whose value is not constant. ####Example: @@ -387,7 +389,11 @@ _RAGGED_REDUCE_ANY_EXAMPLE = """ """ -def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis, +def _ragged_reduce_aggregate(reduce_op, + unsorted_segment_op, + rt_input, + axis, + keepdims, name=None): """Aggregates across axes of a RaggedTensor using the given `Tensor` ops. @@ -412,6 +418,7 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis, `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce a given set of axes), or a `Tensor` with a constant value. Must be in the range `[0, rt_input.rank)`. + keepdims: If true, retains reduced dimensions with length 1. name: A name prefix for the returned tensor (optional). Returns: @@ -426,6 +433,9 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis, if not ragged_tensor.is_ragged(rt_input): return reduce_op(rt_input, axis, name=name) + if keepdims: + raise ValueError('keepdims=True is not supported for RaggedTensors.') + if isinstance(axis, ops.Tensor): axis = tensor_util.constant_value(axis) if axis is None: @@ -448,9 +458,9 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis, # once will probably require a nontrivial c++ op. axis = sorted(axis) inner_reduced = _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, - rt_input, axis[-1]) + rt_input, axis[-1], keepdims) return _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, - inner_reduced, axis[:-1]) + inner_reduced, axis[:-1], keepdims) axis = ragged_util.get_positive_axis(axis, rt_input.shape.ndims) @@ -476,48 +486,48 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis, # sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N] return rt_input.with_values( _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, - rt_input.values, axis - 1)) + rt_input.values, axis - 1, keepdims)) -def reduce_sum(rt_input, axis=None, name=None): +def reduce_sum(input_tensor, axis=None, keepdims=None, name=None): """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" return _ragged_reduce_aggregate(math_ops.reduce_sum, - math_ops.unsorted_segment_sum, rt_input, axis, - name or 'RaggedReduceSum') + math_ops.unsorted_segment_sum, input_tensor, + axis, keepdims, name or 'RaggedReduceSum') -def reduce_prod(rt_input, axis=None, name=None): +def reduce_prod(input_tensor, axis=None, keepdims=None, name=None): """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" return _ragged_reduce_aggregate(math_ops.reduce_prod, - math_ops.unsorted_segment_prod, rt_input, - axis, name or 'RaggedReduceProd') + math_ops.unsorted_segment_prod, input_tensor, + axis, keepdims, name or 'RaggedReduceProd') -def reduce_min(rt_input, axis=None, name=None): +def reduce_min(input_tensor, axis=None, keepdims=None, name=None): """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" return _ragged_reduce_aggregate(math_ops.reduce_min, - math_ops.unsorted_segment_min, rt_input, axis, - name or 'RaggedReduceMin') + math_ops.unsorted_segment_min, input_tensor, + axis, keepdims, name or 'RaggedReduceMin') -def reduce_max(rt_input, axis=None, name=None): +def reduce_max(input_tensor, axis=None, keepdims=None, name=None): """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" return _ragged_reduce_aggregate(math_ops.reduce_max, - math_ops.unsorted_segment_max, rt_input, axis, - name or 'RaggedReduceMax') + math_ops.unsorted_segment_max, input_tensor, + axis, keepdims, name or 'RaggedReduceMax') -def reduce_mean(rt_input, axis=None, name=None): +def reduce_mean(input_tensor, axis=None, keepdims=None, name=None): """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" - with ops.name_scope(name, 'RaggedReduceMean', [rt_input, axis]): - total = reduce_sum(rt_input, axis) - if ragged_tensor.is_ragged(rt_input): + with ops.name_scope(name, 'RaggedReduceMean', [input_tensor, axis]): + total = reduce_sum(input_tensor, axis, keepdims) + if ragged_tensor.is_ragged(input_tensor): ones = ragged_factory_ops.from_nested_row_splits( - array_ops.ones_like(rt_input.inner_values), - rt_input.nested_row_splits) + array_ops.ones_like(input_tensor.inner_values), + input_tensor.nested_row_splits) else: - ones = array_ops.ones_like(rt_input) - count = reduce_sum(ones, axis) + ones = array_ops.ones_like(input_tensor) + count = reduce_sum(ones, axis, keepdims) if ragged_tensor.is_ragged(total): return ragged_factory_ops.from_nested_row_splits( total.inner_values / count.inner_values, total.nested_row_splits) @@ -525,20 +535,25 @@ def reduce_mean(rt_input, axis=None, name=None): return total / count -def _cast(rt_input, dtype): - return ragged_functional_ops.map_inner_values(math_ops.cast, rt_input, dtype) +def _cast(input_tensor, dtype): + return ragged_functional_ops.map_inner_values(math_ops.cast, input_tensor, + dtype) -def reduce_all(rt_input, axis=None, name=None): +def reduce_all(input_tensor, axis=None, keepdims=None, name=None): """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" - with ops.name_scope(name, 'RaggedReduceAll', [rt_input, axis]): - return _cast(reduce_prod(_cast(rt_input, dtypes.int32), axis), dtypes.bool) + with ops.name_scope(name, 'RaggedReduceAll', [input_tensor, axis]): + return _cast( + reduce_prod(_cast(input_tensor, dtypes.int32), axis, keepdims), + dtypes.bool) -def reduce_any(rt_input, axis=None, name=None): +def reduce_any(input_tensor, axis=None, keepdims=None, name=None): """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" - with ops.name_scope(name, 'RaggedReduceAny', [rt_input, axis]): - return _cast(reduce_sum(_cast(rt_input, dtypes.int32), axis), dtypes.bool) + with ops.name_scope(name, 'RaggedReduceAny', [input_tensor, axis]): + return _cast( + reduce_sum(_cast(input_tensor, dtypes.int32), axis, keepdims), + dtypes.bool) def _set_ragged_reduce_docstring(func, combination, combined, default, example): @@ -554,9 +569,11 @@ _set_ragged_reduce_docstring(reduce_sum, 'sum', 'summed', '0', _set_ragged_reduce_docstring(reduce_prod, 'product', 'multiplied', '1', _RAGGED_REDUCE_PROD_EXAMPLE) _set_ragged_reduce_docstring(reduce_min, 'minimum', 'minimized', - '`rt_input.dtype.min`', _RAGGED_REDUCE_MIN_EXAMPLE) + '`input_tensor.dtype.min`', + _RAGGED_REDUCE_MIN_EXAMPLE) _set_ragged_reduce_docstring(reduce_max, 'maximum', 'maximized', - '`rt_input.dtype.max`', _RAGGED_REDUCE_MAX_EXAMPLE) + '`input_tensor.dtype.max`', + _RAGGED_REDUCE_MAX_EXAMPLE) _set_ragged_reduce_docstring(reduce_mean, 'mean', 'averaged', 'NaN', _RAGGED_REDUCE_MEAN_EXAMPLE) diff --git a/tensorflow/python/ops/ragged/ragged_operators.py b/tensorflow/python/ops/ragged/ragged_operators.py index 223ba0d2e7f..7654fa22b1e 100644 --- a/tensorflow/python/ops/ragged/ragged_operators.py +++ b/tensorflow/python/ops/ragged/ragged_operators.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.ops.ragged import ragged_elementwise_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops.ragged import ragged_getitem from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.util import tf_decorator @@ -33,40 +33,39 @@ def _right(operator): ragged_tensor.RaggedTensor.__getitem__ = ragged_getitem.ragged_tensor_getitem # Ordering operators -ragged_tensor.RaggedTensor.__ge__ = ragged_elementwise_ops.greater_equal -ragged_tensor.RaggedTensor.__gt__ = ragged_elementwise_ops.greater -ragged_tensor.RaggedTensor.__le__ = ragged_elementwise_ops.less_equal -ragged_tensor.RaggedTensor.__lt__ = ragged_elementwise_ops.less +ragged_tensor.RaggedTensor.__ge__ = math_ops.greater_equal +ragged_tensor.RaggedTensor.__gt__ = math_ops.greater +ragged_tensor.RaggedTensor.__le__ = math_ops.less_equal +ragged_tensor.RaggedTensor.__lt__ = math_ops.less # Logical operators -ragged_tensor.RaggedTensor.__and__ = ragged_elementwise_ops.logical_and -ragged_tensor.RaggedTensor.__rand__ = _right(ragged_elementwise_ops.logical_and) -ragged_tensor.RaggedTensor.__invert__ = ragged_elementwise_ops.logical_not -ragged_tensor.RaggedTensor.__ror__ = _right(ragged_elementwise_ops.logical_or) -ragged_tensor.RaggedTensor.__or__ = ragged_elementwise_ops.logical_or -ragged_tensor.RaggedTensor.__xor__ = ragged_elementwise_ops.logical_xor -ragged_tensor.RaggedTensor.__rxor__ = _right(ragged_elementwise_ops.logical_xor) +ragged_tensor.RaggedTensor.__and__ = math_ops.logical_and +ragged_tensor.RaggedTensor.__rand__ = _right(math_ops.logical_and) +ragged_tensor.RaggedTensor.__invert__ = math_ops.logical_not +ragged_tensor.RaggedTensor.__ror__ = _right(math_ops.logical_or) +ragged_tensor.RaggedTensor.__or__ = math_ops.logical_or +ragged_tensor.RaggedTensor.__xor__ = math_ops.logical_xor +ragged_tensor.RaggedTensor.__rxor__ = _right(math_ops.logical_xor) # Arithmetic operators -ragged_tensor.RaggedTensor.__abs__ = ragged_elementwise_ops.abs -ragged_tensor.RaggedTensor.__add__ = ragged_elementwise_ops.add -ragged_tensor.RaggedTensor.__radd__ = _right(ragged_elementwise_ops.add) -ragged_tensor.RaggedTensor.__div__ = ragged_elementwise_ops.div -ragged_tensor.RaggedTensor.__rdiv__ = _right(ragged_elementwise_ops.div) -ragged_tensor.RaggedTensor.__floordiv__ = ragged_elementwise_ops.floordiv -ragged_tensor.RaggedTensor.__rfloordiv__ = _right( - ragged_elementwise_ops.floordiv) -ragged_tensor.RaggedTensor.__mod__ = ragged_elementwise_ops.floormod -ragged_tensor.RaggedTensor.__rmod__ = _right(ragged_elementwise_ops.floormod) -ragged_tensor.RaggedTensor.__mul__ = ragged_elementwise_ops.multiply -ragged_tensor.RaggedTensor.__rmul__ = _right(ragged_elementwise_ops.multiply) -ragged_tensor.RaggedTensor.__neg__ = ragged_elementwise_ops.negative -ragged_tensor.RaggedTensor.__pow__ = ragged_elementwise_ops.pow -ragged_tensor.RaggedTensor.__rpow__ = _right(ragged_elementwise_ops.pow) -ragged_tensor.RaggedTensor.__sub__ = ragged_elementwise_ops.subtract -ragged_tensor.RaggedTensor.__rsub__ = _right(ragged_elementwise_ops.subtract) -ragged_tensor.RaggedTensor.__truediv__ = ragged_elementwise_ops.truediv -ragged_tensor.RaggedTensor.__rtruediv__ = _right(ragged_elementwise_ops.truediv) +ragged_tensor.RaggedTensor.__abs__ = math_ops.abs +ragged_tensor.RaggedTensor.__add__ = math_ops.add +ragged_tensor.RaggedTensor.__radd__ = _right(math_ops.add) +ragged_tensor.RaggedTensor.__div__ = math_ops.div +ragged_tensor.RaggedTensor.__rdiv__ = _right(math_ops.div) +ragged_tensor.RaggedTensor.__floordiv__ = math_ops.floordiv +ragged_tensor.RaggedTensor.__rfloordiv__ = _right(math_ops.floordiv) +ragged_tensor.RaggedTensor.__mod__ = math_ops.floormod +ragged_tensor.RaggedTensor.__rmod__ = _right(math_ops.floormod) +ragged_tensor.RaggedTensor.__mul__ = math_ops.multiply +ragged_tensor.RaggedTensor.__rmul__ = _right(math_ops.multiply) +ragged_tensor.RaggedTensor.__neg__ = math_ops.negative +ragged_tensor.RaggedTensor.__pow__ = math_ops.pow +ragged_tensor.RaggedTensor.__rpow__ = _right(math_ops.pow) +ragged_tensor.RaggedTensor.__sub__ = math_ops.subtract +ragged_tensor.RaggedTensor.__rsub__ = _right(math_ops.subtract) +ragged_tensor.RaggedTensor.__truediv__ = math_ops.truediv +ragged_tensor.RaggedTensor.__rtruediv__ = _right(math_ops.truediv) # Dummy methods diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 346ab9c0cb4..feff7df8501 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -2696,6 +2696,7 @@ class _UnaryMapValueDispatcher(dispatch.OpDispatcher): if args: x, args = args[0], args[1:] else: + kwargs = kwargs.copy() x = kwargs.pop(self._x, None) if isinstance(x, sparse_tensor.SparseTensor): return sparse_tensor.SparseTensor( diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py index b6b329c4865..046459706c0 100644 --- a/tensorflow/python/ops/string_ops.py +++ b/tensorflow/python/ops/string_ops.py @@ -38,6 +38,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.gen_string_ops import * from tensorflow.python.util import compat as util_compat from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export # pylint: enable=g-bad-import-order # pylint: enable=wildcard-import @@ -45,6 +46,7 @@ from tensorflow.python.util.tf_export import tf_export # pylint: disable=redefined-builtin @tf_export("strings.regex_full_match") +@dispatch.add_dispatch_support def regex_full_match(input, pattern, name=None): r"""Match elements of `input` with regex `pattern`. @@ -76,6 +78,7 @@ regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__ @tf_export( "strings.regex_replace", v1=["strings.regex_replace", "regex_replace"]) @deprecation.deprecated_endpoints("regex_replace") +@dispatch.add_dispatch_support def regex_replace(input, pattern, rewrite, replace_global=True, name=None): r"""Replace elements of `input` matching regex `pattern` with `rewrite`. @@ -350,10 +353,13 @@ reduce_join.__doc__ = reduce_join.__doc__.replace("tf.reduce_join(", # This wrapper provides backwards compatibility for code that predates the # unit argument and that passed 'name' as a positional argument. @tf_export(v1=["strings.length"]) +@dispatch.add_dispatch_support def string_length(input, name=None, unit="BYTE"): return gen_string_ops.string_length(input, unit=unit, name=name) + @tf_export("strings.length", v1=[]) +@dispatch.add_dispatch_support def string_length_v2(input, unit="BYTE", name=None): return string_length(input, name, unit) @@ -370,11 +376,13 @@ substr_deprecated.__doc__ = gen_string_ops.substr.__doc__ @tf_export(v1=["strings.substr"]) +@dispatch.add_dispatch_support def substr(input, pos, len, name=None, unit="BYTE"): return gen_string_ops.substr(input, pos, len, unit=unit, name=name) @tf_export("strings.substr", v1=[]) +@dispatch.add_dispatch_support def substr_v2(input, pos, len, unit="BYTE", name=None): return substr(input, pos, len, name=name, unit=unit) @@ -395,6 +403,7 @@ ops.NotDifferentiable("DecodeBase64") @tf_export("strings.to_number", v1=[]) +@dispatch.add_dispatch_support def string_to_number(input, out_type=dtypes.float32, name=None): r"""Converts each string in the input Tensor to the specified numeric type. @@ -418,6 +427,7 @@ tf_export(v1=["strings.to_number", "string_to_number"])( @tf_export("strings.to_hash_bucket", v1=[]) +@dispatch.add_dispatch_support def string_to_hash_bucket(input, num_buckets, name=None): # pylint: disable=line-too-long r"""Converts each string in the input Tensor to its hash mod by a number of buckets. diff --git a/tensorflow/python/util/dispatch.py b/tensorflow/python/util/dispatch.py index e7a56b5922c..e94e3345348 100644 --- a/tensorflow/python/util/dispatch.py +++ b/tensorflow/python/util/dispatch.py @@ -166,15 +166,14 @@ def dispatch_for_types(op, *types): def add_dispatch_list(target): """Decorator that adds a dispatch_list attribute to an op.""" - assert not hasattr(target, DISPATCH_ATTR) + if hasattr(target, DISPATCH_ATTR): + raise AssertionError("%s already has a dispatch list" % target) setattr(target, DISPATCH_ATTR, []) return target def add_dispatch_support(target): """Decorator that adds a dispatch handling wrapper to an op.""" - add_dispatch_list(target) - def wrapper(*args, **kwargs): """Call target, and fall back on dispatchers if there is a TypeError.""" try: @@ -188,5 +187,5 @@ def add_dispatch_support(target): else: raise - setattr(wrapper, DISPATCH_ATTR, []) + add_dispatch_list(wrapper) return tf_decorator.make_decorator(target, wrapper)