Internal Change

PiperOrigin-RevId: 224225849
This commit is contained in:
A. Unique TensorFlower 2018-12-05 14:51:36 -08:00 committed by TensorFlower Gardener
parent 2b0fd9b66a
commit 45cfe71266
21 changed files with 889 additions and 757 deletions

View File

@ -1,4 +1,6 @@
op {
graph_op_name: "FloorDiv"
visibility: HIDDEN
endpoint {
name: "floor_div"
}
}

View File

@ -1,4 +1,9 @@
op {
graph_op_name: "FloorMod"
visibility: HIDDEN
endpoint {
name: "floormod"
}
endpoint {
name: "mod"
}
}

View File

@ -1,4 +1,6 @@
op {
graph_op_name: "RealDiv"
visibility: HIDDEN
endpoint {
name: "realdiv"
}
}

View File

@ -1,4 +1,6 @@
op {
graph_op_name: "TruncateDiv"
visibility: HIDDEN
endpoint {
name: "truncatediv"
}
}

View File

@ -1,4 +1,6 @@
op {
graph_op_name: "TruncateMod"
visibility: HIDDEN
endpoint {
name: "truncatemod"
}
}

View File

@ -634,7 +634,9 @@ void GenEagerPythonOp::AddEagerFunctionTeardown(
bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
const string& parameters, const std::vector<string>& output_sizes,
const string& eager_not_allowed_error) {
strings::StrAppend(&result_, "@_dispatch.add_dispatch_list\n");
if (api_def_.visibility() == ApiDef::VISIBLE) {
strings::StrAppend(&result_, "@_dispatch.add_dispatch_list\n");
}
AddExport();
AddDefLine(function_name_, parameters);
AddDocStringDescription();

View File

@ -56,6 +56,7 @@ _BaseSlice = slice
@tf_export("identity")
@dispatch.add_dispatch_support
def identity(input, name=None): # pylint: disable=redefined-builtin
r"""Return a tensor with the same shape and contents as input.
@ -139,6 +140,7 @@ def expand_dims(input, axis=None, name=None, dim=None):
@tf_export("expand_dims", v1=[])
@dispatch.add_dispatch_support
def expand_dims_v2(input, axis, name=None):
"""Inserts a dimension of 1 into a tensor's shape.
@ -941,6 +943,7 @@ def parallel_stack(values, name="parallel_stack"):
@tf_export("stack")
@dispatch.add_dispatch_support
def stack(values, axis=0, name="stack"):
"""Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.
@ -1151,6 +1154,7 @@ def unstack(value, num=None, axis=0, name="unstack"):
@tf_export("concat")
@dispatch.add_dispatch_support
def concat(values, axis, name="concat"):
"""Concatenates tensors along one dimension.
@ -1328,6 +1332,7 @@ def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
@tf_export("boolean_mask", v1=[])
@dispatch.add_dispatch_support
def boolean_mask_v2(tensor, mask, axis=None, name="boolean_mask"):
"""Apply boolean mask to tensor.
@ -1810,6 +1815,7 @@ def zeros(shape, dtype=dtypes.float32, name=None):
@tf_export(v1=["zeros_like"])
@dispatch.add_dispatch_support
def zeros_like(tensor, dtype=None, name=None, optimize=True):
"""Creates a tensor with all elements set to zero.
@ -1840,6 +1846,7 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
@tf_export("zeros_like", v1=[])
@dispatch.add_dispatch_support
def zeros_like_v2(
input, # pylint: disable=redefined-builtin
dtype=None,
@ -1899,6 +1906,7 @@ def zeros_like_impl(tensor, dtype, name, optimize=True):
@tf_export(v1=["ones_like"])
@dispatch.add_dispatch_support
def ones_like(tensor, dtype=None, name=None, optimize=True):
"""Creates a tensor with all elements set to 1.
@ -1929,6 +1937,7 @@ def ones_like(tensor, dtype=None, name=None, optimize=True):
@tf_export("ones_like", v1=[])
@dispatch.add_dispatch_support
def ones_like_v2(
input, # pylint: disable=redefined-builtin
dtype=None,
@ -3115,6 +3124,7 @@ def squeeze_v2(input, axis=None, name=None):
@tf_export("where")
@dispatch.add_dispatch_support
def where(condition, x=None, y=None, name=None):
"""Return the elements, either from `x` or `y`, depending on the `condition`.
@ -3234,6 +3244,7 @@ def gather(params, indices, validate_indices=None, name=None, axis=0):
@tf_export("gather", v1=[])
@dispatch.add_dispatch_support
def gather_v2(params, indices, validate_indices=None, axis=0, name=None):
return gather(params, indices, validate_indices=validate_indices, name=name,
axis=axis)

View File

@ -31,10 +31,12 @@ from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import numerics
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@tf_export("clip_by_value")
@dispatch.add_dispatch_support
def clip_by_value(t, clip_value_min, clip_value_max,
name=None):
"""Clips tensor values to a specified min and max.

View File

@ -230,6 +230,7 @@ class DivideDelegateWithName(object):
@tf_export("math.divide", "divide")
@dispatch.add_dispatch_support
def divide(x, y, name=None):
"""Computes Python style division of `x` by `y`."""
@ -242,6 +243,7 @@ def divide(x, y, name=None):
@tf_export("math.multiply", "multiply")
@dispatch.add_dispatch_support
def multiply(x, y, name=None):
return gen_math_ops.mul(x, y, name)
@ -262,6 +264,7 @@ _mul.__doc__ = (
@tf_export("math.subtract", "subtract")
@dispatch.add_dispatch_support
def subtract(x, y, name=None):
return gen_math_ops.sub(x, y, name)
@ -347,6 +350,7 @@ def scalar_mul_v2(scalar, x, name=None):
@tf_export("math.pow", "pow")
@dispatch.add_dispatch_support
def pow(x, y, name=None): # pylint: disable=redefined-builtin
r"""Computes the power of one value to another.
@ -375,6 +379,7 @@ def pow(x, y, name=None): # pylint: disable=redefined-builtin
# pylint: disable=redefined-builtin,redefined-outer-name
@tf_export("dtypes.complex", "complex")
@dispatch.add_dispatch_support
def complex(real, imag, name=None):
r"""Converts two real numbers to a complex number.
@ -418,6 +423,7 @@ def complex(real, imag, name=None):
@tf_export("math.real", v1=["math.real", "real"])
@deprecation.deprecated_endpoints("real")
@dispatch.add_dispatch_support
def real(input, name=None):
r"""Returns the real part of a complex (or real) tensor.
@ -450,6 +456,7 @@ def real(input, name=None):
@tf_export("math.imag", v1=["math.imag", "imag"])
@deprecation.deprecated_endpoints("imag")
@dispatch.add_dispatch_support
def imag(input, name=None):
r"""Returns the imaginary part of a complex (or real) tensor.
@ -481,6 +488,7 @@ def imag(input, name=None):
@tf_export("math.angle", v1=["math.angle", "angle"])
@deprecation.deprecated_endpoints("angle")
@dispatch.add_dispatch_support
def angle(input, name=None):
r"""Returns the element-wise argument of a complex (or real) tensor.
@ -520,6 +528,7 @@ def angle(input, name=None):
@tf_export("math.round", "round")
@dispatch.add_dispatch_support
def round(x, name=None): # pylint: disable=redefined-builtin
"""Rounds the values of a tensor to the nearest integer, element-wise.
@ -547,6 +556,7 @@ def round(x, name=None): # pylint: disable=redefined-builtin
@tf_export("dtypes.cast", "cast")
@dispatch.add_dispatch_support
def cast(x, dtype, name=None):
"""Casts a tensor to a new type.
@ -610,6 +620,7 @@ def cast(x, dtype, name=None):
@tf_export("dtypes.saturate_cast", "saturate_cast")
@dispatch.add_dispatch_support
def saturate_cast(value, dtype, name=None):
"""Performs a safe saturating cast of `value` to `dtype`.
@ -935,6 +946,7 @@ def _div_python2(x, y, name=None):
@tf_export("math.truediv", "truediv")
@dispatch.add_dispatch_support
def truediv(x, y, name=None):
"""Divides x / y elementwise (using Python 3 division operator semantics).
@ -992,6 +1004,7 @@ def div(x, y, name=None):
@tf_export("div_no_nan")
@dispatch.add_dispatch_support
def div_no_nan(x, y, name=None):
"""Computes an unsafe divide which returns 0 if the y is zero.
@ -1021,6 +1034,7 @@ mod = gen_math_ops.floor_mod
# TODO(aselle): Deprecate this once all internal functionality uses
# tf.truncatediv
@tf_export("math.floordiv", v1=["math.floordiv", "floordiv"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("floordiv")
def floordiv(x, y, name=None):
"""Divides `x / y` elementwise, rounding toward the most negative integer.
@ -1050,16 +1064,11 @@ def floordiv(x, y, name=None):
realdiv = gen_math_ops.real_div
tf_export("realdiv")(realdiv)
truncatediv = gen_math_ops.truncate_div
tf_export("truncatediv")(truncatediv)
# TODO(aselle): Rename this to floordiv when we can.
floor_div = gen_math_ops.floor_div
tf_export("floor_div")(floor_div)
truncatemod = gen_math_ops.truncate_mod
tf_export("truncatemod")(truncatemod)
floormod = gen_math_ops.floor_mod
tf_export("floormod", "mod")(floormod)
def _mul_dispatch(x, y, name=None):
@ -1095,6 +1104,7 @@ _OverrideBinaryOperatorHelper(pow, "pow")
@tf_export("math.logical_xor", v1=["math.logical_xor", "logical_xor"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("logical_xor")
def logical_xor(x, y, name="LogicalXor"):
"""x ^ y = (x | y) & ~(x & y)."""
@ -1277,6 +1287,7 @@ def reduce_sum_v1(input_tensor,
@tf_export("math.reduce_sum", "reduce_sum", v1=[])
@dispatch.add_dispatch_support
def reduce_sum(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the sum of elements across dimensions of a tensor.
@ -1524,6 +1535,7 @@ def reduce_mean_v1(input_tensor,
@tf_export("math.reduce_mean", "reduce_mean", v1=[])
@dispatch.add_dispatch_support
def reduce_mean(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the mean of elements across dimensions of a tensor.
@ -1675,6 +1687,7 @@ def reduce_std(input_tensor, axis=None, keepdims=False, name=None):
@tf_export("math.reduce_prod", "reduce_prod", v1=[])
@dispatch.add_dispatch_support
def reduce_prod(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the product of elements across dimensions of a tensor.
@ -1796,6 +1809,7 @@ def reduce_min_v1(input_tensor,
@tf_export("math.reduce_min", "reduce_min", v1=[])
@dispatch.add_dispatch_support
def reduce_min(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the minimum of elements across dimensions of a tensor.
@ -1874,6 +1888,7 @@ def reduce_max_v1(input_tensor,
@tf_export("math.reduce_max", "reduce_max", v1=[])
@dispatch.add_dispatch_support
def reduce_max(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the maximum of elements across dimensions of a tensor.
@ -1961,6 +1976,7 @@ def reduce_all_v1(input_tensor,
@tf_export("reduce_all", "math.reduce_all", v1=[])
@dispatch.add_dispatch_support
def reduce_all(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the "logical and" of elements across dimensions of a tensor.
@ -2057,6 +2073,7 @@ def reduce_any_v1(input_tensor,
@tf_export("math.reduce_any", "reduce_any", v1=[])
@dispatch.add_dispatch_support
def reduce_any(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the "logical or" of elements across dimensions of a tensor.
@ -2619,6 +2636,7 @@ def _as_indexed_slices_list(inputs, optimize=True):
@tf_export("math.add_n", "add_n")
@dispatch.add_dispatch_support
def add_n(inputs, name=None):
"""Adds all input tensors element-wise.
@ -2764,6 +2782,7 @@ def sigmoid(x, name=None):
@tf_export("math.log_sigmoid", v1=["math.log_sigmoid", "log_sigmoid"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("log_sigmoid")
def log_sigmoid(x, name=None):
"""Computes log sigmoid of `x` element-wise.
@ -2973,6 +2992,7 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
@tf_export("math.conj", v1=["math.conj", "conj"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("conj")
def conj(x, name=None):
r"""Returns the complex conjugate of a complex number.
@ -3077,6 +3097,7 @@ def _unsorted_segment_N(data, segment_ids, num_segments):
"math.unsorted_segment_mean",
v1=["math.unsorted_segment_mean", "unsorted_segment_mean"])
@deprecation.deprecated_endpoints("unsorted_segment_mean")
@dispatch.add_dispatch_support
def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
r"""Computes the mean along segments of a tensor.
@ -3122,6 +3143,7 @@ def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
"math.unsorted_segment_sqrt_n",
v1=["math.unsorted_segment_sqrt_n", "unsorted_segment_sqrt_n"])
@deprecation.deprecated_endpoints("unsorted_segment_sqrt_n")
@dispatch.add_dispatch_support
def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None):
r"""Computes the sum along segments of a tensor divided by the sqrt(N).

View File

@ -25,7 +25,7 @@ py_library(
deps = [
":ragged_array_ops",
":ragged_conversion_ops",
":ragged_elementwise_ops",
":ragged_dispatch",
":ragged_factory_ops",
":ragged_functional_ops",
":ragged_getitem",
@ -150,33 +150,14 @@ py_library(
],
)
py_library(
name = "ragged_elementwise_ops",
srcs = ["ragged_elementwise_ops.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_factory_ops",
":ragged_tensor",
":ragged_tensor_shape",
":ragged_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:clip_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python:util",
],
)
py_library(
name = "ragged_operators",
srcs = ["ragged_operators.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_elementwise_ops",
":ragged_getitem",
":ragged_tensor",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
],
)
@ -186,12 +167,13 @@ py_library(
srcs = ["ragged_string_ops.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_array_ops",
":ragged_conversion_ops",
":ragged_factory_ops",
":ragged_tensor",
"//tensorflow/python:array_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python:util",
],
)
@ -219,10 +201,11 @@ py_library(
":ragged_tensor",
":ragged_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
],
)
@ -285,6 +268,29 @@ py_library(
],
)
py_library(
name = "ragged_dispatch",
srcs = ["ragged_dispatch.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_array_ops",
":ragged_factory_ops",
":ragged_math_ops",
":ragged_tensor",
":ragged_tensor_shape",
":ragged_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:clip_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:util",
"//third_party/py/numpy",
],
)
#-------------------------------------------------------------------------------
# RaggedTensor Tests
#-------------------------------------------------------------------------------
@ -458,6 +464,7 @@ py_test(
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:gradients_impl",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
@ -684,17 +691,21 @@ py_test(
)
py_test(
name = "ragged_elementwise_ops_test",
srcs = ["ragged_elementwise_ops_test.py"],
name = "ragged_dispatch_test",
srcs = ["ragged_dispatch_test.py"],
srcs_version = "PY2AND3",
deps = [
":ragged",
"//tensorflow/python:array_ops",
"//tensorflow/python:clip_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:string_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
@ -725,6 +736,7 @@ py_test(
"//tensorflow/python:platform_test",
"//tensorflow/python:string_ops",
"//tensorflow/python/keras:backend",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
@ -735,8 +747,10 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":ragged",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -1,76 +1,53 @@
"""Ragged Tensors.
This package defines the [`RaggedTensor`](ragged/RaggedTensor.md) class, which
This package defines the `tf.RaggedTensor` class, which
represents tensors with non-uniform shapes. In particular, each `RaggedTensor`
has one or more *ragged dimensions*, which are dimensions whose slices may have
different lengths. For example, the inner (column) dimension of
`rt=[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is ragged, since the column slices
(`rt[0, :]`, ..., `rt[4, :]`) have different lengths. For a more detailed
description of ragged tensors, see the [`RaggedTensor`](ragged/RaggedTensor.md)
description of ragged tensors, see the `tf.RaggedTensor`
class documentation.
## RaggedTensor Operations
## `RaggedTensor` Operations
This package also defines a collection of operations for manipulating
ragged tensors.
### `RaggedTensor` Factory ops
### RaggedTensor Versions of Standard Tensor Operations
* `tf.ragged.constant`
* `tf.ragged.from_row_splits`
* `tf.ragged.from_row_splits`
* `tf.ragged.from_row_lengths`
* `tf.ragged.from_row_starts`
* `tf.ragged.from_row_limits`
* `tf.ragged.from_value_rowids`
* `tf.ragged.from_nested_row_splits`
* `tf.ragged.from_nested_value_rowids`
Many of the operations defined by this package are analogous to
[`Tensor`](https://www.tensorflow.org/api_docs/python/tf/Tensor)
operations, but they accept `RaggedTensor`s as input and can return
`RaggedTensor`s as output. For example, `ragged.add` performs elementwise
addition just like `tf.add`, but can be used on `RaggedTensor`s.
### `RaggedTensor` Conversion ops
These `RaggedTensor` versions of the standard `Tensor` operations can also be
used with standard `Tensors`; and for the most part, they will return the same
value that the standard `Tensor` operation would return. However, there are
a few notable exceptions:
* `tf.ragged.from_tensor`
* `tf.ragged.to_tensor`
* `tf.ragged.from_sparse`
* `tf.ragged.to_sparse`
* `tf.ragged.from_variant`
* `tf.ragged.to_variant`
* `tf.ragged.convert_to_tensor_or_ragged_tensor`
* For [`ragged.stack(...)`](ragged/stack.md) and
[`ragged.concat(...)`](ragged/concat.md), the input tensors are not required
to have matching shapes. In the returned tensor, all dimensions up to the
`axis` dimension will be ragged.
### `RaggedTensor` Shape ops
### Ragged-Tensor Specific Operations
* `tf.ragged.row_splits`
* `tf.ragged.row_lengths`
* `tf.ragged.row_starts`
* `tf.ragged.row_limits`
* `tf.ragged.value_rowids`
* `tf.ragged.nrows`
* `tf.ragged.nested_row_splits`
* `tf.ragged.row_splits_to_segment_ids`
* `tf.ragged.segment_ids_to_row_splits`
* `tf.ragged.bounding_shape`
The following operations are specific to ragged tensors:
* **Factory ops**:
[`constant(...)`](ragged/constant.md),
[`from_row_splits(...)`](ragged/from_row_splits.md),
[`from_row_lengths(...)`](ragged/from_row_lengths.md),
[`from_row_starts(...)`](ragged/from_row_starts.md),
[`from_row_limits(...)`](ragged/from_row_limits.md),
[`from_value_rowids(...)`](ragged/from_value_rowids.md),
[`from_nested_row_splits(...)`](ragged/from_nested_row_splits.md),
[`from_nested_value_rowids(...)`](ragged/from_nested_value_rowids.md).
* **Conversion ops**:
[`from_tensor(...)`](ragged/from_tensor.md),
[`to_tensor(...)`](ragged/to_tensor.md),
[`from_sparse(...)`](ragged/from_sparse.md),
[`to_sparse(...)`](ragged/to_sparse.md),
[`from_variant(...)`](ragged/from_variant.md),
[`to_variant(...)`](ragged/to_variant.md),
[`convert_to_tensor_or_ragged_tensor(...)`](
ragged/convert_to_tensor_or_ragged_tensor.md).
* **Shape ops**:
[`row_splits(...)`](ragged/row_splits.md),
[`row_lengths(...)`](ragged/row_lengths.md),
[`row_starts(...)`](ragged/row_starts.md),
[`row_limits(...)`](ragged/row_limits.md),
[`value_rowids(...)`](ragged/value_rowids.md),
[`nrows(...)`](ragged/nrows.md),
[`nested_row_splits(...)`](ragged/nested_row_splits.md),
[`row_splits_to_segment_ids(...)`](ragged/row_splits_to_segment_ids.md),
[`segment_ids_to_row_splits(...)`](ragged/segment_ids_to_row_splits.md),
[`bounding_shape(...)`](ragged/bounding_shape.md).
* **Functional ops**:
[`map_inner_values(...)`](ragged/map_inner_values.md),
[`make_elementwise_op(...)`](ragged/make_elementwise_op.md).
### Functional ops
* `tf.ragged.map_inner_values`
<!-- Ragged Classes & related helper functions -->
@ -140,21 +117,17 @@ The following operations are specific to ragged tensors:
@@map_inner_values
@@map_fn
<!-- Elementwise Ops -->
@@make_elementwise_op
<!-- Shape & broadcasting -->
@@RaggedTensorDynamicShape
@@broadcast_to
@@broadcast_dynamic_shape
<!-- Symbols from ragged_elementwise_ops._symbols_to_export are whitelisted -->
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.ops.ragged import ragged_dispatch
from tensorflow.python.ops.ragged import ragged_operators
from tensorflow.python.ops.ragged import ragged_string_ops
@ -179,11 +152,6 @@ from tensorflow.python.ops.ragged.ragged_conversion_ops import from_tensor
from tensorflow.python.ops.ragged.ragged_conversion_ops import to_sparse
from tensorflow.python.ops.ragged.ragged_conversion_ops import to_tensor
# pylint: disable=protected-access, wildcard-import
from tensorflow.python.ops.ragged.ragged_elementwise_ops import *
from tensorflow.python.ops.ragged.ragged_elementwise_ops import _symbols_to_export as _elementwise_ops
# pylint: enable=protected-access, wildcard-import
from tensorflow.python.ops.ragged.ragged_factory_ops import constant
from tensorflow.python.ops.ragged.ragged_factory_ops import constant_value
from tensorflow.python.ops.ragged.ragged_factory_ops import convert_to_tensor_or_ragged_tensor
@ -231,6 +199,10 @@ from tensorflow.python.ops.ragged.segment_id_ops import segment_ids_to_row_split
from tensorflow.python.util import all_util as _all_util
# Register OpDispatchers that override standard TF ops to work w/ RaggedTensors.
__doc__ += ragged_dispatch.register_dispatchers() # pylint: disable=redefined-builtin
# Any symbol that is not referenced (with "@@name") in the module docstring
# above, or included in the "_elementwise_ops" whitelist, will be removed.
_all_util.remove_undocumented(__name__, _elementwise_ops)
# above will be removed.
_all_util.remove_undocumented(__name__)

View File

@ -308,7 +308,7 @@ def bounding_shape(rt_input, axis=None, name=None):
# ragged_gather
#===============================================================================
# TODO(edloper): Add an `axis` argument
def gather(params, indices, name=None):
def gather(params, indices, validate_indices=None, axis=0, name=None):
"""Gathers ragged slices from `params` axis `0` according to `indices`.
Returns `RaggedTensor` output, such that:
@ -347,6 +347,8 @@ def gather(params, indices, name=None):
indices: The potentially ragged tensor indicating which values to gather.
Must have dtype `int32` or `int64`. Values must be in the range `[0,
params.shape[0]]`.
validate_indices: Ignored.
axis: Must be zero.
name: A name for the operation (optional).
Returns:
@ -357,6 +359,9 @@ def gather(params, indices, name=None):
Raises:
ValueError: If indices.shape.ndims is not known statically.
"""
del validate_indices
if not isinstance(axis, int) or axis != 0:
raise ValueError('axis>0 is not supported for ragged gather yet.')
with ops.name_scope(name, 'RaggedGather', [params, indices]):
params = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
params, name='params')
@ -812,29 +817,29 @@ def boolean_mask(data, mask, keepdims=False, name=None):
#===============================================================================
# Concatenation and Stacking
#===============================================================================
def concat(rt_inputs, axis, name=None):
def concat(values, axis, name=None):
"""Concatenates potentially ragged tensors along one dimension.
Given a list of tensors with the same rank `K` (`K >= axis`), returns a
rank-`K` `RaggedTensor` `result` such that `result[i0...iaxis]` is the
concatenation of `[rt[i0...iaxis] for rt in rt_inputs]`.
concatenation of `[rt[i0...iaxis] for rt in values]`.
Args:
rt_inputs: A list of potentially ragged tensors. May not be empty. All
`rt_inputs` must have the same rank and the same dtype; but unlike
values: A list of potentially ragged tensors. May not be empty. All
`values` must have the same rank and the same dtype; but unlike
`tf.concat`, they can have arbitrary shapes.
axis: A python integer, indicating the dimension along which to concatenate.
(Note: Unlike `tf.concat`, the `axis` parameter must be statically known.)
Negative values are supported only if the rank of at least one
`rt_inputs` value is statically known.
`values` value is statically known.
name: A name prefix for the returned tensor (optional).
Returns:
A `RaggedTensor` with rank `K`.
`result.ragged_rank=max(axis, max(rt.ragged_rank for rt in rt_inputs]))`.
`result.ragged_rank=max(axis, max(rt.ragged_rank for rt in values]))`.
Raises:
ValueError: If `rt_inputs` is empty, if `axis` is out of bounds or if
ValueError: If `values` is empty, if `axis` is out of bounds or if
the input tensors have different ranks.
#### Example:
@ -847,35 +852,35 @@ def concat(rt_inputs, axis, name=None):
[[1, 2, 6], [3, 4, 5, 7, 8, 9]]
```
"""
if not isinstance(rt_inputs, (list, tuple)):
rt_inputs = [rt_inputs]
with ops.name_scope(name, 'RaggedConcat', rt_inputs):
return _ragged_stack_concat_helper(rt_inputs, axis, stack_values=False)
if not isinstance(values, (list, tuple)):
values = [values]
with ops.name_scope(name, 'RaggedConcat', values):
return _ragged_stack_concat_helper(values, axis, stack_values=False)
def stack(rt_inputs, axis, name=None):
def stack(values, axis, name=None):
"""Stacks potentially ragged tensors along one dimension.
Given a list of tensors with the same rank `K` (`K >= axis`), returns a
rank-`K+1` `RaggedTensor` `result` such that `result[i0...iaxis]` is the
list `[rt[i0...iaxis] for rt in rt_inputs]`.
list `[rt[i0...iaxis] for rt in values]`.
Args:
rt_inputs: A list of potentially ragged tensors. May not be empty. All
`rt_inputs` must have the same rank and the same dtype; but unlike
values: A list of potentially ragged tensors. May not be empty. All
`values` must have the same rank and the same dtype; but unlike
`tf.concat`, they can have arbitrary shapes.
axis: A python integer, indicating the dimension along which to stack.
(Note: Unlike `tf.stack`, the `axis` parameter must be statically known.)
Negative values are supported only if the rank of at least one
`rt_inputs` value is statically known.
`values` value is statically known.
name: A name prefix for the returned tensor (optional).
Returns:
A `RaggedTensor` with rank `K+1`.
`result.ragged_rank=max(axis, max(rt.ragged_rank for rt in rt_inputs]))`.
`result.ragged_rank=max(axis, max(rt.ragged_rank for rt in values]))`.
Raises:
ValueError: If `rt_inputs` is empty, if `axis` is out of bounds or if
ValueError: If `values` is empty, if `axis` is out of bounds or if
the input tensors have different ranks.
#### Example:
@ -888,10 +893,10 @@ def stack(rt_inputs, axis, name=None):
[[[1, 2], [6]], [[3, 4, 5], [7, 8, 9]]]
```
"""
if not isinstance(rt_inputs, (list, tuple)):
rt_inputs = [rt_inputs]
with ops.name_scope(name, 'RaggedConcat', rt_inputs):
return _ragged_stack_concat_helper(rt_inputs, axis, stack_values=True)
if not isinstance(values, (list, tuple)):
values = [values]
with ops.name_scope(name, 'RaggedConcat', values):
return _ragged_stack_concat_helper(values, axis, stack_values=True)
def _ragged_stack_concat_helper(rt_inputs, axis, stack_values):
@ -1065,22 +1070,22 @@ def _copy_row_shape(rt_inputs, splits):
#===============================================================================
# Tiling
#===============================================================================
def tile(rt_input, multiples, name=None):
def tile(input, multiples, name=None): # pylint: disable=redefined-builtin
"""Constructs a `RaggedTensor` by tiling a given `RaggedTensor`.
The values of `rt_input` are replicated `multiples[i]` times along the
The values of `input` are replicated `multiples[i]` times along the
`i`th dimension (for each dimension `i`). For every dimension `axis` in
`rt_input`, the length of each output element in that dimension is the
`input`, the length of each output element in that dimension is the
length of corresponding input element multiplied by `multiples[axis]`.
Args:
rt_input: A `RaggedTensor`.
input: A `RaggedTensor`.
multiples: A 1-D integer `Tensor`. Length must be the same as the number of
dimensions in `rt_input`.
dimensions in `input`.
name: A name for the operation (optional).
Returns:
A `RaggedTensor` with the same type, rank, and ragged_rank as `rt_input`.
A `RaggedTensor` with the same type, rank, and ragged_rank as `input`.
#### Example:
```python
@ -1089,22 +1094,22 @@ def tile(rt_input, multiples, name=None):
[[1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3]]
```
"""
with ops.name_scope(name, 'RaggedTile', [rt_input, multiples]):
rt_input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
rt_input, name='rt_input')
with ops.name_scope(name, 'RaggedTile', [input, multiples]):
input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
input, name='input')
multiples = ragged_util.convert_to_int_tensor(
multiples, name='multiples', dtype=dtypes.int64)
multiples.shape.assert_has_rank(1)
if not ragged_tensor.is_ragged(rt_input):
return array_ops.tile(rt_input, multiples, name)
if not ragged_tensor.is_ragged(input):
return array_ops.tile(input, multiples, name)
# If the constant value of `multiples` is available, then we can use it
# to skip tiling dimensions where `multiples=1`.
const_multiples = tensor_util.constant_value(multiples)
return ragged_factory_ops.from_nested_row_splits(
_tile_ragged_values(rt_input, multiples, const_multiples),
_tile_ragged_splits(rt_input, multiples, const_multiples))
_tile_ragged_values(input, multiples, const_multiples),
_tile_ragged_splits(input, multiples, const_multiples))
def _tile_ragged_values(rt_input, multiples, const_multiples=None):
@ -1240,26 +1245,26 @@ def _tile_ragged_splits(rt_input, multiples, const_multiples=None):
#===============================================================================
def expand_dims(rt_input, axis, name=None):
def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
"""Inserts a dimension with shape 1 into a potentially ragged tensor's shape.
Given a potentially ragged tenor `rt_input`, this operation inserts a
dimension with size 1 at the dimension `axis` of `rt_input`'s shape.
Given a potentially ragged tenor `input`, this operation inserts a
dimension with size 1 at the dimension `axis` of `input`'s shape.
* If `rt_input` is a `Tensor`, then this is equivalent to
* If `input` is a `Tensor`, then this is equivalent to
`tf.expand_dims`.
* If `rt_input` is ragged, and `axis=0`, then the new dimension will be
* If `input` is ragged, and `axis=0`, then the new dimension will be
uniform; but the previously outermost dimension will become ragged.
* If `rt_input` is ragged, and `0 < axis < rt_input.ragged_rank`, then the
* If `input` is ragged, and `0 < axis < input.ragged_rank`, then the
new dimension will be ragged.
* If `rt_input` is ragged, and axis >= rt_input.ragged_rank`, then the new
* If `input` is ragged, and axis >= input.ragged_rank`, then the new
dimension will be uniform.
The following table gives some examples showing how `ragged.expand_dims`
impacts the shapes of different input tensors. Ragged dimensions are
indicated by enclosing them in parentheses.
rt_input.shape | axis | result.shape
input.shape | axis | result.shape
----------------------- | ---- | -----------------------------
`[D1, D2]` | `0` | `[1, D1, D2]`
`[D1, D2]` | `1` | `[D1, 1, D2]`
@ -1271,14 +1276,14 @@ def expand_dims(rt_input, axis, name=None):
`[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]`
Args:
rt_input: The potentially tensor that should be expanded with a new
input: The potentially tensor that should be expanded with a new
dimension.
axis: An integer constant indicating where the new dimension should be
inserted.
name: A name for the operation (optional).
Returns:
A tensor with the same values as `rt_input`, with an added dimension of
A tensor with the same values as `input`, with an added dimension of
size 1 at `axis`.
#### Examples:
@ -1300,24 +1305,24 @@ def expand_dims(rt_input, axis, name=None):
TensorShape([2, None, 1]) [[[1], [2]], [[3]]]
```
"""
with ops.name_scope(name, 'RaggedExpandDims', [rt_input]):
rt_input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
rt_input, name='rt_input')
with ops.name_scope(name, 'RaggedExpandDims', [input]):
input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
input, name='input')
if not ragged_tensor.is_ragged(rt_input):
return array_ops.expand_dims(rt_input, axis)
if not ragged_tensor.is_ragged(input):
return array_ops.expand_dims(input, axis)
ndims = None if rt_input.shape.ndims is None else rt_input.shape.ndims + 1
ndims = None if input.shape.ndims is None else input.shape.ndims + 1
axis = ragged_util.get_positive_axis(axis, ndims)
if axis == 0:
values = rt_input
splits = array_ops.stack([0, nrows(rt_input)])
values = input
splits = array_ops.stack([0, nrows(input)])
elif axis == 1:
values = rt_input
splits = math_ops.range(nrows(rt_input) + 1)
values = input
splits = math_ops.range(nrows(input) + 1)
else:
values = expand_dims(rt_input.values, axis - 1)
splits = rt_input.row_splits
values = expand_dims(input.values, axis - 1)
splits = input.row_splits
return ragged_factory_ops.from_row_splits(values, splits)

View File

@ -0,0 +1,441 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Operator dispatch for RaggedTensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_shape
from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.util import dispatch
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_export
from tensorflow.python.util import tf_inspect
# @TODO(edloper): Set this to True in the CL that exports RaggedTensors.
_UPDATE_DOCSTRINGS = False
# Information about an argument to an operation: The name of the argument, its
# position in the argument list, and a boolean flag indicating whether it
# expects a list of tensors.
_ArgInfo = collections.namedtuple('ArgInfo', ['name', 'position', 'is_list'])
def _get_arg_infos(func, arg_names):
"""Returns an `_ArgInfo` for each argument of `func` specified by `arg_names`.
Args:
func: The function whose arguments should be described.
arg_names: The names of the arguments to get info for.
Returns:
A tuple of `_ArgInfo`s.
"""
arg_infos = []
# Inspect the func's argspec to find the position of each arg.
arg_spec = tf_inspect.getargspec(func)
for argname in arg_names:
assert isinstance(argname, str)
is_list = argname.startswith('[') and argname.endswith(']')
if is_list:
argname = argname[1:-1]
if argname not in arg_spec.args:
raise ValueError('Argument %r not found function in %s. Args=%s' %
(argname, func, arg_spec.args))
arg_infos.append(_ArgInfo(argname, arg_spec.args.index(argname), is_list))
return arg_infos
def _is_convertible_to_tensor(value):
"""Returns true if `value` is convertible to a `Tensor`."""
if isinstance(value,
(ops.Tensor, variables.Variable, np.ndarray, int, float, str)):
return True
elif isinstance(value, (sparse_tensor.SparseTensor,)):
return False
else:
try:
ops.convert_to_tensor(value)
return True
except (TypeError, ValueError):
return False
class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
"""OpDispatcher for unary ops that map a base op across ragged values."""
def __init__(self, original_op, arg_is_list=False):
self._original_op = original_op
self._arg_is_list = arg_is_list
arg_names = tf_inspect.getfullargspec(original_op)[0]
self._x = arg_names[0]
if _UPDATE_DOCSTRINGS:
original_op.__doc__ = (
original_op.__doc__.rstrip() + '\n\n' +
' `{x}` may be a `tf.RaggedTensor`.\n'.format(x=self._x))
def handle(self, args, kwargs):
if args:
x, args = args[0], args[1:]
else:
kwargs = kwargs.copy()
x = kwargs.pop(self._x, None)
if x is None:
return self.NOT_SUPPORTED
if self._arg_is_list:
found_ragged = False
for elt in x:
if ragged_tensor.is_ragged(elt):
found_ragged = True
elif not _is_convertible_to_tensor(elt):
return self.NOT_SUPPORTED
if found_ragged:
nested_splits_lists = [
elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt)
]
inner_values = [
elt.inner_values if ragged_tensor.is_ragged(elt) else elt
for elt in x
]
with ops.control_dependencies(
ragged_util.assert_splits_match(nested_splits_lists)):
return ragged_factory_ops.from_nested_row_splits(
self._original_op(inner_values, *args, **kwargs),
nested_splits_lists[0])
else:
return self.NOT_SUPPORTED
else:
found_ragged = ragged_tensor.is_ragged(x)
if found_ragged:
mapped_values = self._original_op(x.inner_values, *args, **kwargs)
return x.with_inner_values(mapped_values)
else:
return self.NOT_SUPPORTED
class BinaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
"""OpDispatcher for binary ops that map a base op across ragged values.
Supports broadcasting.
"""
def __init__(self, original_op):
self._original_op = original_op
arg_names = tf_inspect.getfullargspec(original_op)[0]
self._x = arg_names[0]
self._y = arg_names[1]
if _UPDATE_DOCSTRINGS:
original_op.__doc__ = (
original_op.__doc__.rstrip() + '\n\n' +
' `{x}` and `{y}` may be a `tf.RaggedTensor`.\n'.format(
x=self._x, y=self._y))
def handle(self, args, kwargs):
# Extract the binary args.
if len(args) > 1:
x = args[0]
y = args[1]
args = args[2:]
elif args:
kwargs = kwargs.copy()
x = args[0]
y = kwargs.pop(self._y, None)
args = args[1:]
else:
kwargs = kwargs.copy()
x = kwargs.pop(self._x, None)
y = kwargs.pop(self._y, None)
# Bail if we don't have at least one ragged argument.
x_is_ragged = ragged_tensor.is_ragged(x)
y_is_ragged = ragged_tensor.is_ragged(y)
if not (x_is_ragged or y_is_ragged):
return self.NOT_SUPPORTED
# Convert args to tensors. Bail if conversion fails.
try:
if not x_is_ragged:
x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype)
if not y_is_ragged:
y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype)
except (TypeError, ValueError):
return self.NOT_SUPPORTED
if ((x_is_ragged and y_is_ragged) or
(x_is_ragged and x.inner_values.shape.ndims <= y.shape.ndims) or
(y_is_ragged and y.inner_values.shape.ndims <= x.shape.ndims)):
bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(
ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x),
ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y))
x = ragged_tensor_shape.broadcast_to(
x, bcast_shape, broadcast_inner_dimensions=False)
y = ragged_tensor_shape.broadcast_to(
y, bcast_shape, broadcast_inner_dimensions=False)
x_values = x.inner_values if ragged_tensor.is_ragged(x) else x
y_values = y.inner_values if ragged_tensor.is_ragged(y) else y
mapped_values = self._original_op(x_values, y_values, *args, **kwargs)
if ragged_tensor.is_ragged(x):
return x.with_inner_values(mapped_values)
else:
return y.with_inner_values(mapped_values)
class RaggedDispatcher(dispatch.OpDispatcher):
"""OpDispatcher for ragged ops.
Dispatches to a wrapped op-handler if at least one of the `tensor_args`
arguments is a RaggedTensor or a RaggedTensorValue; and all of the
`tensor_args` arguments are convertible to Tensor or RaggedTensor.
"""
def __init__(self, original_op, ragged_op, ragged_args):
op_arg_names = tf_inspect.getfullargspec(original_op)[0]
ragged_arg_names = tf_inspect.getfullargspec(ragged_op)[0]
if op_arg_names != ragged_arg_names:
raise AssertionError(
'Signature must exactly match when overriding %s with %s: %s vs %s' %
(original_op, ragged_op, op_arg_names, ragged_arg_names))
self._ragged_op = ragged_op
self._ragged_args = _get_arg_infos(ragged_op, ragged_args)
if _UPDATE_DOCSTRINGS:
arg_list = ' and '.join('`%s`' % arg for arg in ragged_args)
original_op.__doc__ = (
original_op.__doc__.rstrip() + '\n\n' +
' {0} may be a `tf.RaggedTensor`.\n'.format(arg_list))
def handle(self, args, kwargs):
if self.is_supported(args, kwargs):
return self._ragged_op(*args, **kwargs)
else:
return self.NOT_SUPPORTED
def is_supported(self, args, kwargs):
found_ragged = False
for arg_info in self._ragged_args:
if arg_info.position < len(args):
arg = args[arg_info.position]
else:
arg = kwargs.get(arg_info.name, None)
if arg_info.is_list:
if not isinstance(arg, (list, tuple)):
return False
for elt in arg:
if ragged_tensor.is_ragged(elt):
found_ragged = True
elif not _is_convertible_to_tensor(elt):
return False
else:
if ragged_tensor.is_ragged(arg):
found_ragged = True
elif not _is_convertible_to_tensor(arg):
return False
return found_ragged
def ragged_dispatch(original_op, tensor_args):
def decorator(ragged_op):
dispatch.RaggedDispatcher(original_op, ragged_op,
tensor_args).register(original_op)
return ragged_op
return decorator
_UNARY_ELEMENTWISE_OPS = [
array_ops.check_numerics,
array_ops.identity,
array_ops.ones_like,
array_ops.ones_like_v2,
array_ops.zeros_like,
array_ops.zeros_like_v2,
clip_ops.clip_by_value,
math_ops.abs,
math_ops.acos,
math_ops.acosh,
math_ops.angle,
math_ops.asin,
math_ops.asinh,
math_ops.atan,
math_ops.atanh,
math_ops.cast,
math_ops.ceil,
math_ops.conj,
math_ops.cos,
math_ops.cosh,
math_ops.digamma,
math_ops.erf,
math_ops.erfc,
math_ops.exp,
math_ops.expm1,
math_ops.floor,
math_ops.imag,
math_ops.is_finite,
math_ops.is_inf,
math_ops.is_nan,
math_ops.lgamma,
math_ops.log,
math_ops.log1p,
math_ops.log_sigmoid,
math_ops.logical_not,
math_ops.negative,
math_ops.real,
math_ops.reciprocal,
math_ops.rint,
math_ops.round,
math_ops.rsqrt,
math_ops.saturate_cast,
math_ops.sign,
math_ops.sin,
math_ops.sinh,
math_ops.sqrt,
math_ops.square,
math_ops.tan,
parsing_ops.decode_compressed,
string_ops.string_to_number,
string_ops.string_to_hash_bucket,
string_ops.as_string,
string_ops.decode_base64,
string_ops.encode_base64,
string_ops.regex_full_match,
string_ops.regex_replace,
string_ops.string_strip,
string_ops.string_to_hash_bucket,
string_ops.string_to_hash_bucket_fast,
string_ops.string_to_hash_bucket_strong,
string_ops.substr,
string_ops.substr_v2,
string_ops.string_length,
string_ops.string_length_v2,
string_ops.unicode_script,
]
_UNARY_LIST_ELEMENTWISE_OPS = [
math_ops.add_n,
string_ops.string_join,
]
_BINARY_ELEMENTWISE_OPS = [
math_ops.add,
math_ops.atan2,
math_ops.complex,
math_ops.div_no_nan,
math_ops.divide,
math_ops.equal,
math_ops.floordiv,
math_ops.floormod,
math_ops.greater,
math_ops.greater_equal,
math_ops.less,
math_ops.less_equal,
math_ops.logical_and,
math_ops.logical_or,
math_ops.logical_xor,
math_ops.maximum,
math_ops.minimum,
math_ops.multiply,
math_ops.not_equal,
math_ops.pow,
math_ops.realdiv,
math_ops.squared_difference,
math_ops.subtract,
math_ops.truediv,
math_ops.truncatediv,
math_ops.truncatemod,
]
# (original_op, ragged_op, ragged_args)
_RAGGED_DISPATCH_OPS = [
(array_ops.batch_gather, ragged_array_ops.batch_gather,
['params', 'indices']),
(array_ops.concat, ragged_array_ops.concat, ['values']),
(array_ops.expand_dims_v2, ragged_array_ops.expand_dims, ['input']),
(array_ops.gather_v2, ragged_array_ops.gather, ['params', 'indices']),
(array_ops.gather_nd, ragged_array_ops.gather_nd, ['params', 'indices']),
(array_ops.stack, ragged_array_ops.stack, ['values']),
(array_ops.tile, ragged_array_ops.tile, ['input']),
(array_ops.where, ragged_array_ops.where, ['condition', 'x', 'y']),
(math_ops.unsorted_segment_sum, ragged_math_ops.segment_sum,
['data', 'segment_ids']),
(math_ops.unsorted_segment_prod, ragged_math_ops.segment_prod,
['data', 'segment_ids']),
(math_ops.unsorted_segment_min, ragged_math_ops.segment_min,
['data', 'segment_ids']),
(math_ops.unsorted_segment_max, ragged_math_ops.segment_max,
['data', 'segment_ids']),
(math_ops.unsorted_segment_mean, ragged_math_ops.segment_mean,
['data', 'segment_ids']),
(math_ops.unsorted_segment_sqrt_n, ragged_math_ops.segment_sqrt_n,
['data', 'segment_ids']),
(math_ops.reduce_sum, ragged_math_ops.reduce_sum, ['input_tensor']),
(math_ops.reduce_prod, ragged_math_ops.reduce_prod, ['input_tensor']),
(math_ops.reduce_min, ragged_math_ops.reduce_min, ['input_tensor']),
(math_ops.reduce_max, ragged_math_ops.reduce_max, ['input_tensor']),
(math_ops.reduce_mean, ragged_math_ops.reduce_mean, ['input_tensor']),
(math_ops.reduce_any, ragged_math_ops.reduce_any, ['input_tensor']),
(math_ops.reduce_all, ragged_math_ops.reduce_all, ['input_tensor']),
]
def register_dispatchers():
"""Constructs & registers OpDispatchers for ragged ops."""
op_list = (
_UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS +
_BINARY_ELEMENTWISE_OPS + [x[0] for x in _RAGGED_DISPATCH_OPS])
for op in op_list:
_, undecorated_op = tf_decorator.unwrap(op)
if not hasattr(undecorated_op, tf_export.API_ATTRS['tensorflow'].names):
raise AssertionError('Expected %s to be an exported symbol '
'(while adding a RaggedTensor dispatcher)')
for op in _UNARY_ELEMENTWISE_OPS:
UnaryRaggedElementwiseDispatcher(op).register(op)
for op in _UNARY_LIST_ELEMENTWISE_OPS:
UnaryRaggedElementwiseDispatcher(op, True).register(op)
for op in _BINARY_ELEMENTWISE_OPS:
BinaryRaggedElementwiseDispatcher(op).register(op)
for (original_op, ragged_op, args) in _RAGGED_DISPATCH_OPS:
RaggedDispatcher(original_op, ragged_op, args).register(original_op)
docstring = (
'\n\n### Additional ops that support `RaggedTensor`\n\n' + '\n'.join([
'* `tf.%s`' % tf_export.get_canonical_name_for_symbol(op)
for op in op_list
]))
return docstring

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ragged.elementwise_ops."""
"""Tests for RaggedTensor operator dispatch."""
from __future__ import absolute_import
from __future__ import division
@ -21,106 +21,108 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import ragged
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import googletest
# Constants listing various op types to test. Each elementwise operation
# Constants listing various op types to test. Each operation
# should be included in at least one list below, or tested separately if
# necessary (e.g., because it expects additional arguments).
UNARY_FLOAT_OPS = [
ragged.abs,
ragged.acos,
ragged.acosh,
ragged.angle,
ragged.asin,
ragged.asinh,
ragged.atan,
ragged.atanh,
ragged.ceil,
ragged.conj,
ragged.cos,
ragged.cosh,
ragged.digamma,
ragged.erf,
ragged.erfc,
ragged.exp,
ragged.expm1,
ragged.floor,
ragged.imag,
ragged.is_finite,
ragged.is_inf,
ragged.is_nan,
ragged.lgamma,
ragged.log,
ragged.log1p,
ragged.log_sigmoid,
ragged.negative,
ragged.real,
ragged.reciprocal,
ragged.rint,
ragged.round,
ragged.rsqrt,
ragged.sign,
ragged.sin,
ragged.sinh,
ragged.sqrt,
ragged.square,
ragged.tan,
ragged.as_string,
ragged.identity,
ragged.ones_like,
ragged.zeros_like,
math_ops.abs,
math_ops.acos,
math_ops.acosh,
math_ops.angle,
math_ops.asin,
math_ops.asinh,
math_ops.atan,
math_ops.atanh,
math_ops.ceil,
math_ops.conj,
math_ops.cos,
math_ops.cosh,
math_ops.digamma,
math_ops.erf,
math_ops.erfc,
math_ops.exp,
math_ops.expm1,
math_ops.floor,
math_ops.imag,
math_ops.is_finite,
math_ops.is_inf,
math_ops.is_nan,
math_ops.lgamma,
math_ops.log,
math_ops.log1p,
math_ops.log_sigmoid,
math_ops.negative,
math_ops.real,
math_ops.reciprocal,
math_ops.rint,
math_ops.round,
math_ops.rsqrt,
math_ops.sign,
math_ops.sin,
math_ops.sinh,
math_ops.sqrt,
math_ops.square,
math_ops.tan,
array_ops.identity,
array_ops.ones_like,
array_ops.zeros_like,
]
UNARY_BOOL_OPS = [
ragged.logical_not,
math_ops.logical_not,
]
UNARY_STRING_OPS = [
ragged.decode_base64,
ragged.encode_base64,
ragged.string_strip,
ragged.decode_compressed,
string_ops.decode_base64,
string_ops.encode_base64,
string_ops.string_strip,
parsing_ops.decode_compressed,
]
BINARY_FLOAT_OPS = [
ragged.add,
ragged.atan2,
ragged.complex,
ragged.div,
ragged.div_no_nan,
ragged.divide,
ragged.equal,
ragged.floordiv,
ragged.floormod,
ragged.greater,
ragged.greater_equal,
ragged.less,
ragged.less_equal,
ragged.maximum,
ragged.minimum,
ragged.multiply,
ragged.not_equal,
ragged.pow,
ragged.realdiv,
ragged.squared_difference,
ragged.subtract,
ragged.truediv,
math_ops.add,
math_ops.atan2,
math_ops.complex,
math_ops.div_no_nan,
math_ops.divide,
math_ops.equal,
math_ops.floordiv,
math_ops.floormod,
math_ops.greater,
math_ops.greater_equal,
math_ops.less,
math_ops.less_equal,
math_ops.maximum,
math_ops.minimum,
math_ops.multiply,
math_ops.not_equal,
math_ops.pow,
math_ops.realdiv,
math_ops.squared_difference,
math_ops.subtract,
math_ops.truediv,
]
BINARY_BOOL_OPS = [
ragged.logical_and,
ragged.logical_or,
ragged.logical_xor,
math_ops.logical_and,
math_ops.logical_or,
math_ops.logical_xor,
]
UNARY_INT_OPS = [
ragged.unicode_script,
string_ops.unicode_script,
]
BINARY_INT_OPS = [
ragged.truncatediv,
ragged.truncatemod,
math_ops.truncatediv,
math_ops.truncatemod,
]
@ -171,50 +173,49 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
[{'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]), 'op': op}
for op in UNARY_STRING_OPS] +
[
{'op': ragged.clip_by_value,
{'op': clip_ops.clip_by_value,
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
'clip_value_min': 0.1, 'clip_value_max': 4.0},
{'op': ragged.cast,
{'op': math_ops.cast,
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
'dtype': dtypes.int32},
{'op': ragged.saturate_cast,
{'op': math_ops.saturate_cast,
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
'dtype': dtypes.int32},
{'op': ragged.string_to_hash_bucket,
{'op': string_ops.string_to_hash_bucket,
'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
'num_buckets': 1000},
{'op': ragged.string_to_hash_bucket_fast,
{'op': string_ops.string_to_hash_bucket_fast,
'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
'num_buckets': 1000},
{'op': ragged.string_to_hash_bucket_strong,
{'op': string_ops.string_to_hash_bucket_strong,
'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
'num_buckets': 1000,
'key': [1231, 12512]},
{'op': ragged.string_to_number,
{'op': string_ops.string_to_number,
'x': ragged.constant_value([['-2.0', '3.0'], ['-3.0']])},
{'op': ragged.regex_full_match,
{'op': string_ops.regex_full_match,
'x': ragged.constant_value([['hello', '123'], ['1+1']]),
'pattern': r'\w+'},
{'op': ragged.regex_replace,
{'op': string_ops.regex_replace,
'x': ragged.constant_value([['hello', '123'], ['1+1']]),
'pattern': r'\d',
'rewrite': '#'},
{'op': ragged.substr,
{'op': string_ops.substr,
'x': ragged.constant_value([['hello', '123'], ['1+1']]),
'pos': 2, 'len': 3},
{'op': ragged.check_numerics,
{'op': array_ops.check_numerics,
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
'message': 'check-numerics'},
]
) # pyformat: disable
def testUnaryOp(self, x, op=ragged.abs, **extra_args):
def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args):
x = ragged.convert_to_tensor_or_ragged_tensor(x)
result = op(x, **extra_args)
# Run the wrapped op on the dense values, for comparison.
dense_x = x.inner_values if isinstance(x, ragged.RaggedTensor) else x
expected_flat_values = array_ops.reshape(
op.__wrapped__(dense_x, **extra_args), [-1])
expected_flat_values = array_ops.reshape(op(dense_x, **extra_args), [-1])
with self.test_session():
# Check that the result has the expected shape.
@ -285,12 +286,17 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
#=====================================================================
{'x': ragged.constant_value([[[1, 2], [3], [4]], [[], [5, 7, 8]]]),
'y': ragged.constant_value([[[3, 8], [2], [5]], [[], [1, 9, 8]]]),
'use_kwargs': True},
'use_kwargs': ('x', 'y')},
{'x': ragged.constant_value([[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
ragged_rank=1),
'y': ragged.constant_value([[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
ragged_rank=1),
'use_kwargs': True},
'use_kwargs': ('x', 'y')},
{'x': ragged.constant_value([[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
ragged_rank=1),
'y': ragged.constant_value([[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
ragged_rank=1),
'use_kwargs': ('x',)},
] +
#=========================================================================
# Test each unary op.
@ -306,16 +312,16 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
[{'x': ragged.constant_value([[True, True], [False]]),
'y': ragged.constant_value([[False, True], [False]]),
'op': op}
for op in BINARY_BOOL_OPS] +
[
]
for op in BINARY_BOOL_OPS]
) # pyformat: disable
def testBinaryOp(self, x, y, op=ragged.add, **extra_args):
use_kwargs = extra_args.pop('use_kwargs', False)
def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args):
use_kwargs = extra_args.pop('use_kwargs', ())
x = ragged.convert_to_tensor_or_ragged_tensor(x)
y = ragged.convert_to_tensor_or_ragged_tensor(y)
if use_kwargs:
if 'x' in use_kwargs and 'y' in use_kwargs:
result = op(x=x, y=y, **extra_args)
elif 'y' in use_kwargs:
result = op(x, y=y, **extra_args)
else:
result = op(x, y, **extra_args)
@ -323,7 +329,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
dense_x = x.inner_values if isinstance(x, ragged.RaggedTensor) else x
dense_y = y.inner_values if isinstance(y, ragged.RaggedTensor) else y
expected_flat_values = array_ops.reshape(
op.__wrapped__(dense_x, dense_y, **extra_args), [-1])
op(dense_x, dense_y, **extra_args), [-1])
with self.test_session():
# Check that the result has the expected shape.
@ -358,16 +364,17 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
ragged.constant_value([[[2, 9], [12]], [[8]]])),
'use_kwargs': True},
] + [
{'op': ragged.add_n,
{'op': math_ops.add_n,
'inputs': (ragged.constant_value([[1, 3], [-3]]),
ragged.constant_value([[4, 7], [88]]),
ragged.constant_value([[2, 9], [12]]))},
{'op': ragged.string_join,
{'op': string_ops.string_join,
'inputs': (ragged.constant_value([['a', 'b'], ['c']]),
ragged.constant_value([['foo', 'bar'], ['baz']]),
ragged.constant_value([['2', '9'], ['12']]))},
]) # pyformat: disable
def testListValuedOp(self, inputs, op=ragged.add_n, **extra_args):
def testListValuedElementwiseOp(self, inputs, op=math_ops.add_n,
**extra_args):
use_kwargs = extra_args.pop('use_kwargs', False)
inputs = [ragged.convert_to_tensor_or_ragged_tensor(x) for x in inputs]
if use_kwargs:
@ -381,7 +388,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
for x in inputs
]
expected_flat_values = array_ops.reshape(
op.__wrapped__(dense_inputs, **extra_args), [-1])
op(dense_inputs, **extra_args), [-1])
with self.test_session():
# Check that the result has the expected shape.
@ -395,13 +402,13 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
self.assertAllEqual(expected_flat_values, result_flat_values)
@test_util.run_deprecated_v1
def testUnknownRankError(self):
def testElementwiseOpUnknownRankError(self):
x = ragged.constant([[1, 2], [3]])
y = ragged.from_row_splits(
array_ops.placeholder_with_default([1, 2, 3], shape=None), x.row_splits)
with self.assertRaisesRegexp(
ValueError, r'Unable to broadcast: unknown rank'):
ragged.add(x, y)
with self.assertRaisesRegexp(ValueError,
r'Unable to broadcast: unknown rank'):
math_ops.add(x, y)
@parameterized.parameters([
dict(
@ -417,26 +424,31 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
y=ragged.constant_value([[1]]),
expected=[[[2]]]),
])
def testBroadcastAdd(self, x, y, expected):
def testElementwiseOpBroadcast(self, x, y, expected):
x = ragged.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32)
y = ragged.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32)
result = x + y
with self.cached_session():
self.assertEqual(result.eval().tolist(), expected)
def testShapeMismatch(self):
def testElementwiseOpShapeMismatch(self):
x = ragged.constant([[1, 2, 3], [4, 5]])
y = ragged.constant([[1, 2, 3], [4, 5, 6]])
with self.assertRaisesRegexp(errors.InvalidArgumentError,
'Incompatible shapes'):
with self.cached_session():
ragged.add(x, y).eval()
math_ops.add(x, y).eval()
def testDocstring(self):
self.assertRegexpMatches(
ragged.add.__doc__,
'Ragged version of the elementwise operation `tf.math.add`')
self.assertEqual(ragged.add.__name__, 'add')
def testBinaryOpSparseAndRagged(self):
x = ragged.constant([[1, 2, 3], [4, 5]])
y = sparse_tensor.SparseTensor([[0, 0], [0, 1], [2, 0]], [1, 2, 3], [3, 2])
with self.assertRaises(TypeError):
with self.cached_session():
math_ops.add(x, y).eval()
with self.assertRaises(TypeError):
with self.cached_session():
math_ops.add_n([x, y]).eval()
if __name__ == '__main__':

View File

@ -1,389 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Elementwise operations for RaggedTensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_shape
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_export
from tensorflow.python.util import tf_inspect
# Information about an argument to an operation: The name of the argument, its
# position in the argument list, and a boolean flag indicating whether it
# expects a list of tensors.
_ArgInfo = collections.namedtuple('ArgInfo', ['name', 'position', 'is_list'])
def make_elementwise_op(op, *elementwise_args):
"""Returns a ragged-tensor version of the elementwise operation `op`.
The returned operation will:
1. Broadcast the elementwise arguments to have a compatible shape.
An exception is raised if the tensors not broadcast-compatible.
2. Call `op`, substituting the dense values of the broadcasted tensor for
each elementwise argument.
3. Return a potentially ragged tensor constructed from the output of `op`
and the broadcasted tensors' nested row splits.
For example, you can construct a ragged-tensor version of the standard
operation `tf.add` by calling `make_elementwise_op(tf.add, 'x', 'y')`.
Args:
op: The operation to wrap.
*elementwise_args: The names of arguments to `op` that are treated as
elementwise. Arguments that take a list of tensors should have their
names wrapped in square brackets (e.g. "[inputs]").
Raises:
ValueError: If any name specified in `elementwise_args` is not the name
of an argument to `op`.
"""
elementwise_arg_infos = _get_arg_infos(op, elementwise_args)
def ragged_op(*args, **kwargs):
"""Ragged version of `op`."""
args = list(args)
# Collect all of the elementwise arguments, and put them in a single
# dict whose values are the (potentially ragged) tensors that need to
# be broadcast to a common shape. The keys of this dict are tuples
# (argkey, index), where argkey is an int for poitional args or a string
# for keyword args; and index is None for non-list args and the index of the
# tensor for list args.
elementwise_args = {}
for (name, position, is_list) in elementwise_arg_infos.values():
if position < len(args):
if is_list:
args[position] = list(args[position])
for (index, arg) in enumerate(args[position]):
elementwise_args[position, index] = arg
else:
elementwise_args[position, None] = args[position]
elif name in kwargs:
if is_list:
kwargs[name] = list(kwargs[name])
for (i, arg) in enumerate(kwargs[name]):
elementwise_args[name, i] = arg
else:
elementwise_args[name, None] = kwargs[name]
with ops.name_scope(None, op.__name__, elementwise_args.values()):
# Convert all inputs to tensors or ragged tensors.
for ((key, index), tensor) in elementwise_args.items():
argname = elementwise_arg_infos[key].name
converted = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
tensor, name=argname)
elementwise_args[key, index] = converted
# Broadcast tensors to have compatible shapes.
broadcast_args, result_splits, broadcast_check_ops = \
_broadcast_elementwise_args(elementwise_args)
# Replace tensor arguments with their dense values.
for ((key, index), tensor) in broadcast_args.items():
if ragged_tensor.is_ragged(tensor):
if isinstance(key, int) and index is None:
args[key] = tensor.inner_values
elif isinstance(key, int) and index is not None:
args[key][index] = tensor.inner_values
elif isinstance(key, str) and index is None:
kwargs[key] = tensor.inner_values
else:
assert isinstance(key, str) and index is not None
kwargs[key][index] = tensor.inner_values
# Call the elementwise op on the broadcasted dense values.
with ops.control_dependencies(broadcast_check_ops):
result_values = op(*args, **kwargs)
# Restore any ragged dimensions that we stripped off, and return the
# result.
return ragged_factory_ops.from_nested_row_splits(result_values,
result_splits)
# Construct the docstring.
op_name = tf_export.get_canonical_name_for_symbol(op)
assert op_name is not None, op
argnames = ', '.join('`%s`' % s.strip('[]') for s in elementwise_args)
docstring = _ELEMENTWISE_DOCSTRING % dict(op_name=op_name, argnames=argnames)
# Update name, docstring, signature, etc., for the wrapper, and return it.
return tf_decorator.make_decorator(op, ragged_op, decorator_doc=docstring)
_ELEMENTWISE_DOCSTRING = """\
Ragged version of the elementwise operation `tf.%(op_name)s`.
The following elementwise arguments may be ragged or dense:
%(argnames)s.
These arguments will be broadcast to a compatible shape if necessary.
"""
def _get_arg_infos(func, elementwise_args):
"""Returns `_ArgInfo`s for each `func` arg specified by `elementwise_args`.
Args:
func: The function whose arguments should be described.
elementwise_args: The names of the arguments to get info for.
Returns:
A dictionary that maps both names and positions of arguments to
`_ArgInfo` tuples.
"""
arg_infos = {}
# Inspect the func's argspec to find the position of each arg.
arg_spec = tf_inspect.getargspec(func)
for argname in elementwise_args:
assert isinstance(argname, str)
is_list = argname.startswith('[') and argname.endswith(']')
if is_list:
argname = argname[1:-1]
assert argname in arg_spec.args, (func, argname, arg_spec.args)
arg_info = _ArgInfo(argname, arg_spec.args.index(argname), is_list)
arg_infos[arg_info.name] = arg_info
arg_infos[arg_info.position] = arg_info
return arg_infos
def _broadcast_elementwise_args(elementwise_args):
"""Broadcasts the values of `elementwise_args` to have compatible shapes.
Args:
elementwise_args: A dictionary whose keys are potentially ragged tensors.
Returns:
A tuple `(broadcast_args, broadcast_splits, checks)` where:
* `broadcast_args` is a dictionary with the same keys as
`elementwise_args`, mapping to broadcasted tensors.
* `broadcast_splits` is the broadcasted nested row splits.
* `checks` is a possibly empty tuple of assertion operations that should
be added as control dependencies.
Raises:
ValueError: If broadcasting fails.
"""
# No elementwise arguments were used: nothing to do!
if not elementwise_args:
return elementwise_args, (), ()
# A single elementwise argument was used: no broadcasting necessary.
if len(elementwise_args) == 1:
arg = list(elementwise_args.values())[0]
if ragged_tensor.is_ragged(arg):
return elementwise_args, arg.nested_row_splits, ()
else:
return elementwise_args, (), ()
# Multiple elementwise arguments.
else:
is_ragged = [ragged_tensor.is_ragged(t) for t in elementwise_args.values()]
if not any(is_ragged):
return elementwise_args, (), ()
# If we have a single ragged tensor plus a set of scalars, then we can
# rely on the underlying elementwise op to do broadcasting.
if (sum(is_ragged) == 1 and
all((ragged_tensor.is_ragged(t) or t.shape.ndims == 0)
for t in elementwise_args.values())):
nested_splits_lists = [
t.nested_row_splits
for t in elementwise_args.values()
if ragged_tensor.is_ragged(t)][0]
return elementwise_args, nested_splits_lists, ()
else:
# Get the shapes of all the elementwise arguments.
shapes = [ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(t)
for t in elementwise_args.values()]
# Broadcast the shapes to all have the same rank (the max rank).
ranks = [t.shape.ndims for t in elementwise_args.values()]
if any(rank is None for rank in ranks):
raise ValueError('Unable to broadcast: unknown rank')
broadcast_rank = max(ranks)
shapes = [shape.broadcast_to_rank(broadcast_rank) for shape in shapes]
# For each dimension, broadcast the shapes to be compatible.
for axis in range(broadcast_rank):
# For each i, broadcast shape[i+1] to be compatible with shape[i]; and
# then finally broadcast shape[0] to be compatible with shape[-1].
for i in range(len(shapes)):
j = (i + 1) % len(shapes)
dim_size = shapes[i].dimension_size(axis)
shapes[j] = shapes[j].broadcast_dimension(axis, dim_size)
broadcast_shape = shapes[0]
# Broadcast every elementwise arg to the shape that we calculated.
elementwise_args = dict([
(key, ragged_tensor_shape.broadcast_to(t, broadcast_shape, False))
for (key, t) in elementwise_args.items()])
nested_splits_lists = list(elementwise_args.values())[0].nested_row_splits
return elementwise_args, nested_splits_lists, ()
# A list of symbols that should be exported in the "ragged" package.
_symbols_to_export = []
def _add_elementwise_ops_to_this_module(specs, verbose=False):
"""Adds ragged versions of the given ops to this module.
Args:
specs: A list of tuples containing the arguments for `make_elementwise_op`.
verbose: If true, then display each op that gets added.
"""
for spec in specs:
original_op = spec[0]
ragged_op = make_elementwise_op(*spec)
canonical_name = tf_export.get_canonical_name_for_symbol(original_op)
if '.' not in canonical_name:
op_name = canonical_name
else:
op_name = original_op.__name__
# Temporary hack (will be removed once dispatch is added for RaggedTensors):
if op_name == 'neg': op_name = 'negative'
if verbose:
print('Adding ragged_elementwise_op: tf.ragged.%s (based on tf.%s)' %
(op_name, canonical_name))
globals()[op_name] = ragged_op
_symbols_to_export.append(op_name)
# A list of tuples containing arguments for `make_elementwise_op`, for each
# elementwise operation that should have a ragged version built. Each tuple
# contains a standard `Tensor` operation, and the names of any arguments
# that are processed in elementwise fashion.
_TF_ELEMENTWISE_OPS = [
# Unary math operations.
(clip_ops.clip_by_value, 't'),
(math_ops.abs, 'x'),
(math_ops.acos, 'x'),
(math_ops.acosh, 'x'),
(math_ops.angle, 'input'),
(math_ops.asin, 'x'),
(math_ops.asinh, 'x'),
(math_ops.atan, 'x'),
(math_ops.atanh, 'x'),
(math_ops.cast, 'x'),
(math_ops.ceil, 'x'),
(math_ops.conj, 'x'),
(math_ops.cos, 'x'),
(math_ops.cosh, 'x'),
(math_ops.digamma, 'x'),
(math_ops.erf, 'x'),
(math_ops.erfc, 'x'),
(math_ops.exp, 'x'),
(math_ops.expm1, 'x'),
(math_ops.floor, 'x'),
(math_ops.imag, 'input'),
(math_ops.is_finite, 'x'),
(math_ops.is_inf, 'x'),
(math_ops.is_nan, 'x'),
(math_ops.lgamma, 'x'),
(math_ops.log, 'x'),
(math_ops.log1p, 'x'),
(math_ops.log_sigmoid, 'x'),
(math_ops.logical_not, 'x'),
(math_ops.negative, 'x'),
(math_ops.real, 'input'),
(math_ops.reciprocal, 'x'),
(math_ops.rint, 'x'),
(math_ops.round, 'x'),
(math_ops.rsqrt, 'x'),
(math_ops.saturate_cast, 'value'),
(math_ops.sign, 'x'),
(math_ops.sin, 'x'),
(math_ops.sinh, 'x'),
(math_ops.sqrt, 'x'),
(math_ops.square, 'x'),
(math_ops.tan, 'x'),
# Binary math operations
(math_ops.add, 'x', 'y'),
(math_ops.atan2, 'y', 'x'),
(math_ops.complex, 'real', 'imag'),
(math_ops.div, 'x', 'y'),
(math_ops.div_no_nan, 'x', 'y'),
(math_ops.divide, 'x', 'y'),
(math_ops.equal, 'x', 'y'),
(math_ops.floordiv, 'x', 'y'),
(math_ops.floormod, 'x', 'y'),
(math_ops.greater, 'x', 'y'),
(math_ops.greater_equal, 'x', 'y'),
(math_ops.less, 'x', 'y'),
(math_ops.less_equal, 'x', 'y'),
(math_ops.logical_and, 'x', 'y'),
(math_ops.logical_or, 'x', 'y'),
(math_ops.logical_xor, 'x', 'y'),
(math_ops.maximum, 'x', 'y'),
(math_ops.minimum, 'x', 'y'),
(math_ops.multiply, 'x', 'y'),
(math_ops.not_equal, 'x', 'y'),
(math_ops.pow, 'x', 'y'),
(math_ops.realdiv, 'x', 'y'),
(math_ops.squared_difference, 'x', 'y'),
(math_ops.subtract, 'x', 'y'),
(math_ops.truediv, 'x', 'y'),
(math_ops.truncatediv, 'x', 'y'),
(math_ops.truncatemod, 'x', 'y'),
# N-ary math operations
(math_ops.add_n, '[inputs]'),
# String operations
(string_ops.as_string, 'input'),
(string_ops.decode_base64, 'input'),
(string_ops.encode_base64, 'input'),
(string_ops.regex_full_match, 'input'),
(string_ops.regex_replace, 'input'),
(string_ops.string_join, '[inputs]'),
(string_ops.string_strip, 'input'),
(string_ops.string_to_hash_bucket, 'input'),
(string_ops.string_to_hash_bucket_fast, 'input'),
(string_ops.string_to_hash_bucket_strong, 'input'),
(string_ops.substr, 'input'),
(string_ops.unicode_script, 'input'),
# Array ops
(array_ops.check_numerics, 'tensor'),
(array_ops.identity, 'input'),
(array_ops.ones_like, 'tensor'),
(array_ops.zeros_like, 'tensor'),
# Parsing ops
(parsing_ops.decode_compressed, 'bytes'),
(parsing_ops.string_to_number, 'string_tensor'),
]
_add_elementwise_ops_to_this_module(_TF_ELEMENTWISE_OPS)

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
@ -55,7 +56,7 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
),
# [d1, (d2)] -> [d1, (d2)]
dict(
fn=lambda x: x+1,
fn=lambda x: x + np.int64(1),
elems=[[1, 2, 3], [4, 5], [6, 7]],
expected_output=[[2, 3, 4], [5, 6], [7, 8]],
dtype=dtypes.int64,
@ -64,7 +65,7 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
),
# [d1, (d2), d3] -> [d1, (d2), d3]
dict(
fn=lambda x: x+1,
fn=lambda x: x + np.int64(1),
elems=[[[1, 2], [3, 4]], [], [[5, 6], [7, 8], [9, 0]]],
elems_ragged_rank=1,
expected_ragged_rank=1,
@ -131,7 +132,7 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
),
# [d1, (d2), (d3), (d4a), (d5)] -> [d1, (d2), (d3), (d4b), (d5)]
dict(
fn=lambda x: ragged.add(x, 1),
fn=lambda x: x + np.int64(1),
elems=[[[[[1, 2, 3]], [[4], [5]]]], [[[[6, 7]]], [[[8], []]]]],
expected_output=[[[[[2, 3, 4]], [[5], [6]]]],
[[[[7, 8]]], [[[9], []]]]],
@ -196,8 +197,8 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def _increment(f):
return {
'batman': ragged.add(f['batman'], 1),
'robin': ragged.add(f['robin'], 1),
'batman': f['batman'] + 1,
'robin': f['robin'] + 1,
}
output = ragged.map_fn(

View File

@ -143,8 +143,11 @@ Computes the %(combination)s along segments of a RaggedTensor.
"""
def _ragged_segment_aggregate(unsorted_segment_op, data, segment_ids,
num_segments, name=None):
def _ragged_segment_aggregate(unsorted_segment_op,
data,
segment_ids,
num_segments,
name=None):
"""Aggregates along segments of a RaggedTensor using `unsorted_segment_op`.
Returns a RaggedTensor `output` with `num_segments` rows, where the row
@ -212,12 +215,11 @@ def _ragged_segment_aggregate(unsorted_segment_op, data, segment_ids,
assert output_row_lengths.dtype == dtypes.int64
# Build the splits tensor for the output RaggedTensor.
output_splits = array_ops.concat(
[
array_ops.zeros([1], dtypes.int64),
math_ops.cumsum(output_row_lengths)
],
axis=0)
output_splits = array_ops.concat([
array_ops.zeros([1], dtypes.int64),
math_ops.cumsum(output_row_lengths)
],
axis=0)
# For each row in `data`, find the start & limit position where that row's
# values will be aggregated in output.values.
@ -311,7 +313,7 @@ _set_ragged_segment_docstring(segment_sqrt_n, 'sum divided by sqrt(N)',
_RAGGED_REDUCE_DOCSTRING = """\
Computes the %(combination)s of elements across dimensions of a `RaggedTensor`.
Reduces `rt_input` along the dimensions given in `axis` by taking the
Reduces `input_tensor` along the dimensions given in `axis` by taking the
%(combination)s of values. If a reduced dimension has no elements for
some index, then the value for that index will be %(default)s.
@ -319,18 +321,18 @@ Computes the %(combination)s of elements across dimensions of a `RaggedTensor`.
`axis` is not specified, then all dimensions are reduced, and a scalar
value is returned.
Args:
rt_input: A `RaggedTensor` containing the values to be %(combined)s.
input_tensor: A `RaggedTensor` containing the values to be %(combined)s.
axis: The dimensions to reduce. May be `None` (to reduce all axes), an
`int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce
a given set of axes), or a `Tensor` with a constant value. Must be in
the range `[0, rt_input.rank]`.
the range `[0, input_tensor.rank]`.
name: A name prefix for the returned tensor (optional).
Returns:
A `RaggedTensor` containing the %(combined)s values. The returned tensor
has the same dtype as `data`, and its shape is given by removing the
dimensions specified in `axis` from `rt_input.shape`. The `ragged_rank`
dimensions specified in `axis` from `input_tensor.shape`. The `ragged_rank`
of the returned tensor is given by substracting any ragged dimensions
specified in `axis` from `rt_input.ragged_rank`.
specified in `axis` from `input_tensor.ragged_rank`.
Raises:
ValueError: If `axis` contains a `Tensor` whose value is not constant.
####Example:
@ -387,7 +389,11 @@ _RAGGED_REDUCE_ANY_EXAMPLE = """
"""
def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis,
def _ragged_reduce_aggregate(reduce_op,
unsorted_segment_op,
rt_input,
axis,
keepdims,
name=None):
"""Aggregates across axes of a RaggedTensor using the given `Tensor` ops.
@ -412,6 +418,7 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis,
`int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce a
given set of axes), or a `Tensor` with a constant value. Must be in the
range `[0, rt_input.rank)`.
keepdims: If true, retains reduced dimensions with length 1.
name: A name prefix for the returned tensor (optional).
Returns:
@ -426,6 +433,9 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis,
if not ragged_tensor.is_ragged(rt_input):
return reduce_op(rt_input, axis, name=name)
if keepdims:
raise ValueError('keepdims=True is not supported for RaggedTensors.')
if isinstance(axis, ops.Tensor):
axis = tensor_util.constant_value(axis)
if axis is None:
@ -448,9 +458,9 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis,
# once will probably require a nontrivial c++ op.
axis = sorted(axis)
inner_reduced = _ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
rt_input, axis[-1])
rt_input, axis[-1], keepdims)
return _ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
inner_reduced, axis[:-1])
inner_reduced, axis[:-1], keepdims)
axis = ragged_util.get_positive_axis(axis, rt_input.shape.ndims)
@ -476,48 +486,48 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis,
# sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N]
return rt_input.with_values(
_ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
rt_input.values, axis - 1))
rt_input.values, axis - 1, keepdims))
def reduce_sum(rt_input, axis=None, name=None):
def reduce_sum(input_tensor, axis=None, keepdims=None, name=None):
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
return _ragged_reduce_aggregate(math_ops.reduce_sum,
math_ops.unsorted_segment_sum, rt_input, axis,
name or 'RaggedReduceSum')
math_ops.unsorted_segment_sum, input_tensor,
axis, keepdims, name or 'RaggedReduceSum')
def reduce_prod(rt_input, axis=None, name=None):
def reduce_prod(input_tensor, axis=None, keepdims=None, name=None):
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
return _ragged_reduce_aggregate(math_ops.reduce_prod,
math_ops.unsorted_segment_prod, rt_input,
axis, name or 'RaggedReduceProd')
math_ops.unsorted_segment_prod, input_tensor,
axis, keepdims, name or 'RaggedReduceProd')
def reduce_min(rt_input, axis=None, name=None):
def reduce_min(input_tensor, axis=None, keepdims=None, name=None):
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
return _ragged_reduce_aggregate(math_ops.reduce_min,
math_ops.unsorted_segment_min, rt_input, axis,
name or 'RaggedReduceMin')
math_ops.unsorted_segment_min, input_tensor,
axis, keepdims, name or 'RaggedReduceMin')
def reduce_max(rt_input, axis=None, name=None):
def reduce_max(input_tensor, axis=None, keepdims=None, name=None):
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
return _ragged_reduce_aggregate(math_ops.reduce_max,
math_ops.unsorted_segment_max, rt_input, axis,
name or 'RaggedReduceMax')
math_ops.unsorted_segment_max, input_tensor,
axis, keepdims, name or 'RaggedReduceMax')
def reduce_mean(rt_input, axis=None, name=None):
def reduce_mean(input_tensor, axis=None, keepdims=None, name=None):
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
with ops.name_scope(name, 'RaggedReduceMean', [rt_input, axis]):
total = reduce_sum(rt_input, axis)
if ragged_tensor.is_ragged(rt_input):
with ops.name_scope(name, 'RaggedReduceMean', [input_tensor, axis]):
total = reduce_sum(input_tensor, axis, keepdims)
if ragged_tensor.is_ragged(input_tensor):
ones = ragged_factory_ops.from_nested_row_splits(
array_ops.ones_like(rt_input.inner_values),
rt_input.nested_row_splits)
array_ops.ones_like(input_tensor.inner_values),
input_tensor.nested_row_splits)
else:
ones = array_ops.ones_like(rt_input)
count = reduce_sum(ones, axis)
ones = array_ops.ones_like(input_tensor)
count = reduce_sum(ones, axis, keepdims)
if ragged_tensor.is_ragged(total):
return ragged_factory_ops.from_nested_row_splits(
total.inner_values / count.inner_values, total.nested_row_splits)
@ -525,20 +535,25 @@ def reduce_mean(rt_input, axis=None, name=None):
return total / count
def _cast(rt_input, dtype):
return ragged_functional_ops.map_inner_values(math_ops.cast, rt_input, dtype)
def _cast(input_tensor, dtype):
return ragged_functional_ops.map_inner_values(math_ops.cast, input_tensor,
dtype)
def reduce_all(rt_input, axis=None, name=None):
def reduce_all(input_tensor, axis=None, keepdims=None, name=None):
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
with ops.name_scope(name, 'RaggedReduceAll', [rt_input, axis]):
return _cast(reduce_prod(_cast(rt_input, dtypes.int32), axis), dtypes.bool)
with ops.name_scope(name, 'RaggedReduceAll', [input_tensor, axis]):
return _cast(
reduce_prod(_cast(input_tensor, dtypes.int32), axis, keepdims),
dtypes.bool)
def reduce_any(rt_input, axis=None, name=None):
def reduce_any(input_tensor, axis=None, keepdims=None, name=None):
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
with ops.name_scope(name, 'RaggedReduceAny', [rt_input, axis]):
return _cast(reduce_sum(_cast(rt_input, dtypes.int32), axis), dtypes.bool)
with ops.name_scope(name, 'RaggedReduceAny', [input_tensor, axis]):
return _cast(
reduce_sum(_cast(input_tensor, dtypes.int32), axis, keepdims),
dtypes.bool)
def _set_ragged_reduce_docstring(func, combination, combined, default, example):
@ -554,9 +569,11 @@ _set_ragged_reduce_docstring(reduce_sum, 'sum', 'summed', '0',
_set_ragged_reduce_docstring(reduce_prod, 'product', 'multiplied', '1',
_RAGGED_REDUCE_PROD_EXAMPLE)
_set_ragged_reduce_docstring(reduce_min, 'minimum', 'minimized',
'`rt_input.dtype.min`', _RAGGED_REDUCE_MIN_EXAMPLE)
'`input_tensor.dtype.min`',
_RAGGED_REDUCE_MIN_EXAMPLE)
_set_ragged_reduce_docstring(reduce_max, 'maximum', 'maximized',
'`rt_input.dtype.max`', _RAGGED_REDUCE_MAX_EXAMPLE)
'`input_tensor.dtype.max`',
_RAGGED_REDUCE_MAX_EXAMPLE)
_set_ragged_reduce_docstring(reduce_mean, 'mean', 'averaged', 'NaN',
_RAGGED_REDUCE_MEAN_EXAMPLE)

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.ops.ragged import ragged_elementwise_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_getitem
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util import tf_decorator
@ -33,40 +33,39 @@ def _right(operator):
ragged_tensor.RaggedTensor.__getitem__ = ragged_getitem.ragged_tensor_getitem
# Ordering operators
ragged_tensor.RaggedTensor.__ge__ = ragged_elementwise_ops.greater_equal
ragged_tensor.RaggedTensor.__gt__ = ragged_elementwise_ops.greater
ragged_tensor.RaggedTensor.__le__ = ragged_elementwise_ops.less_equal
ragged_tensor.RaggedTensor.__lt__ = ragged_elementwise_ops.less
ragged_tensor.RaggedTensor.__ge__ = math_ops.greater_equal
ragged_tensor.RaggedTensor.__gt__ = math_ops.greater
ragged_tensor.RaggedTensor.__le__ = math_ops.less_equal
ragged_tensor.RaggedTensor.__lt__ = math_ops.less
# Logical operators
ragged_tensor.RaggedTensor.__and__ = ragged_elementwise_ops.logical_and
ragged_tensor.RaggedTensor.__rand__ = _right(ragged_elementwise_ops.logical_and)
ragged_tensor.RaggedTensor.__invert__ = ragged_elementwise_ops.logical_not
ragged_tensor.RaggedTensor.__ror__ = _right(ragged_elementwise_ops.logical_or)
ragged_tensor.RaggedTensor.__or__ = ragged_elementwise_ops.logical_or
ragged_tensor.RaggedTensor.__xor__ = ragged_elementwise_ops.logical_xor
ragged_tensor.RaggedTensor.__rxor__ = _right(ragged_elementwise_ops.logical_xor)
ragged_tensor.RaggedTensor.__and__ = math_ops.logical_and
ragged_tensor.RaggedTensor.__rand__ = _right(math_ops.logical_and)
ragged_tensor.RaggedTensor.__invert__ = math_ops.logical_not
ragged_tensor.RaggedTensor.__ror__ = _right(math_ops.logical_or)
ragged_tensor.RaggedTensor.__or__ = math_ops.logical_or
ragged_tensor.RaggedTensor.__xor__ = math_ops.logical_xor
ragged_tensor.RaggedTensor.__rxor__ = _right(math_ops.logical_xor)
# Arithmetic operators
ragged_tensor.RaggedTensor.__abs__ = ragged_elementwise_ops.abs
ragged_tensor.RaggedTensor.__add__ = ragged_elementwise_ops.add
ragged_tensor.RaggedTensor.__radd__ = _right(ragged_elementwise_ops.add)
ragged_tensor.RaggedTensor.__div__ = ragged_elementwise_ops.div
ragged_tensor.RaggedTensor.__rdiv__ = _right(ragged_elementwise_ops.div)
ragged_tensor.RaggedTensor.__floordiv__ = ragged_elementwise_ops.floordiv
ragged_tensor.RaggedTensor.__rfloordiv__ = _right(
ragged_elementwise_ops.floordiv)
ragged_tensor.RaggedTensor.__mod__ = ragged_elementwise_ops.floormod
ragged_tensor.RaggedTensor.__rmod__ = _right(ragged_elementwise_ops.floormod)
ragged_tensor.RaggedTensor.__mul__ = ragged_elementwise_ops.multiply
ragged_tensor.RaggedTensor.__rmul__ = _right(ragged_elementwise_ops.multiply)
ragged_tensor.RaggedTensor.__neg__ = ragged_elementwise_ops.negative
ragged_tensor.RaggedTensor.__pow__ = ragged_elementwise_ops.pow
ragged_tensor.RaggedTensor.__rpow__ = _right(ragged_elementwise_ops.pow)
ragged_tensor.RaggedTensor.__sub__ = ragged_elementwise_ops.subtract
ragged_tensor.RaggedTensor.__rsub__ = _right(ragged_elementwise_ops.subtract)
ragged_tensor.RaggedTensor.__truediv__ = ragged_elementwise_ops.truediv
ragged_tensor.RaggedTensor.__rtruediv__ = _right(ragged_elementwise_ops.truediv)
ragged_tensor.RaggedTensor.__abs__ = math_ops.abs
ragged_tensor.RaggedTensor.__add__ = math_ops.add
ragged_tensor.RaggedTensor.__radd__ = _right(math_ops.add)
ragged_tensor.RaggedTensor.__div__ = math_ops.div
ragged_tensor.RaggedTensor.__rdiv__ = _right(math_ops.div)
ragged_tensor.RaggedTensor.__floordiv__ = math_ops.floordiv
ragged_tensor.RaggedTensor.__rfloordiv__ = _right(math_ops.floordiv)
ragged_tensor.RaggedTensor.__mod__ = math_ops.floormod
ragged_tensor.RaggedTensor.__rmod__ = _right(math_ops.floormod)
ragged_tensor.RaggedTensor.__mul__ = math_ops.multiply
ragged_tensor.RaggedTensor.__rmul__ = _right(math_ops.multiply)
ragged_tensor.RaggedTensor.__neg__ = math_ops.negative
ragged_tensor.RaggedTensor.__pow__ = math_ops.pow
ragged_tensor.RaggedTensor.__rpow__ = _right(math_ops.pow)
ragged_tensor.RaggedTensor.__sub__ = math_ops.subtract
ragged_tensor.RaggedTensor.__rsub__ = _right(math_ops.subtract)
ragged_tensor.RaggedTensor.__truediv__ = math_ops.truediv
ragged_tensor.RaggedTensor.__rtruediv__ = _right(math_ops.truediv)
# Dummy methods

View File

@ -2696,6 +2696,7 @@ class _UnaryMapValueDispatcher(dispatch.OpDispatcher):
if args:
x, args = args[0], args[1:]
else:
kwargs = kwargs.copy()
x = kwargs.pop(self._x, None)
if isinstance(x, sparse_tensor.SparseTensor):
return sparse_tensor.SparseTensor(

View File

@ -38,6 +38,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops.gen_string_ops import *
from tensorflow.python.util import compat as util_compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=g-bad-import-order
# pylint: enable=wildcard-import
@ -45,6 +46,7 @@ from tensorflow.python.util.tf_export import tf_export
# pylint: disable=redefined-builtin
@tf_export("strings.regex_full_match")
@dispatch.add_dispatch_support
def regex_full_match(input, pattern, name=None):
r"""Match elements of `input` with regex `pattern`.
@ -76,6 +78,7 @@ regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__
@tf_export(
"strings.regex_replace", v1=["strings.regex_replace", "regex_replace"])
@deprecation.deprecated_endpoints("regex_replace")
@dispatch.add_dispatch_support
def regex_replace(input, pattern, rewrite, replace_global=True, name=None):
r"""Replace elements of `input` matching regex `pattern` with `rewrite`.
@ -350,10 +353,13 @@ reduce_join.__doc__ = reduce_join.__doc__.replace("tf.reduce_join(",
# This wrapper provides backwards compatibility for code that predates the
# unit argument and that passed 'name' as a positional argument.
@tf_export(v1=["strings.length"])
@dispatch.add_dispatch_support
def string_length(input, name=None, unit="BYTE"):
return gen_string_ops.string_length(input, unit=unit, name=name)
@tf_export("strings.length", v1=[])
@dispatch.add_dispatch_support
def string_length_v2(input, unit="BYTE", name=None):
return string_length(input, name, unit)
@ -370,11 +376,13 @@ substr_deprecated.__doc__ = gen_string_ops.substr.__doc__
@tf_export(v1=["strings.substr"])
@dispatch.add_dispatch_support
def substr(input, pos, len, name=None, unit="BYTE"):
return gen_string_ops.substr(input, pos, len, unit=unit, name=name)
@tf_export("strings.substr", v1=[])
@dispatch.add_dispatch_support
def substr_v2(input, pos, len, unit="BYTE", name=None):
return substr(input, pos, len, name=name, unit=unit)
@ -395,6 +403,7 @@ ops.NotDifferentiable("DecodeBase64")
@tf_export("strings.to_number", v1=[])
@dispatch.add_dispatch_support
def string_to_number(input, out_type=dtypes.float32, name=None):
r"""Converts each string in the input Tensor to the specified numeric type.
@ -418,6 +427,7 @@ tf_export(v1=["strings.to_number", "string_to_number"])(
@tf_export("strings.to_hash_bucket", v1=[])
@dispatch.add_dispatch_support
def string_to_hash_bucket(input, num_buckets, name=None):
# pylint: disable=line-too-long
r"""Converts each string in the input Tensor to its hash mod by a number of buckets.

View File

@ -166,15 +166,14 @@ def dispatch_for_types(op, *types):
def add_dispatch_list(target):
"""Decorator that adds a dispatch_list attribute to an op."""
assert not hasattr(target, DISPATCH_ATTR)
if hasattr(target, DISPATCH_ATTR):
raise AssertionError("%s already has a dispatch list" % target)
setattr(target, DISPATCH_ATTR, [])
return target
def add_dispatch_support(target):
"""Decorator that adds a dispatch handling wrapper to an op."""
add_dispatch_list(target)
def wrapper(*args, **kwargs):
"""Call target, and fall back on dispatchers if there is a TypeError."""
try:
@ -188,5 +187,5 @@ def add_dispatch_support(target):
else:
raise
setattr(wrapper, DISPATCH_ATTR, [])
add_dispatch_list(wrapper)
return tf_decorator.make_decorator(target, wrapper)