Internal Change
PiperOrigin-RevId: 224225849
This commit is contained in:
parent
2b0fd9b66a
commit
45cfe71266
@ -1,4 +1,6 @@
|
||||
op {
|
||||
graph_op_name: "FloorDiv"
|
||||
visibility: HIDDEN
|
||||
endpoint {
|
||||
name: "floor_div"
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,9 @@
|
||||
op {
|
||||
graph_op_name: "FloorMod"
|
||||
visibility: HIDDEN
|
||||
endpoint {
|
||||
name: "floormod"
|
||||
}
|
||||
endpoint {
|
||||
name: "mod"
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,6 @@
|
||||
op {
|
||||
graph_op_name: "RealDiv"
|
||||
visibility: HIDDEN
|
||||
endpoint {
|
||||
name: "realdiv"
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,6 @@
|
||||
op {
|
||||
graph_op_name: "TruncateDiv"
|
||||
visibility: HIDDEN
|
||||
endpoint {
|
||||
name: "truncatediv"
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,6 @@
|
||||
op {
|
||||
graph_op_name: "TruncateMod"
|
||||
visibility: HIDDEN
|
||||
endpoint {
|
||||
name: "truncatemod"
|
||||
}
|
||||
}
|
||||
|
@ -634,7 +634,9 @@ void GenEagerPythonOp::AddEagerFunctionTeardown(
|
||||
bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
|
||||
const string& parameters, const std::vector<string>& output_sizes,
|
||||
const string& eager_not_allowed_error) {
|
||||
strings::StrAppend(&result_, "@_dispatch.add_dispatch_list\n");
|
||||
if (api_def_.visibility() == ApiDef::VISIBLE) {
|
||||
strings::StrAppend(&result_, "@_dispatch.add_dispatch_list\n");
|
||||
}
|
||||
AddExport();
|
||||
AddDefLine(function_name_, parameters);
|
||||
AddDocStringDescription();
|
||||
|
@ -56,6 +56,7 @@ _BaseSlice = slice
|
||||
|
||||
|
||||
@tf_export("identity")
|
||||
@dispatch.add_dispatch_support
|
||||
def identity(input, name=None): # pylint: disable=redefined-builtin
|
||||
r"""Return a tensor with the same shape and contents as input.
|
||||
|
||||
@ -139,6 +140,7 @@ def expand_dims(input, axis=None, name=None, dim=None):
|
||||
|
||||
|
||||
@tf_export("expand_dims", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def expand_dims_v2(input, axis, name=None):
|
||||
"""Inserts a dimension of 1 into a tensor's shape.
|
||||
|
||||
@ -941,6 +943,7 @@ def parallel_stack(values, name="parallel_stack"):
|
||||
|
||||
|
||||
@tf_export("stack")
|
||||
@dispatch.add_dispatch_support
|
||||
def stack(values, axis=0, name="stack"):
|
||||
"""Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.
|
||||
|
||||
@ -1151,6 +1154,7 @@ def unstack(value, num=None, axis=0, name="unstack"):
|
||||
|
||||
|
||||
@tf_export("concat")
|
||||
@dispatch.add_dispatch_support
|
||||
def concat(values, axis, name="concat"):
|
||||
"""Concatenates tensors along one dimension.
|
||||
|
||||
@ -1328,6 +1332,7 @@ def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
|
||||
|
||||
|
||||
@tf_export("boolean_mask", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def boolean_mask_v2(tensor, mask, axis=None, name="boolean_mask"):
|
||||
"""Apply boolean mask to tensor.
|
||||
|
||||
@ -1810,6 +1815,7 @@ def zeros(shape, dtype=dtypes.float32, name=None):
|
||||
|
||||
|
||||
@tf_export(v1=["zeros_like"])
|
||||
@dispatch.add_dispatch_support
|
||||
def zeros_like(tensor, dtype=None, name=None, optimize=True):
|
||||
"""Creates a tensor with all elements set to zero.
|
||||
|
||||
@ -1840,6 +1846,7 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
|
||||
|
||||
|
||||
@tf_export("zeros_like", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def zeros_like_v2(
|
||||
input, # pylint: disable=redefined-builtin
|
||||
dtype=None,
|
||||
@ -1899,6 +1906,7 @@ def zeros_like_impl(tensor, dtype, name, optimize=True):
|
||||
|
||||
|
||||
@tf_export(v1=["ones_like"])
|
||||
@dispatch.add_dispatch_support
|
||||
def ones_like(tensor, dtype=None, name=None, optimize=True):
|
||||
"""Creates a tensor with all elements set to 1.
|
||||
|
||||
@ -1929,6 +1937,7 @@ def ones_like(tensor, dtype=None, name=None, optimize=True):
|
||||
|
||||
|
||||
@tf_export("ones_like", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def ones_like_v2(
|
||||
input, # pylint: disable=redefined-builtin
|
||||
dtype=None,
|
||||
@ -3115,6 +3124,7 @@ def squeeze_v2(input, axis=None, name=None):
|
||||
|
||||
|
||||
@tf_export("where")
|
||||
@dispatch.add_dispatch_support
|
||||
def where(condition, x=None, y=None, name=None):
|
||||
"""Return the elements, either from `x` or `y`, depending on the `condition`.
|
||||
|
||||
@ -3234,6 +3244,7 @@ def gather(params, indices, validate_indices=None, name=None, axis=0):
|
||||
|
||||
|
||||
@tf_export("gather", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def gather_v2(params, indices, validate_indices=None, axis=0, name=None):
|
||||
return gather(params, indices, validate_indices=validate_indices, name=name,
|
||||
axis=axis)
|
||||
|
@ -31,10 +31,12 @@ from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import numerics
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("clip_by_value")
|
||||
@dispatch.add_dispatch_support
|
||||
def clip_by_value(t, clip_value_min, clip_value_max,
|
||||
name=None):
|
||||
"""Clips tensor values to a specified min and max.
|
||||
|
@ -230,6 +230,7 @@ class DivideDelegateWithName(object):
|
||||
|
||||
|
||||
@tf_export("math.divide", "divide")
|
||||
@dispatch.add_dispatch_support
|
||||
def divide(x, y, name=None):
|
||||
"""Computes Python style division of `x` by `y`."""
|
||||
|
||||
@ -242,6 +243,7 @@ def divide(x, y, name=None):
|
||||
|
||||
|
||||
@tf_export("math.multiply", "multiply")
|
||||
@dispatch.add_dispatch_support
|
||||
def multiply(x, y, name=None):
|
||||
return gen_math_ops.mul(x, y, name)
|
||||
|
||||
@ -262,6 +264,7 @@ _mul.__doc__ = (
|
||||
|
||||
|
||||
@tf_export("math.subtract", "subtract")
|
||||
@dispatch.add_dispatch_support
|
||||
def subtract(x, y, name=None):
|
||||
return gen_math_ops.sub(x, y, name)
|
||||
|
||||
@ -347,6 +350,7 @@ def scalar_mul_v2(scalar, x, name=None):
|
||||
|
||||
|
||||
@tf_export("math.pow", "pow")
|
||||
@dispatch.add_dispatch_support
|
||||
def pow(x, y, name=None): # pylint: disable=redefined-builtin
|
||||
r"""Computes the power of one value to another.
|
||||
|
||||
@ -375,6 +379,7 @@ def pow(x, y, name=None): # pylint: disable=redefined-builtin
|
||||
|
||||
# pylint: disable=redefined-builtin,redefined-outer-name
|
||||
@tf_export("dtypes.complex", "complex")
|
||||
@dispatch.add_dispatch_support
|
||||
def complex(real, imag, name=None):
|
||||
r"""Converts two real numbers to a complex number.
|
||||
|
||||
@ -418,6 +423,7 @@ def complex(real, imag, name=None):
|
||||
|
||||
@tf_export("math.real", v1=["math.real", "real"])
|
||||
@deprecation.deprecated_endpoints("real")
|
||||
@dispatch.add_dispatch_support
|
||||
def real(input, name=None):
|
||||
r"""Returns the real part of a complex (or real) tensor.
|
||||
|
||||
@ -450,6 +456,7 @@ def real(input, name=None):
|
||||
|
||||
@tf_export("math.imag", v1=["math.imag", "imag"])
|
||||
@deprecation.deprecated_endpoints("imag")
|
||||
@dispatch.add_dispatch_support
|
||||
def imag(input, name=None):
|
||||
r"""Returns the imaginary part of a complex (or real) tensor.
|
||||
|
||||
@ -481,6 +488,7 @@ def imag(input, name=None):
|
||||
|
||||
@tf_export("math.angle", v1=["math.angle", "angle"])
|
||||
@deprecation.deprecated_endpoints("angle")
|
||||
@dispatch.add_dispatch_support
|
||||
def angle(input, name=None):
|
||||
r"""Returns the element-wise argument of a complex (or real) tensor.
|
||||
|
||||
@ -520,6 +528,7 @@ def angle(input, name=None):
|
||||
|
||||
|
||||
@tf_export("math.round", "round")
|
||||
@dispatch.add_dispatch_support
|
||||
def round(x, name=None): # pylint: disable=redefined-builtin
|
||||
"""Rounds the values of a tensor to the nearest integer, element-wise.
|
||||
|
||||
@ -547,6 +556,7 @@ def round(x, name=None): # pylint: disable=redefined-builtin
|
||||
|
||||
|
||||
@tf_export("dtypes.cast", "cast")
|
||||
@dispatch.add_dispatch_support
|
||||
def cast(x, dtype, name=None):
|
||||
"""Casts a tensor to a new type.
|
||||
|
||||
@ -610,6 +620,7 @@ def cast(x, dtype, name=None):
|
||||
|
||||
|
||||
@tf_export("dtypes.saturate_cast", "saturate_cast")
|
||||
@dispatch.add_dispatch_support
|
||||
def saturate_cast(value, dtype, name=None):
|
||||
"""Performs a safe saturating cast of `value` to `dtype`.
|
||||
|
||||
@ -935,6 +946,7 @@ def _div_python2(x, y, name=None):
|
||||
|
||||
|
||||
@tf_export("math.truediv", "truediv")
|
||||
@dispatch.add_dispatch_support
|
||||
def truediv(x, y, name=None):
|
||||
"""Divides x / y elementwise (using Python 3 division operator semantics).
|
||||
|
||||
@ -992,6 +1004,7 @@ def div(x, y, name=None):
|
||||
|
||||
|
||||
@tf_export("div_no_nan")
|
||||
@dispatch.add_dispatch_support
|
||||
def div_no_nan(x, y, name=None):
|
||||
"""Computes an unsafe divide which returns 0 if the y is zero.
|
||||
|
||||
@ -1021,6 +1034,7 @@ mod = gen_math_ops.floor_mod
|
||||
# TODO(aselle): Deprecate this once all internal functionality uses
|
||||
# tf.truncatediv
|
||||
@tf_export("math.floordiv", v1=["math.floordiv", "floordiv"])
|
||||
@dispatch.add_dispatch_support
|
||||
@deprecation.deprecated_endpoints("floordiv")
|
||||
def floordiv(x, y, name=None):
|
||||
"""Divides `x / y` elementwise, rounding toward the most negative integer.
|
||||
@ -1050,16 +1064,11 @@ def floordiv(x, y, name=None):
|
||||
|
||||
|
||||
realdiv = gen_math_ops.real_div
|
||||
tf_export("realdiv")(realdiv)
|
||||
truncatediv = gen_math_ops.truncate_div
|
||||
tf_export("truncatediv")(truncatediv)
|
||||
# TODO(aselle): Rename this to floordiv when we can.
|
||||
floor_div = gen_math_ops.floor_div
|
||||
tf_export("floor_div")(floor_div)
|
||||
truncatemod = gen_math_ops.truncate_mod
|
||||
tf_export("truncatemod")(truncatemod)
|
||||
floormod = gen_math_ops.floor_mod
|
||||
tf_export("floormod", "mod")(floormod)
|
||||
|
||||
|
||||
def _mul_dispatch(x, y, name=None):
|
||||
@ -1095,6 +1104,7 @@ _OverrideBinaryOperatorHelper(pow, "pow")
|
||||
|
||||
|
||||
@tf_export("math.logical_xor", v1=["math.logical_xor", "logical_xor"])
|
||||
@dispatch.add_dispatch_support
|
||||
@deprecation.deprecated_endpoints("logical_xor")
|
||||
def logical_xor(x, y, name="LogicalXor"):
|
||||
"""x ^ y = (x | y) & ~(x & y)."""
|
||||
@ -1277,6 +1287,7 @@ def reduce_sum_v1(input_tensor,
|
||||
|
||||
|
||||
@tf_export("math.reduce_sum", "reduce_sum", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def reduce_sum(input_tensor, axis=None, keepdims=False, name=None):
|
||||
"""Computes the sum of elements across dimensions of a tensor.
|
||||
|
||||
@ -1524,6 +1535,7 @@ def reduce_mean_v1(input_tensor,
|
||||
|
||||
|
||||
@tf_export("math.reduce_mean", "reduce_mean", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def reduce_mean(input_tensor, axis=None, keepdims=False, name=None):
|
||||
"""Computes the mean of elements across dimensions of a tensor.
|
||||
|
||||
@ -1675,6 +1687,7 @@ def reduce_std(input_tensor, axis=None, keepdims=False, name=None):
|
||||
|
||||
|
||||
@tf_export("math.reduce_prod", "reduce_prod", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def reduce_prod(input_tensor, axis=None, keepdims=False, name=None):
|
||||
"""Computes the product of elements across dimensions of a tensor.
|
||||
|
||||
@ -1796,6 +1809,7 @@ def reduce_min_v1(input_tensor,
|
||||
|
||||
|
||||
@tf_export("math.reduce_min", "reduce_min", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def reduce_min(input_tensor, axis=None, keepdims=False, name=None):
|
||||
"""Computes the minimum of elements across dimensions of a tensor.
|
||||
|
||||
@ -1874,6 +1888,7 @@ def reduce_max_v1(input_tensor,
|
||||
|
||||
|
||||
@tf_export("math.reduce_max", "reduce_max", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def reduce_max(input_tensor, axis=None, keepdims=False, name=None):
|
||||
"""Computes the maximum of elements across dimensions of a tensor.
|
||||
|
||||
@ -1961,6 +1976,7 @@ def reduce_all_v1(input_tensor,
|
||||
|
||||
|
||||
@tf_export("reduce_all", "math.reduce_all", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def reduce_all(input_tensor, axis=None, keepdims=False, name=None):
|
||||
"""Computes the "logical and" of elements across dimensions of a tensor.
|
||||
|
||||
@ -2057,6 +2073,7 @@ def reduce_any_v1(input_tensor,
|
||||
|
||||
|
||||
@tf_export("math.reduce_any", "reduce_any", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def reduce_any(input_tensor, axis=None, keepdims=False, name=None):
|
||||
"""Computes the "logical or" of elements across dimensions of a tensor.
|
||||
|
||||
@ -2619,6 +2636,7 @@ def _as_indexed_slices_list(inputs, optimize=True):
|
||||
|
||||
|
||||
@tf_export("math.add_n", "add_n")
|
||||
@dispatch.add_dispatch_support
|
||||
def add_n(inputs, name=None):
|
||||
"""Adds all input tensors element-wise.
|
||||
|
||||
@ -2764,6 +2782,7 @@ def sigmoid(x, name=None):
|
||||
|
||||
|
||||
@tf_export("math.log_sigmoid", v1=["math.log_sigmoid", "log_sigmoid"])
|
||||
@dispatch.add_dispatch_support
|
||||
@deprecation.deprecated_endpoints("log_sigmoid")
|
||||
def log_sigmoid(x, name=None):
|
||||
"""Computes log sigmoid of `x` element-wise.
|
||||
@ -2973,6 +2992,7 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
|
||||
|
||||
|
||||
@tf_export("math.conj", v1=["math.conj", "conj"])
|
||||
@dispatch.add_dispatch_support
|
||||
@deprecation.deprecated_endpoints("conj")
|
||||
def conj(x, name=None):
|
||||
r"""Returns the complex conjugate of a complex number.
|
||||
@ -3077,6 +3097,7 @@ def _unsorted_segment_N(data, segment_ids, num_segments):
|
||||
"math.unsorted_segment_mean",
|
||||
v1=["math.unsorted_segment_mean", "unsorted_segment_mean"])
|
||||
@deprecation.deprecated_endpoints("unsorted_segment_mean")
|
||||
@dispatch.add_dispatch_support
|
||||
def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
|
||||
r"""Computes the mean along segments of a tensor.
|
||||
|
||||
@ -3122,6 +3143,7 @@ def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
|
||||
"math.unsorted_segment_sqrt_n",
|
||||
v1=["math.unsorted_segment_sqrt_n", "unsorted_segment_sqrt_n"])
|
||||
@deprecation.deprecated_endpoints("unsorted_segment_sqrt_n")
|
||||
@dispatch.add_dispatch_support
|
||||
def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None):
|
||||
r"""Computes the sum along segments of a tensor divided by the sqrt(N).
|
||||
|
||||
|
@ -25,7 +25,7 @@ py_library(
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
":ragged_conversion_ops",
|
||||
":ragged_elementwise_ops",
|
||||
":ragged_dispatch",
|
||||
":ragged_factory_ops",
|
||||
":ragged_functional_ops",
|
||||
":ragged_getitem",
|
||||
@ -150,33 +150,14 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "ragged_elementwise_ops",
|
||||
srcs = ["ragged_elementwise_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_factory_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_tensor_shape",
|
||||
":ragged_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:clip_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "ragged_operators",
|
||||
srcs = ["ragged_operators.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_elementwise_ops",
|
||||
":ragged_getitem",
|
||||
":ragged_tensor",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
@ -186,12 +167,13 @@ py_library(
|
||||
srcs = ["ragged_string_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
":ragged_conversion_ops",
|
||||
":ragged_factory_ops",
|
||||
":ragged_tensor",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
@ -219,10 +201,11 @@ py_library(
|
||||
":ragged_tensor",
|
||||
":ragged_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_util",
|
||||
],
|
||||
)
|
||||
@ -285,6 +268,29 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "ragged_dispatch",
|
||||
srcs = ["ragged_dispatch.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
":ragged_factory_ops",
|
||||
":ragged_math_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_tensor_shape",
|
||||
":ragged_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:clip_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# RaggedTensor Tests
|
||||
#-------------------------------------------------------------------------------
|
||||
@ -458,6 +464,7 @@ py_test(
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:gradients_impl",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
@ -684,17 +691,21 @@ py_test(
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ragged_elementwise_ops_test",
|
||||
srcs = ["ragged_elementwise_ops_test.py"],
|
||||
name = "ragged_dispatch_test",
|
||||
srcs = ["ragged_dispatch_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:clip_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
@ -725,6 +736,7 @@ py_test(
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python/keras:backend",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
@ -735,8 +747,10 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
@ -1,76 +1,53 @@
|
||||
"""Ragged Tensors.
|
||||
|
||||
This package defines the [`RaggedTensor`](ragged/RaggedTensor.md) class, which
|
||||
This package defines the `tf.RaggedTensor` class, which
|
||||
represents tensors with non-uniform shapes. In particular, each `RaggedTensor`
|
||||
has one or more *ragged dimensions*, which are dimensions whose slices may have
|
||||
different lengths. For example, the inner (column) dimension of
|
||||
`rt=[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is ragged, since the column slices
|
||||
(`rt[0, :]`, ..., `rt[4, :]`) have different lengths. For a more detailed
|
||||
description of ragged tensors, see the [`RaggedTensor`](ragged/RaggedTensor.md)
|
||||
description of ragged tensors, see the `tf.RaggedTensor`
|
||||
class documentation.
|
||||
|
||||
## RaggedTensor Operations
|
||||
## `RaggedTensor` Operations
|
||||
|
||||
This package also defines a collection of operations for manipulating
|
||||
ragged tensors.
|
||||
### `RaggedTensor` Factory ops
|
||||
|
||||
### RaggedTensor Versions of Standard Tensor Operations
|
||||
* `tf.ragged.constant`
|
||||
* `tf.ragged.from_row_splits`
|
||||
* `tf.ragged.from_row_splits`
|
||||
* `tf.ragged.from_row_lengths`
|
||||
* `tf.ragged.from_row_starts`
|
||||
* `tf.ragged.from_row_limits`
|
||||
* `tf.ragged.from_value_rowids`
|
||||
* `tf.ragged.from_nested_row_splits`
|
||||
* `tf.ragged.from_nested_value_rowids`
|
||||
|
||||
Many of the operations defined by this package are analogous to
|
||||
[`Tensor`](https://www.tensorflow.org/api_docs/python/tf/Tensor)
|
||||
operations, but they accept `RaggedTensor`s as input and can return
|
||||
`RaggedTensor`s as output. For example, `ragged.add` performs elementwise
|
||||
addition just like `tf.add`, but can be used on `RaggedTensor`s.
|
||||
### `RaggedTensor` Conversion ops
|
||||
|
||||
These `RaggedTensor` versions of the standard `Tensor` operations can also be
|
||||
used with standard `Tensors`; and for the most part, they will return the same
|
||||
value that the standard `Tensor` operation would return. However, there are
|
||||
a few notable exceptions:
|
||||
* `tf.ragged.from_tensor`
|
||||
* `tf.ragged.to_tensor`
|
||||
* `tf.ragged.from_sparse`
|
||||
* `tf.ragged.to_sparse`
|
||||
* `tf.ragged.from_variant`
|
||||
* `tf.ragged.to_variant`
|
||||
* `tf.ragged.convert_to_tensor_or_ragged_tensor`
|
||||
|
||||
* For [`ragged.stack(...)`](ragged/stack.md) and
|
||||
[`ragged.concat(...)`](ragged/concat.md), the input tensors are not required
|
||||
to have matching shapes. In the returned tensor, all dimensions up to the
|
||||
`axis` dimension will be ragged.
|
||||
### `RaggedTensor` Shape ops
|
||||
|
||||
### Ragged-Tensor Specific Operations
|
||||
* `tf.ragged.row_splits`
|
||||
* `tf.ragged.row_lengths`
|
||||
* `tf.ragged.row_starts`
|
||||
* `tf.ragged.row_limits`
|
||||
* `tf.ragged.value_rowids`
|
||||
* `tf.ragged.nrows`
|
||||
* `tf.ragged.nested_row_splits`
|
||||
* `tf.ragged.row_splits_to_segment_ids`
|
||||
* `tf.ragged.segment_ids_to_row_splits`
|
||||
* `tf.ragged.bounding_shape`
|
||||
|
||||
The following operations are specific to ragged tensors:
|
||||
|
||||
* **Factory ops**:
|
||||
[`constant(...)`](ragged/constant.md),
|
||||
[`from_row_splits(...)`](ragged/from_row_splits.md),
|
||||
[`from_row_lengths(...)`](ragged/from_row_lengths.md),
|
||||
[`from_row_starts(...)`](ragged/from_row_starts.md),
|
||||
[`from_row_limits(...)`](ragged/from_row_limits.md),
|
||||
[`from_value_rowids(...)`](ragged/from_value_rowids.md),
|
||||
[`from_nested_row_splits(...)`](ragged/from_nested_row_splits.md),
|
||||
[`from_nested_value_rowids(...)`](ragged/from_nested_value_rowids.md).
|
||||
|
||||
* **Conversion ops**:
|
||||
[`from_tensor(...)`](ragged/from_tensor.md),
|
||||
[`to_tensor(...)`](ragged/to_tensor.md),
|
||||
[`from_sparse(...)`](ragged/from_sparse.md),
|
||||
[`to_sparse(...)`](ragged/to_sparse.md),
|
||||
[`from_variant(...)`](ragged/from_variant.md),
|
||||
[`to_variant(...)`](ragged/to_variant.md),
|
||||
[`convert_to_tensor_or_ragged_tensor(...)`](
|
||||
ragged/convert_to_tensor_or_ragged_tensor.md).
|
||||
|
||||
* **Shape ops**:
|
||||
[`row_splits(...)`](ragged/row_splits.md),
|
||||
[`row_lengths(...)`](ragged/row_lengths.md),
|
||||
[`row_starts(...)`](ragged/row_starts.md),
|
||||
[`row_limits(...)`](ragged/row_limits.md),
|
||||
[`value_rowids(...)`](ragged/value_rowids.md),
|
||||
[`nrows(...)`](ragged/nrows.md),
|
||||
[`nested_row_splits(...)`](ragged/nested_row_splits.md),
|
||||
[`row_splits_to_segment_ids(...)`](ragged/row_splits_to_segment_ids.md),
|
||||
[`segment_ids_to_row_splits(...)`](ragged/segment_ids_to_row_splits.md),
|
||||
[`bounding_shape(...)`](ragged/bounding_shape.md).
|
||||
|
||||
* **Functional ops**:
|
||||
[`map_inner_values(...)`](ragged/map_inner_values.md),
|
||||
[`make_elementwise_op(...)`](ragged/make_elementwise_op.md).
|
||||
### Functional ops
|
||||
* `tf.ragged.map_inner_values`
|
||||
|
||||
|
||||
<!-- Ragged Classes & related helper functions -->
|
||||
@ -140,21 +117,17 @@ The following operations are specific to ragged tensors:
|
||||
@@map_inner_values
|
||||
@@map_fn
|
||||
|
||||
<!-- Elementwise Ops -->
|
||||
@@make_elementwise_op
|
||||
|
||||
<!-- Shape & broadcasting -->
|
||||
@@RaggedTensorDynamicShape
|
||||
@@broadcast_to
|
||||
@@broadcast_dynamic_shape
|
||||
|
||||
<!-- Symbols from ragged_elementwise_ops._symbols_to_export are whitelisted -->
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.ops.ragged import ragged_dispatch
|
||||
from tensorflow.python.ops.ragged import ragged_operators
|
||||
from tensorflow.python.ops.ragged import ragged_string_ops
|
||||
|
||||
@ -179,11 +152,6 @@ from tensorflow.python.ops.ragged.ragged_conversion_ops import from_tensor
|
||||
from tensorflow.python.ops.ragged.ragged_conversion_ops import to_sparse
|
||||
from tensorflow.python.ops.ragged.ragged_conversion_ops import to_tensor
|
||||
|
||||
# pylint: disable=protected-access, wildcard-import
|
||||
from tensorflow.python.ops.ragged.ragged_elementwise_ops import *
|
||||
from tensorflow.python.ops.ragged.ragged_elementwise_ops import _symbols_to_export as _elementwise_ops
|
||||
# pylint: enable=protected-access, wildcard-import
|
||||
|
||||
from tensorflow.python.ops.ragged.ragged_factory_ops import constant
|
||||
from tensorflow.python.ops.ragged.ragged_factory_ops import constant_value
|
||||
from tensorflow.python.ops.ragged.ragged_factory_ops import convert_to_tensor_or_ragged_tensor
|
||||
@ -231,6 +199,10 @@ from tensorflow.python.ops.ragged.segment_id_ops import segment_ids_to_row_split
|
||||
|
||||
from tensorflow.python.util import all_util as _all_util
|
||||
|
||||
|
||||
# Register OpDispatchers that override standard TF ops to work w/ RaggedTensors.
|
||||
__doc__ += ragged_dispatch.register_dispatchers() # pylint: disable=redefined-builtin
|
||||
|
||||
# Any symbol that is not referenced (with "@@name") in the module docstring
|
||||
# above, or included in the "_elementwise_ops" whitelist, will be removed.
|
||||
_all_util.remove_undocumented(__name__, _elementwise_ops)
|
||||
# above will be removed.
|
||||
_all_util.remove_undocumented(__name__)
|
||||
|
@ -308,7 +308,7 @@ def bounding_shape(rt_input, axis=None, name=None):
|
||||
# ragged_gather
|
||||
#===============================================================================
|
||||
# TODO(edloper): Add an `axis` argument
|
||||
def gather(params, indices, name=None):
|
||||
def gather(params, indices, validate_indices=None, axis=0, name=None):
|
||||
"""Gathers ragged slices from `params` axis `0` according to `indices`.
|
||||
|
||||
Returns `RaggedTensor` output, such that:
|
||||
@ -347,6 +347,8 @@ def gather(params, indices, name=None):
|
||||
indices: The potentially ragged tensor indicating which values to gather.
|
||||
Must have dtype `int32` or `int64`. Values must be in the range `[0,
|
||||
params.shape[0]]`.
|
||||
validate_indices: Ignored.
|
||||
axis: Must be zero.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
@ -357,6 +359,9 @@ def gather(params, indices, name=None):
|
||||
Raises:
|
||||
ValueError: If indices.shape.ndims is not known statically.
|
||||
"""
|
||||
del validate_indices
|
||||
if not isinstance(axis, int) or axis != 0:
|
||||
raise ValueError('axis>0 is not supported for ragged gather yet.')
|
||||
with ops.name_scope(name, 'RaggedGather', [params, indices]):
|
||||
params = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
|
||||
params, name='params')
|
||||
@ -812,29 +817,29 @@ def boolean_mask(data, mask, keepdims=False, name=None):
|
||||
#===============================================================================
|
||||
# Concatenation and Stacking
|
||||
#===============================================================================
|
||||
def concat(rt_inputs, axis, name=None):
|
||||
def concat(values, axis, name=None):
|
||||
"""Concatenates potentially ragged tensors along one dimension.
|
||||
|
||||
Given a list of tensors with the same rank `K` (`K >= axis`), returns a
|
||||
rank-`K` `RaggedTensor` `result` such that `result[i0...iaxis]` is the
|
||||
concatenation of `[rt[i0...iaxis] for rt in rt_inputs]`.
|
||||
concatenation of `[rt[i0...iaxis] for rt in values]`.
|
||||
|
||||
Args:
|
||||
rt_inputs: A list of potentially ragged tensors. May not be empty. All
|
||||
`rt_inputs` must have the same rank and the same dtype; but unlike
|
||||
values: A list of potentially ragged tensors. May not be empty. All
|
||||
`values` must have the same rank and the same dtype; but unlike
|
||||
`tf.concat`, they can have arbitrary shapes.
|
||||
axis: A python integer, indicating the dimension along which to concatenate.
|
||||
(Note: Unlike `tf.concat`, the `axis` parameter must be statically known.)
|
||||
Negative values are supported only if the rank of at least one
|
||||
`rt_inputs` value is statically known.
|
||||
`values` value is statically known.
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
|
||||
Returns:
|
||||
A `RaggedTensor` with rank `K`.
|
||||
`result.ragged_rank=max(axis, max(rt.ragged_rank for rt in rt_inputs]))`.
|
||||
`result.ragged_rank=max(axis, max(rt.ragged_rank for rt in values]))`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `rt_inputs` is empty, if `axis` is out of bounds or if
|
||||
ValueError: If `values` is empty, if `axis` is out of bounds or if
|
||||
the input tensors have different ranks.
|
||||
|
||||
#### Example:
|
||||
@ -847,35 +852,35 @@ def concat(rt_inputs, axis, name=None):
|
||||
[[1, 2, 6], [3, 4, 5, 7, 8, 9]]
|
||||
```
|
||||
"""
|
||||
if not isinstance(rt_inputs, (list, tuple)):
|
||||
rt_inputs = [rt_inputs]
|
||||
with ops.name_scope(name, 'RaggedConcat', rt_inputs):
|
||||
return _ragged_stack_concat_helper(rt_inputs, axis, stack_values=False)
|
||||
if not isinstance(values, (list, tuple)):
|
||||
values = [values]
|
||||
with ops.name_scope(name, 'RaggedConcat', values):
|
||||
return _ragged_stack_concat_helper(values, axis, stack_values=False)
|
||||
|
||||
|
||||
def stack(rt_inputs, axis, name=None):
|
||||
def stack(values, axis, name=None):
|
||||
"""Stacks potentially ragged tensors along one dimension.
|
||||
|
||||
Given a list of tensors with the same rank `K` (`K >= axis`), returns a
|
||||
rank-`K+1` `RaggedTensor` `result` such that `result[i0...iaxis]` is the
|
||||
list `[rt[i0...iaxis] for rt in rt_inputs]`.
|
||||
list `[rt[i0...iaxis] for rt in values]`.
|
||||
|
||||
Args:
|
||||
rt_inputs: A list of potentially ragged tensors. May not be empty. All
|
||||
`rt_inputs` must have the same rank and the same dtype; but unlike
|
||||
values: A list of potentially ragged tensors. May not be empty. All
|
||||
`values` must have the same rank and the same dtype; but unlike
|
||||
`tf.concat`, they can have arbitrary shapes.
|
||||
axis: A python integer, indicating the dimension along which to stack.
|
||||
(Note: Unlike `tf.stack`, the `axis` parameter must be statically known.)
|
||||
Negative values are supported only if the rank of at least one
|
||||
`rt_inputs` value is statically known.
|
||||
`values` value is statically known.
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
|
||||
Returns:
|
||||
A `RaggedTensor` with rank `K+1`.
|
||||
`result.ragged_rank=max(axis, max(rt.ragged_rank for rt in rt_inputs]))`.
|
||||
`result.ragged_rank=max(axis, max(rt.ragged_rank for rt in values]))`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `rt_inputs` is empty, if `axis` is out of bounds or if
|
||||
ValueError: If `values` is empty, if `axis` is out of bounds or if
|
||||
the input tensors have different ranks.
|
||||
|
||||
#### Example:
|
||||
@ -888,10 +893,10 @@ def stack(rt_inputs, axis, name=None):
|
||||
[[[1, 2], [6]], [[3, 4, 5], [7, 8, 9]]]
|
||||
```
|
||||
"""
|
||||
if not isinstance(rt_inputs, (list, tuple)):
|
||||
rt_inputs = [rt_inputs]
|
||||
with ops.name_scope(name, 'RaggedConcat', rt_inputs):
|
||||
return _ragged_stack_concat_helper(rt_inputs, axis, stack_values=True)
|
||||
if not isinstance(values, (list, tuple)):
|
||||
values = [values]
|
||||
with ops.name_scope(name, 'RaggedConcat', values):
|
||||
return _ragged_stack_concat_helper(values, axis, stack_values=True)
|
||||
|
||||
|
||||
def _ragged_stack_concat_helper(rt_inputs, axis, stack_values):
|
||||
@ -1065,22 +1070,22 @@ def _copy_row_shape(rt_inputs, splits):
|
||||
#===============================================================================
|
||||
# Tiling
|
||||
#===============================================================================
|
||||
def tile(rt_input, multiples, name=None):
|
||||
def tile(input, multiples, name=None): # pylint: disable=redefined-builtin
|
||||
"""Constructs a `RaggedTensor` by tiling a given `RaggedTensor`.
|
||||
|
||||
The values of `rt_input` are replicated `multiples[i]` times along the
|
||||
The values of `input` are replicated `multiples[i]` times along the
|
||||
`i`th dimension (for each dimension `i`). For every dimension `axis` in
|
||||
`rt_input`, the length of each output element in that dimension is the
|
||||
`input`, the length of each output element in that dimension is the
|
||||
length of corresponding input element multiplied by `multiples[axis]`.
|
||||
|
||||
Args:
|
||||
rt_input: A `RaggedTensor`.
|
||||
input: A `RaggedTensor`.
|
||||
multiples: A 1-D integer `Tensor`. Length must be the same as the number of
|
||||
dimensions in `rt_input`.
|
||||
dimensions in `input`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `RaggedTensor` with the same type, rank, and ragged_rank as `rt_input`.
|
||||
A `RaggedTensor` with the same type, rank, and ragged_rank as `input`.
|
||||
|
||||
#### Example:
|
||||
```python
|
||||
@ -1089,22 +1094,22 @@ def tile(rt_input, multiples, name=None):
|
||||
[[1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3]]
|
||||
```
|
||||
"""
|
||||
with ops.name_scope(name, 'RaggedTile', [rt_input, multiples]):
|
||||
rt_input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
|
||||
rt_input, name='rt_input')
|
||||
with ops.name_scope(name, 'RaggedTile', [input, multiples]):
|
||||
input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
|
||||
input, name='input')
|
||||
multiples = ragged_util.convert_to_int_tensor(
|
||||
multiples, name='multiples', dtype=dtypes.int64)
|
||||
multiples.shape.assert_has_rank(1)
|
||||
if not ragged_tensor.is_ragged(rt_input):
|
||||
return array_ops.tile(rt_input, multiples, name)
|
||||
if not ragged_tensor.is_ragged(input):
|
||||
return array_ops.tile(input, multiples, name)
|
||||
|
||||
# If the constant value of `multiples` is available, then we can use it
|
||||
# to skip tiling dimensions where `multiples=1`.
|
||||
const_multiples = tensor_util.constant_value(multiples)
|
||||
|
||||
return ragged_factory_ops.from_nested_row_splits(
|
||||
_tile_ragged_values(rt_input, multiples, const_multiples),
|
||||
_tile_ragged_splits(rt_input, multiples, const_multiples))
|
||||
_tile_ragged_values(input, multiples, const_multiples),
|
||||
_tile_ragged_splits(input, multiples, const_multiples))
|
||||
|
||||
|
||||
def _tile_ragged_values(rt_input, multiples, const_multiples=None):
|
||||
@ -1240,26 +1245,26 @@ def _tile_ragged_splits(rt_input, multiples, const_multiples=None):
|
||||
#===============================================================================
|
||||
|
||||
|
||||
def expand_dims(rt_input, axis, name=None):
|
||||
def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
|
||||
"""Inserts a dimension with shape 1 into a potentially ragged tensor's shape.
|
||||
|
||||
Given a potentially ragged tenor `rt_input`, this operation inserts a
|
||||
dimension with size 1 at the dimension `axis` of `rt_input`'s shape.
|
||||
Given a potentially ragged tenor `input`, this operation inserts a
|
||||
dimension with size 1 at the dimension `axis` of `input`'s shape.
|
||||
|
||||
* If `rt_input` is a `Tensor`, then this is equivalent to
|
||||
* If `input` is a `Tensor`, then this is equivalent to
|
||||
`tf.expand_dims`.
|
||||
* If `rt_input` is ragged, and `axis=0`, then the new dimension will be
|
||||
* If `input` is ragged, and `axis=0`, then the new dimension will be
|
||||
uniform; but the previously outermost dimension will become ragged.
|
||||
* If `rt_input` is ragged, and `0 < axis < rt_input.ragged_rank`, then the
|
||||
* If `input` is ragged, and `0 < axis < input.ragged_rank`, then the
|
||||
new dimension will be ragged.
|
||||
* If `rt_input` is ragged, and axis >= rt_input.ragged_rank`, then the new
|
||||
* If `input` is ragged, and axis >= input.ragged_rank`, then the new
|
||||
dimension will be uniform.
|
||||
|
||||
The following table gives some examples showing how `ragged.expand_dims`
|
||||
impacts the shapes of different input tensors. Ragged dimensions are
|
||||
indicated by enclosing them in parentheses.
|
||||
|
||||
rt_input.shape | axis | result.shape
|
||||
input.shape | axis | result.shape
|
||||
----------------------- | ---- | -----------------------------
|
||||
`[D1, D2]` | `0` | `[1, D1, D2]`
|
||||
`[D1, D2]` | `1` | `[D1, 1, D2]`
|
||||
@ -1271,14 +1276,14 @@ def expand_dims(rt_input, axis, name=None):
|
||||
`[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]`
|
||||
|
||||
Args:
|
||||
rt_input: The potentially tensor that should be expanded with a new
|
||||
input: The potentially tensor that should be expanded with a new
|
||||
dimension.
|
||||
axis: An integer constant indicating where the new dimension should be
|
||||
inserted.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A tensor with the same values as `rt_input`, with an added dimension of
|
||||
A tensor with the same values as `input`, with an added dimension of
|
||||
size 1 at `axis`.
|
||||
|
||||
#### Examples:
|
||||
@ -1300,24 +1305,24 @@ def expand_dims(rt_input, axis, name=None):
|
||||
TensorShape([2, None, 1]) [[[1], [2]], [[3]]]
|
||||
```
|
||||
"""
|
||||
with ops.name_scope(name, 'RaggedExpandDims', [rt_input]):
|
||||
rt_input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
|
||||
rt_input, name='rt_input')
|
||||
with ops.name_scope(name, 'RaggedExpandDims', [input]):
|
||||
input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
|
||||
input, name='input')
|
||||
|
||||
if not ragged_tensor.is_ragged(rt_input):
|
||||
return array_ops.expand_dims(rt_input, axis)
|
||||
if not ragged_tensor.is_ragged(input):
|
||||
return array_ops.expand_dims(input, axis)
|
||||
|
||||
ndims = None if rt_input.shape.ndims is None else rt_input.shape.ndims + 1
|
||||
ndims = None if input.shape.ndims is None else input.shape.ndims + 1
|
||||
axis = ragged_util.get_positive_axis(axis, ndims)
|
||||
if axis == 0:
|
||||
values = rt_input
|
||||
splits = array_ops.stack([0, nrows(rt_input)])
|
||||
values = input
|
||||
splits = array_ops.stack([0, nrows(input)])
|
||||
elif axis == 1:
|
||||
values = rt_input
|
||||
splits = math_ops.range(nrows(rt_input) + 1)
|
||||
values = input
|
||||
splits = math_ops.range(nrows(input) + 1)
|
||||
else:
|
||||
values = expand_dims(rt_input.values, axis - 1)
|
||||
splits = rt_input.row_splits
|
||||
values = expand_dims(input.values, axis - 1)
|
||||
splits = input.row_splits
|
||||
|
||||
return ragged_factory_ops.from_row_splits(values, splits)
|
||||
|
||||
|
441
tensorflow/python/ops/ragged/ragged_dispatch.py
Normal file
441
tensorflow/python/ops/ragged/ragged_dispatch.py
Normal file
@ -0,0 +1,441 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Operator dispatch for RaggedTensors."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_tensor_shape
|
||||
from tensorflow.python.ops.ragged import ragged_util
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_export
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
# @TODO(edloper): Set this to True in the CL that exports RaggedTensors.
|
||||
_UPDATE_DOCSTRINGS = False
|
||||
|
||||
# Information about an argument to an operation: The name of the argument, its
|
||||
# position in the argument list, and a boolean flag indicating whether it
|
||||
# expects a list of tensors.
|
||||
_ArgInfo = collections.namedtuple('ArgInfo', ['name', 'position', 'is_list'])
|
||||
|
||||
|
||||
def _get_arg_infos(func, arg_names):
|
||||
"""Returns an `_ArgInfo` for each argument of `func` specified by `arg_names`.
|
||||
|
||||
Args:
|
||||
func: The function whose arguments should be described.
|
||||
arg_names: The names of the arguments to get info for.
|
||||
|
||||
Returns:
|
||||
A tuple of `_ArgInfo`s.
|
||||
"""
|
||||
arg_infos = []
|
||||
|
||||
# Inspect the func's argspec to find the position of each arg.
|
||||
arg_spec = tf_inspect.getargspec(func)
|
||||
for argname in arg_names:
|
||||
assert isinstance(argname, str)
|
||||
is_list = argname.startswith('[') and argname.endswith(']')
|
||||
if is_list:
|
||||
argname = argname[1:-1]
|
||||
if argname not in arg_spec.args:
|
||||
raise ValueError('Argument %r not found function in %s. Args=%s' %
|
||||
(argname, func, arg_spec.args))
|
||||
arg_infos.append(_ArgInfo(argname, arg_spec.args.index(argname), is_list))
|
||||
return arg_infos
|
||||
|
||||
|
||||
def _is_convertible_to_tensor(value):
|
||||
"""Returns true if `value` is convertible to a `Tensor`."""
|
||||
if isinstance(value,
|
||||
(ops.Tensor, variables.Variable, np.ndarray, int, float, str)):
|
||||
return True
|
||||
elif isinstance(value, (sparse_tensor.SparseTensor,)):
|
||||
return False
|
||||
else:
|
||||
try:
|
||||
ops.convert_to_tensor(value)
|
||||
return True
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
|
||||
"""OpDispatcher for unary ops that map a base op across ragged values."""
|
||||
|
||||
def __init__(self, original_op, arg_is_list=False):
|
||||
self._original_op = original_op
|
||||
self._arg_is_list = arg_is_list
|
||||
arg_names = tf_inspect.getfullargspec(original_op)[0]
|
||||
self._x = arg_names[0]
|
||||
if _UPDATE_DOCSTRINGS:
|
||||
original_op.__doc__ = (
|
||||
original_op.__doc__.rstrip() + '\n\n' +
|
||||
' `{x}` may be a `tf.RaggedTensor`.\n'.format(x=self._x))
|
||||
|
||||
def handle(self, args, kwargs):
|
||||
if args:
|
||||
x, args = args[0], args[1:]
|
||||
else:
|
||||
kwargs = kwargs.copy()
|
||||
x = kwargs.pop(self._x, None)
|
||||
if x is None:
|
||||
return self.NOT_SUPPORTED
|
||||
if self._arg_is_list:
|
||||
found_ragged = False
|
||||
for elt in x:
|
||||
if ragged_tensor.is_ragged(elt):
|
||||
found_ragged = True
|
||||
elif not _is_convertible_to_tensor(elt):
|
||||
return self.NOT_SUPPORTED
|
||||
if found_ragged:
|
||||
nested_splits_lists = [
|
||||
elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt)
|
||||
]
|
||||
inner_values = [
|
||||
elt.inner_values if ragged_tensor.is_ragged(elt) else elt
|
||||
for elt in x
|
||||
]
|
||||
with ops.control_dependencies(
|
||||
ragged_util.assert_splits_match(nested_splits_lists)):
|
||||
return ragged_factory_ops.from_nested_row_splits(
|
||||
self._original_op(inner_values, *args, **kwargs),
|
||||
nested_splits_lists[0])
|
||||
else:
|
||||
return self.NOT_SUPPORTED
|
||||
else:
|
||||
found_ragged = ragged_tensor.is_ragged(x)
|
||||
if found_ragged:
|
||||
mapped_values = self._original_op(x.inner_values, *args, **kwargs)
|
||||
return x.with_inner_values(mapped_values)
|
||||
else:
|
||||
return self.NOT_SUPPORTED
|
||||
|
||||
|
||||
class BinaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
|
||||
"""OpDispatcher for binary ops that map a base op across ragged values.
|
||||
|
||||
Supports broadcasting.
|
||||
"""
|
||||
|
||||
def __init__(self, original_op):
|
||||
self._original_op = original_op
|
||||
arg_names = tf_inspect.getfullargspec(original_op)[0]
|
||||
self._x = arg_names[0]
|
||||
self._y = arg_names[1]
|
||||
if _UPDATE_DOCSTRINGS:
|
||||
original_op.__doc__ = (
|
||||
original_op.__doc__.rstrip() + '\n\n' +
|
||||
' `{x}` and `{y}` may be a `tf.RaggedTensor`.\n'.format(
|
||||
x=self._x, y=self._y))
|
||||
|
||||
def handle(self, args, kwargs):
|
||||
# Extract the binary args.
|
||||
if len(args) > 1:
|
||||
x = args[0]
|
||||
y = args[1]
|
||||
args = args[2:]
|
||||
elif args:
|
||||
kwargs = kwargs.copy()
|
||||
x = args[0]
|
||||
y = kwargs.pop(self._y, None)
|
||||
args = args[1:]
|
||||
else:
|
||||
kwargs = kwargs.copy()
|
||||
x = kwargs.pop(self._x, None)
|
||||
y = kwargs.pop(self._y, None)
|
||||
|
||||
# Bail if we don't have at least one ragged argument.
|
||||
x_is_ragged = ragged_tensor.is_ragged(x)
|
||||
y_is_ragged = ragged_tensor.is_ragged(y)
|
||||
if not (x_is_ragged or y_is_ragged):
|
||||
return self.NOT_SUPPORTED
|
||||
|
||||
# Convert args to tensors. Bail if conversion fails.
|
||||
try:
|
||||
if not x_is_ragged:
|
||||
x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype)
|
||||
if not y_is_ragged:
|
||||
y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype)
|
||||
except (TypeError, ValueError):
|
||||
return self.NOT_SUPPORTED
|
||||
|
||||
if ((x_is_ragged and y_is_ragged) or
|
||||
(x_is_ragged and x.inner_values.shape.ndims <= y.shape.ndims) or
|
||||
(y_is_ragged and y.inner_values.shape.ndims <= x.shape.ndims)):
|
||||
bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(
|
||||
ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x),
|
||||
ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y))
|
||||
x = ragged_tensor_shape.broadcast_to(
|
||||
x, bcast_shape, broadcast_inner_dimensions=False)
|
||||
y = ragged_tensor_shape.broadcast_to(
|
||||
y, bcast_shape, broadcast_inner_dimensions=False)
|
||||
|
||||
x_values = x.inner_values if ragged_tensor.is_ragged(x) else x
|
||||
y_values = y.inner_values if ragged_tensor.is_ragged(y) else y
|
||||
mapped_values = self._original_op(x_values, y_values, *args, **kwargs)
|
||||
if ragged_tensor.is_ragged(x):
|
||||
return x.with_inner_values(mapped_values)
|
||||
else:
|
||||
return y.with_inner_values(mapped_values)
|
||||
|
||||
|
||||
class RaggedDispatcher(dispatch.OpDispatcher):
|
||||
"""OpDispatcher for ragged ops.
|
||||
|
||||
Dispatches to a wrapped op-handler if at least one of the `tensor_args`
|
||||
arguments is a RaggedTensor or a RaggedTensorValue; and all of the
|
||||
`tensor_args` arguments are convertible to Tensor or RaggedTensor.
|
||||
"""
|
||||
|
||||
def __init__(self, original_op, ragged_op, ragged_args):
|
||||
op_arg_names = tf_inspect.getfullargspec(original_op)[0]
|
||||
ragged_arg_names = tf_inspect.getfullargspec(ragged_op)[0]
|
||||
if op_arg_names != ragged_arg_names:
|
||||
raise AssertionError(
|
||||
'Signature must exactly match when overriding %s with %s: %s vs %s' %
|
||||
(original_op, ragged_op, op_arg_names, ragged_arg_names))
|
||||
self._ragged_op = ragged_op
|
||||
self._ragged_args = _get_arg_infos(ragged_op, ragged_args)
|
||||
if _UPDATE_DOCSTRINGS:
|
||||
arg_list = ' and '.join('`%s`' % arg for arg in ragged_args)
|
||||
original_op.__doc__ = (
|
||||
original_op.__doc__.rstrip() + '\n\n' +
|
||||
' {0} may be a `tf.RaggedTensor`.\n'.format(arg_list))
|
||||
|
||||
def handle(self, args, kwargs):
|
||||
if self.is_supported(args, kwargs):
|
||||
return self._ragged_op(*args, **kwargs)
|
||||
else:
|
||||
return self.NOT_SUPPORTED
|
||||
|
||||
def is_supported(self, args, kwargs):
|
||||
found_ragged = False
|
||||
for arg_info in self._ragged_args:
|
||||
if arg_info.position < len(args):
|
||||
arg = args[arg_info.position]
|
||||
else:
|
||||
arg = kwargs.get(arg_info.name, None)
|
||||
|
||||
if arg_info.is_list:
|
||||
if not isinstance(arg, (list, tuple)):
|
||||
return False
|
||||
for elt in arg:
|
||||
if ragged_tensor.is_ragged(elt):
|
||||
found_ragged = True
|
||||
elif not _is_convertible_to_tensor(elt):
|
||||
return False
|
||||
else:
|
||||
if ragged_tensor.is_ragged(arg):
|
||||
found_ragged = True
|
||||
elif not _is_convertible_to_tensor(arg):
|
||||
return False
|
||||
return found_ragged
|
||||
|
||||
|
||||
def ragged_dispatch(original_op, tensor_args):
|
||||
|
||||
def decorator(ragged_op):
|
||||
dispatch.RaggedDispatcher(original_op, ragged_op,
|
||||
tensor_args).register(original_op)
|
||||
return ragged_op
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
_UNARY_ELEMENTWISE_OPS = [
|
||||
array_ops.check_numerics,
|
||||
array_ops.identity,
|
||||
array_ops.ones_like,
|
||||
array_ops.ones_like_v2,
|
||||
array_ops.zeros_like,
|
||||
array_ops.zeros_like_v2,
|
||||
clip_ops.clip_by_value,
|
||||
math_ops.abs,
|
||||
math_ops.acos,
|
||||
math_ops.acosh,
|
||||
math_ops.angle,
|
||||
math_ops.asin,
|
||||
math_ops.asinh,
|
||||
math_ops.atan,
|
||||
math_ops.atanh,
|
||||
math_ops.cast,
|
||||
math_ops.ceil,
|
||||
math_ops.conj,
|
||||
math_ops.cos,
|
||||
math_ops.cosh,
|
||||
math_ops.digamma,
|
||||
math_ops.erf,
|
||||
math_ops.erfc,
|
||||
math_ops.exp,
|
||||
math_ops.expm1,
|
||||
math_ops.floor,
|
||||
math_ops.imag,
|
||||
math_ops.is_finite,
|
||||
math_ops.is_inf,
|
||||
math_ops.is_nan,
|
||||
math_ops.lgamma,
|
||||
math_ops.log,
|
||||
math_ops.log1p,
|
||||
math_ops.log_sigmoid,
|
||||
math_ops.logical_not,
|
||||
math_ops.negative,
|
||||
math_ops.real,
|
||||
math_ops.reciprocal,
|
||||
math_ops.rint,
|
||||
math_ops.round,
|
||||
math_ops.rsqrt,
|
||||
math_ops.saturate_cast,
|
||||
math_ops.sign,
|
||||
math_ops.sin,
|
||||
math_ops.sinh,
|
||||
math_ops.sqrt,
|
||||
math_ops.square,
|
||||
math_ops.tan,
|
||||
parsing_ops.decode_compressed,
|
||||
string_ops.string_to_number,
|
||||
string_ops.string_to_hash_bucket,
|
||||
string_ops.as_string,
|
||||
string_ops.decode_base64,
|
||||
string_ops.encode_base64,
|
||||
string_ops.regex_full_match,
|
||||
string_ops.regex_replace,
|
||||
string_ops.string_strip,
|
||||
string_ops.string_to_hash_bucket,
|
||||
string_ops.string_to_hash_bucket_fast,
|
||||
string_ops.string_to_hash_bucket_strong,
|
||||
string_ops.substr,
|
||||
string_ops.substr_v2,
|
||||
string_ops.string_length,
|
||||
string_ops.string_length_v2,
|
||||
string_ops.unicode_script,
|
||||
]
|
||||
|
||||
_UNARY_LIST_ELEMENTWISE_OPS = [
|
||||
math_ops.add_n,
|
||||
string_ops.string_join,
|
||||
]
|
||||
|
||||
_BINARY_ELEMENTWISE_OPS = [
|
||||
math_ops.add,
|
||||
math_ops.atan2,
|
||||
math_ops.complex,
|
||||
math_ops.div_no_nan,
|
||||
math_ops.divide,
|
||||
math_ops.equal,
|
||||
math_ops.floordiv,
|
||||
math_ops.floormod,
|
||||
math_ops.greater,
|
||||
math_ops.greater_equal,
|
||||
math_ops.less,
|
||||
math_ops.less_equal,
|
||||
math_ops.logical_and,
|
||||
math_ops.logical_or,
|
||||
math_ops.logical_xor,
|
||||
math_ops.maximum,
|
||||
math_ops.minimum,
|
||||
math_ops.multiply,
|
||||
math_ops.not_equal,
|
||||
math_ops.pow,
|
||||
math_ops.realdiv,
|
||||
math_ops.squared_difference,
|
||||
math_ops.subtract,
|
||||
math_ops.truediv,
|
||||
math_ops.truncatediv,
|
||||
math_ops.truncatemod,
|
||||
]
|
||||
|
||||
# (original_op, ragged_op, ragged_args)
|
||||
_RAGGED_DISPATCH_OPS = [
|
||||
(array_ops.batch_gather, ragged_array_ops.batch_gather,
|
||||
['params', 'indices']),
|
||||
(array_ops.concat, ragged_array_ops.concat, ['values']),
|
||||
(array_ops.expand_dims_v2, ragged_array_ops.expand_dims, ['input']),
|
||||
(array_ops.gather_v2, ragged_array_ops.gather, ['params', 'indices']),
|
||||
(array_ops.gather_nd, ragged_array_ops.gather_nd, ['params', 'indices']),
|
||||
(array_ops.stack, ragged_array_ops.stack, ['values']),
|
||||
(array_ops.tile, ragged_array_ops.tile, ['input']),
|
||||
(array_ops.where, ragged_array_ops.where, ['condition', 'x', 'y']),
|
||||
(math_ops.unsorted_segment_sum, ragged_math_ops.segment_sum,
|
||||
['data', 'segment_ids']),
|
||||
(math_ops.unsorted_segment_prod, ragged_math_ops.segment_prod,
|
||||
['data', 'segment_ids']),
|
||||
(math_ops.unsorted_segment_min, ragged_math_ops.segment_min,
|
||||
['data', 'segment_ids']),
|
||||
(math_ops.unsorted_segment_max, ragged_math_ops.segment_max,
|
||||
['data', 'segment_ids']),
|
||||
(math_ops.unsorted_segment_mean, ragged_math_ops.segment_mean,
|
||||
['data', 'segment_ids']),
|
||||
(math_ops.unsorted_segment_sqrt_n, ragged_math_ops.segment_sqrt_n,
|
||||
['data', 'segment_ids']),
|
||||
(math_ops.reduce_sum, ragged_math_ops.reduce_sum, ['input_tensor']),
|
||||
(math_ops.reduce_prod, ragged_math_ops.reduce_prod, ['input_tensor']),
|
||||
(math_ops.reduce_min, ragged_math_ops.reduce_min, ['input_tensor']),
|
||||
(math_ops.reduce_max, ragged_math_ops.reduce_max, ['input_tensor']),
|
||||
(math_ops.reduce_mean, ragged_math_ops.reduce_mean, ['input_tensor']),
|
||||
(math_ops.reduce_any, ragged_math_ops.reduce_any, ['input_tensor']),
|
||||
(math_ops.reduce_all, ragged_math_ops.reduce_all, ['input_tensor']),
|
||||
]
|
||||
|
||||
|
||||
def register_dispatchers():
|
||||
"""Constructs & registers OpDispatchers for ragged ops."""
|
||||
|
||||
op_list = (
|
||||
_UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS +
|
||||
_BINARY_ELEMENTWISE_OPS + [x[0] for x in _RAGGED_DISPATCH_OPS])
|
||||
for op in op_list:
|
||||
_, undecorated_op = tf_decorator.unwrap(op)
|
||||
if not hasattr(undecorated_op, tf_export.API_ATTRS['tensorflow'].names):
|
||||
raise AssertionError('Expected %s to be an exported symbol '
|
||||
'(while adding a RaggedTensor dispatcher)')
|
||||
|
||||
for op in _UNARY_ELEMENTWISE_OPS:
|
||||
UnaryRaggedElementwiseDispatcher(op).register(op)
|
||||
|
||||
for op in _UNARY_LIST_ELEMENTWISE_OPS:
|
||||
UnaryRaggedElementwiseDispatcher(op, True).register(op)
|
||||
|
||||
for op in _BINARY_ELEMENTWISE_OPS:
|
||||
BinaryRaggedElementwiseDispatcher(op).register(op)
|
||||
|
||||
for (original_op, ragged_op, args) in _RAGGED_DISPATCH_OPS:
|
||||
RaggedDispatcher(original_op, ragged_op, args).register(original_op)
|
||||
|
||||
docstring = (
|
||||
'\n\n### Additional ops that support `RaggedTensor`\n\n' + '\n'.join([
|
||||
'* `tf.%s`' % tf_export.get_canonical_name_for_symbol(op)
|
||||
for op in op_list
|
||||
]))
|
||||
|
||||
return docstring
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for ragged.elementwise_ops."""
|
||||
"""Tests for RaggedTensor operator dispatch."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -21,106 +21,108 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import ragged
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
# Constants listing various op types to test. Each elementwise operation
|
||||
# Constants listing various op types to test. Each operation
|
||||
# should be included in at least one list below, or tested separately if
|
||||
# necessary (e.g., because it expects additional arguments).
|
||||
UNARY_FLOAT_OPS = [
|
||||
ragged.abs,
|
||||
ragged.acos,
|
||||
ragged.acosh,
|
||||
ragged.angle,
|
||||
ragged.asin,
|
||||
ragged.asinh,
|
||||
ragged.atan,
|
||||
ragged.atanh,
|
||||
ragged.ceil,
|
||||
ragged.conj,
|
||||
ragged.cos,
|
||||
ragged.cosh,
|
||||
ragged.digamma,
|
||||
ragged.erf,
|
||||
ragged.erfc,
|
||||
ragged.exp,
|
||||
ragged.expm1,
|
||||
ragged.floor,
|
||||
ragged.imag,
|
||||
ragged.is_finite,
|
||||
ragged.is_inf,
|
||||
ragged.is_nan,
|
||||
ragged.lgamma,
|
||||
ragged.log,
|
||||
ragged.log1p,
|
||||
ragged.log_sigmoid,
|
||||
ragged.negative,
|
||||
ragged.real,
|
||||
ragged.reciprocal,
|
||||
ragged.rint,
|
||||
ragged.round,
|
||||
ragged.rsqrt,
|
||||
ragged.sign,
|
||||
ragged.sin,
|
||||
ragged.sinh,
|
||||
ragged.sqrt,
|
||||
ragged.square,
|
||||
ragged.tan,
|
||||
ragged.as_string,
|
||||
ragged.identity,
|
||||
ragged.ones_like,
|
||||
ragged.zeros_like,
|
||||
math_ops.abs,
|
||||
math_ops.acos,
|
||||
math_ops.acosh,
|
||||
math_ops.angle,
|
||||
math_ops.asin,
|
||||
math_ops.asinh,
|
||||
math_ops.atan,
|
||||
math_ops.atanh,
|
||||
math_ops.ceil,
|
||||
math_ops.conj,
|
||||
math_ops.cos,
|
||||
math_ops.cosh,
|
||||
math_ops.digamma,
|
||||
math_ops.erf,
|
||||
math_ops.erfc,
|
||||
math_ops.exp,
|
||||
math_ops.expm1,
|
||||
math_ops.floor,
|
||||
math_ops.imag,
|
||||
math_ops.is_finite,
|
||||
math_ops.is_inf,
|
||||
math_ops.is_nan,
|
||||
math_ops.lgamma,
|
||||
math_ops.log,
|
||||
math_ops.log1p,
|
||||
math_ops.log_sigmoid,
|
||||
math_ops.negative,
|
||||
math_ops.real,
|
||||
math_ops.reciprocal,
|
||||
math_ops.rint,
|
||||
math_ops.round,
|
||||
math_ops.rsqrt,
|
||||
math_ops.sign,
|
||||
math_ops.sin,
|
||||
math_ops.sinh,
|
||||
math_ops.sqrt,
|
||||
math_ops.square,
|
||||
math_ops.tan,
|
||||
array_ops.identity,
|
||||
array_ops.ones_like,
|
||||
array_ops.zeros_like,
|
||||
]
|
||||
UNARY_BOOL_OPS = [
|
||||
ragged.logical_not,
|
||||
math_ops.logical_not,
|
||||
]
|
||||
UNARY_STRING_OPS = [
|
||||
ragged.decode_base64,
|
||||
ragged.encode_base64,
|
||||
ragged.string_strip,
|
||||
ragged.decode_compressed,
|
||||
string_ops.decode_base64,
|
||||
string_ops.encode_base64,
|
||||
string_ops.string_strip,
|
||||
parsing_ops.decode_compressed,
|
||||
]
|
||||
BINARY_FLOAT_OPS = [
|
||||
ragged.add,
|
||||
ragged.atan2,
|
||||
ragged.complex,
|
||||
ragged.div,
|
||||
ragged.div_no_nan,
|
||||
ragged.divide,
|
||||
ragged.equal,
|
||||
ragged.floordiv,
|
||||
ragged.floormod,
|
||||
ragged.greater,
|
||||
ragged.greater_equal,
|
||||
ragged.less,
|
||||
ragged.less_equal,
|
||||
ragged.maximum,
|
||||
ragged.minimum,
|
||||
ragged.multiply,
|
||||
ragged.not_equal,
|
||||
ragged.pow,
|
||||
ragged.realdiv,
|
||||
ragged.squared_difference,
|
||||
ragged.subtract,
|
||||
ragged.truediv,
|
||||
math_ops.add,
|
||||
math_ops.atan2,
|
||||
math_ops.complex,
|
||||
math_ops.div_no_nan,
|
||||
math_ops.divide,
|
||||
math_ops.equal,
|
||||
math_ops.floordiv,
|
||||
math_ops.floormod,
|
||||
math_ops.greater,
|
||||
math_ops.greater_equal,
|
||||
math_ops.less,
|
||||
math_ops.less_equal,
|
||||
math_ops.maximum,
|
||||
math_ops.minimum,
|
||||
math_ops.multiply,
|
||||
math_ops.not_equal,
|
||||
math_ops.pow,
|
||||
math_ops.realdiv,
|
||||
math_ops.squared_difference,
|
||||
math_ops.subtract,
|
||||
math_ops.truediv,
|
||||
]
|
||||
BINARY_BOOL_OPS = [
|
||||
ragged.logical_and,
|
||||
ragged.logical_or,
|
||||
ragged.logical_xor,
|
||||
math_ops.logical_and,
|
||||
math_ops.logical_or,
|
||||
math_ops.logical_xor,
|
||||
]
|
||||
UNARY_INT_OPS = [
|
||||
ragged.unicode_script,
|
||||
string_ops.unicode_script,
|
||||
]
|
||||
BINARY_INT_OPS = [
|
||||
ragged.truncatediv,
|
||||
ragged.truncatemod,
|
||||
math_ops.truncatediv,
|
||||
math_ops.truncatemod,
|
||||
]
|
||||
|
||||
|
||||
@ -171,50 +173,49 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
[{'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]), 'op': op}
|
||||
for op in UNARY_STRING_OPS] +
|
||||
[
|
||||
{'op': ragged.clip_by_value,
|
||||
{'op': clip_ops.clip_by_value,
|
||||
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
|
||||
'clip_value_min': 0.1, 'clip_value_max': 4.0},
|
||||
{'op': ragged.cast,
|
||||
{'op': math_ops.cast,
|
||||
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
|
||||
'dtype': dtypes.int32},
|
||||
{'op': ragged.saturate_cast,
|
||||
{'op': math_ops.saturate_cast,
|
||||
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
|
||||
'dtype': dtypes.int32},
|
||||
{'op': ragged.string_to_hash_bucket,
|
||||
{'op': string_ops.string_to_hash_bucket,
|
||||
'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
|
||||
'num_buckets': 1000},
|
||||
{'op': ragged.string_to_hash_bucket_fast,
|
||||
{'op': string_ops.string_to_hash_bucket_fast,
|
||||
'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
|
||||
'num_buckets': 1000},
|
||||
{'op': ragged.string_to_hash_bucket_strong,
|
||||
{'op': string_ops.string_to_hash_bucket_strong,
|
||||
'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
|
||||
'num_buckets': 1000,
|
||||
'key': [1231, 12512]},
|
||||
{'op': ragged.string_to_number,
|
||||
{'op': string_ops.string_to_number,
|
||||
'x': ragged.constant_value([['-2.0', '3.0'], ['-3.0']])},
|
||||
{'op': ragged.regex_full_match,
|
||||
{'op': string_ops.regex_full_match,
|
||||
'x': ragged.constant_value([['hello', '123'], ['1+1']]),
|
||||
'pattern': r'\w+'},
|
||||
{'op': ragged.regex_replace,
|
||||
{'op': string_ops.regex_replace,
|
||||
'x': ragged.constant_value([['hello', '123'], ['1+1']]),
|
||||
'pattern': r'\d',
|
||||
'rewrite': '#'},
|
||||
{'op': ragged.substr,
|
||||
{'op': string_ops.substr,
|
||||
'x': ragged.constant_value([['hello', '123'], ['1+1']]),
|
||||
'pos': 2, 'len': 3},
|
||||
{'op': ragged.check_numerics,
|
||||
{'op': array_ops.check_numerics,
|
||||
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
|
||||
'message': 'check-numerics'},
|
||||
]
|
||||
) # pyformat: disable
|
||||
def testUnaryOp(self, x, op=ragged.abs, **extra_args):
|
||||
def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args):
|
||||
x = ragged.convert_to_tensor_or_ragged_tensor(x)
|
||||
result = op(x, **extra_args)
|
||||
|
||||
# Run the wrapped op on the dense values, for comparison.
|
||||
dense_x = x.inner_values if isinstance(x, ragged.RaggedTensor) else x
|
||||
expected_flat_values = array_ops.reshape(
|
||||
op.__wrapped__(dense_x, **extra_args), [-1])
|
||||
expected_flat_values = array_ops.reshape(op(dense_x, **extra_args), [-1])
|
||||
|
||||
with self.test_session():
|
||||
# Check that the result has the expected shape.
|
||||
@ -285,12 +286,17 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
#=====================================================================
|
||||
{'x': ragged.constant_value([[[1, 2], [3], [4]], [[], [5, 7, 8]]]),
|
||||
'y': ragged.constant_value([[[3, 8], [2], [5]], [[], [1, 9, 8]]]),
|
||||
'use_kwargs': True},
|
||||
'use_kwargs': ('x', 'y')},
|
||||
{'x': ragged.constant_value([[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
|
||||
ragged_rank=1),
|
||||
'y': ragged.constant_value([[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
|
||||
ragged_rank=1),
|
||||
'use_kwargs': True},
|
||||
'use_kwargs': ('x', 'y')},
|
||||
{'x': ragged.constant_value([[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
|
||||
ragged_rank=1),
|
||||
'y': ragged.constant_value([[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
|
||||
ragged_rank=1),
|
||||
'use_kwargs': ('x',)},
|
||||
] +
|
||||
#=========================================================================
|
||||
# Test each unary op.
|
||||
@ -306,16 +312,16 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
[{'x': ragged.constant_value([[True, True], [False]]),
|
||||
'y': ragged.constant_value([[False, True], [False]]),
|
||||
'op': op}
|
||||
for op in BINARY_BOOL_OPS] +
|
||||
[
|
||||
]
|
||||
for op in BINARY_BOOL_OPS]
|
||||
) # pyformat: disable
|
||||
def testBinaryOp(self, x, y, op=ragged.add, **extra_args):
|
||||
use_kwargs = extra_args.pop('use_kwargs', False)
|
||||
def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args):
|
||||
use_kwargs = extra_args.pop('use_kwargs', ())
|
||||
x = ragged.convert_to_tensor_or_ragged_tensor(x)
|
||||
y = ragged.convert_to_tensor_or_ragged_tensor(y)
|
||||
if use_kwargs:
|
||||
if 'x' in use_kwargs and 'y' in use_kwargs:
|
||||
result = op(x=x, y=y, **extra_args)
|
||||
elif 'y' in use_kwargs:
|
||||
result = op(x, y=y, **extra_args)
|
||||
else:
|
||||
result = op(x, y, **extra_args)
|
||||
|
||||
@ -323,7 +329,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
dense_x = x.inner_values if isinstance(x, ragged.RaggedTensor) else x
|
||||
dense_y = y.inner_values if isinstance(y, ragged.RaggedTensor) else y
|
||||
expected_flat_values = array_ops.reshape(
|
||||
op.__wrapped__(dense_x, dense_y, **extra_args), [-1])
|
||||
op(dense_x, dense_y, **extra_args), [-1])
|
||||
|
||||
with self.test_session():
|
||||
# Check that the result has the expected shape.
|
||||
@ -358,16 +364,17 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
ragged.constant_value([[[2, 9], [12]], [[8]]])),
|
||||
'use_kwargs': True},
|
||||
] + [
|
||||
{'op': ragged.add_n,
|
||||
{'op': math_ops.add_n,
|
||||
'inputs': (ragged.constant_value([[1, 3], [-3]]),
|
||||
ragged.constant_value([[4, 7], [88]]),
|
||||
ragged.constant_value([[2, 9], [12]]))},
|
||||
{'op': ragged.string_join,
|
||||
{'op': string_ops.string_join,
|
||||
'inputs': (ragged.constant_value([['a', 'b'], ['c']]),
|
||||
ragged.constant_value([['foo', 'bar'], ['baz']]),
|
||||
ragged.constant_value([['2', '9'], ['12']]))},
|
||||
]) # pyformat: disable
|
||||
def testListValuedOp(self, inputs, op=ragged.add_n, **extra_args):
|
||||
def testListValuedElementwiseOp(self, inputs, op=math_ops.add_n,
|
||||
**extra_args):
|
||||
use_kwargs = extra_args.pop('use_kwargs', False)
|
||||
inputs = [ragged.convert_to_tensor_or_ragged_tensor(x) for x in inputs]
|
||||
if use_kwargs:
|
||||
@ -381,7 +388,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
for x in inputs
|
||||
]
|
||||
expected_flat_values = array_ops.reshape(
|
||||
op.__wrapped__(dense_inputs, **extra_args), [-1])
|
||||
op(dense_inputs, **extra_args), [-1])
|
||||
|
||||
with self.test_session():
|
||||
# Check that the result has the expected shape.
|
||||
@ -395,13 +402,13 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
self.assertAllEqual(expected_flat_values, result_flat_values)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testUnknownRankError(self):
|
||||
def testElementwiseOpUnknownRankError(self):
|
||||
x = ragged.constant([[1, 2], [3]])
|
||||
y = ragged.from_row_splits(
|
||||
array_ops.placeholder_with_default([1, 2, 3], shape=None), x.row_splits)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r'Unable to broadcast: unknown rank'):
|
||||
ragged.add(x, y)
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
r'Unable to broadcast: unknown rank'):
|
||||
math_ops.add(x, y)
|
||||
|
||||
@parameterized.parameters([
|
||||
dict(
|
||||
@ -417,26 +424,31 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
y=ragged.constant_value([[1]]),
|
||||
expected=[[[2]]]),
|
||||
])
|
||||
def testBroadcastAdd(self, x, y, expected):
|
||||
def testElementwiseOpBroadcast(self, x, y, expected):
|
||||
x = ragged.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32)
|
||||
y = ragged.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32)
|
||||
result = x + y
|
||||
with self.cached_session():
|
||||
self.assertEqual(result.eval().tolist(), expected)
|
||||
|
||||
def testShapeMismatch(self):
|
||||
def testElementwiseOpShapeMismatch(self):
|
||||
x = ragged.constant([[1, 2, 3], [4, 5]])
|
||||
y = ragged.constant([[1, 2, 3], [4, 5, 6]])
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
'Incompatible shapes'):
|
||||
with self.cached_session():
|
||||
ragged.add(x, y).eval()
|
||||
math_ops.add(x, y).eval()
|
||||
|
||||
def testDocstring(self):
|
||||
self.assertRegexpMatches(
|
||||
ragged.add.__doc__,
|
||||
'Ragged version of the elementwise operation `tf.math.add`')
|
||||
self.assertEqual(ragged.add.__name__, 'add')
|
||||
def testBinaryOpSparseAndRagged(self):
|
||||
x = ragged.constant([[1, 2, 3], [4, 5]])
|
||||
y = sparse_tensor.SparseTensor([[0, 0], [0, 1], [2, 0]], [1, 2, 3], [3, 2])
|
||||
with self.assertRaises(TypeError):
|
||||
with self.cached_session():
|
||||
math_ops.add(x, y).eval()
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
with self.cached_session():
|
||||
math_ops.add_n([x, y]).eval()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
@ -1,389 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Elementwise operations for RaggedTensors."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_tensor_shape
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_export
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
# Information about an argument to an operation: The name of the argument, its
|
||||
# position in the argument list, and a boolean flag indicating whether it
|
||||
# expects a list of tensors.
|
||||
_ArgInfo = collections.namedtuple('ArgInfo', ['name', 'position', 'is_list'])
|
||||
|
||||
|
||||
def make_elementwise_op(op, *elementwise_args):
|
||||
"""Returns a ragged-tensor version of the elementwise operation `op`.
|
||||
|
||||
The returned operation will:
|
||||
|
||||
1. Broadcast the elementwise arguments to have a compatible shape.
|
||||
An exception is raised if the tensors not broadcast-compatible.
|
||||
2. Call `op`, substituting the dense values of the broadcasted tensor for
|
||||
each elementwise argument.
|
||||
3. Return a potentially ragged tensor constructed from the output of `op`
|
||||
and the broadcasted tensors' nested row splits.
|
||||
|
||||
For example, you can construct a ragged-tensor version of the standard
|
||||
operation `tf.add` by calling `make_elementwise_op(tf.add, 'x', 'y')`.
|
||||
|
||||
Args:
|
||||
op: The operation to wrap.
|
||||
*elementwise_args: The names of arguments to `op` that are treated as
|
||||
elementwise. Arguments that take a list of tensors should have their
|
||||
names wrapped in square brackets (e.g. "[inputs]").
|
||||
|
||||
Raises:
|
||||
ValueError: If any name specified in `elementwise_args` is not the name
|
||||
of an argument to `op`.
|
||||
"""
|
||||
elementwise_arg_infos = _get_arg_infos(op, elementwise_args)
|
||||
|
||||
def ragged_op(*args, **kwargs):
|
||||
"""Ragged version of `op`."""
|
||||
args = list(args)
|
||||
|
||||
# Collect all of the elementwise arguments, and put them in a single
|
||||
# dict whose values are the (potentially ragged) tensors that need to
|
||||
# be broadcast to a common shape. The keys of this dict are tuples
|
||||
# (argkey, index), where argkey is an int for poitional args or a string
|
||||
# for keyword args; and index is None for non-list args and the index of the
|
||||
# tensor for list args.
|
||||
elementwise_args = {}
|
||||
for (name, position, is_list) in elementwise_arg_infos.values():
|
||||
if position < len(args):
|
||||
if is_list:
|
||||
args[position] = list(args[position])
|
||||
for (index, arg) in enumerate(args[position]):
|
||||
elementwise_args[position, index] = arg
|
||||
else:
|
||||
elementwise_args[position, None] = args[position]
|
||||
elif name in kwargs:
|
||||
if is_list:
|
||||
kwargs[name] = list(kwargs[name])
|
||||
for (i, arg) in enumerate(kwargs[name]):
|
||||
elementwise_args[name, i] = arg
|
||||
else:
|
||||
elementwise_args[name, None] = kwargs[name]
|
||||
|
||||
with ops.name_scope(None, op.__name__, elementwise_args.values()):
|
||||
# Convert all inputs to tensors or ragged tensors.
|
||||
for ((key, index), tensor) in elementwise_args.items():
|
||||
argname = elementwise_arg_infos[key].name
|
||||
converted = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
|
||||
tensor, name=argname)
|
||||
elementwise_args[key, index] = converted
|
||||
|
||||
# Broadcast tensors to have compatible shapes.
|
||||
broadcast_args, result_splits, broadcast_check_ops = \
|
||||
_broadcast_elementwise_args(elementwise_args)
|
||||
|
||||
# Replace tensor arguments with their dense values.
|
||||
for ((key, index), tensor) in broadcast_args.items():
|
||||
if ragged_tensor.is_ragged(tensor):
|
||||
if isinstance(key, int) and index is None:
|
||||
args[key] = tensor.inner_values
|
||||
elif isinstance(key, int) and index is not None:
|
||||
args[key][index] = tensor.inner_values
|
||||
elif isinstance(key, str) and index is None:
|
||||
kwargs[key] = tensor.inner_values
|
||||
else:
|
||||
assert isinstance(key, str) and index is not None
|
||||
kwargs[key][index] = tensor.inner_values
|
||||
|
||||
# Call the elementwise op on the broadcasted dense values.
|
||||
with ops.control_dependencies(broadcast_check_ops):
|
||||
result_values = op(*args, **kwargs)
|
||||
|
||||
# Restore any ragged dimensions that we stripped off, and return the
|
||||
# result.
|
||||
return ragged_factory_ops.from_nested_row_splits(result_values,
|
||||
result_splits)
|
||||
|
||||
# Construct the docstring.
|
||||
op_name = tf_export.get_canonical_name_for_symbol(op)
|
||||
assert op_name is not None, op
|
||||
argnames = ', '.join('`%s`' % s.strip('[]') for s in elementwise_args)
|
||||
docstring = _ELEMENTWISE_DOCSTRING % dict(op_name=op_name, argnames=argnames)
|
||||
|
||||
# Update name, docstring, signature, etc., for the wrapper, and return it.
|
||||
return tf_decorator.make_decorator(op, ragged_op, decorator_doc=docstring)
|
||||
|
||||
|
||||
_ELEMENTWISE_DOCSTRING = """\
|
||||
Ragged version of the elementwise operation `tf.%(op_name)s`.
|
||||
|
||||
The following elementwise arguments may be ragged or dense:
|
||||
%(argnames)s.
|
||||
These arguments will be broadcast to a compatible shape if necessary.
|
||||
"""
|
||||
|
||||
|
||||
def _get_arg_infos(func, elementwise_args):
|
||||
"""Returns `_ArgInfo`s for each `func` arg specified by `elementwise_args`.
|
||||
|
||||
Args:
|
||||
func: The function whose arguments should be described.
|
||||
elementwise_args: The names of the arguments to get info for.
|
||||
|
||||
Returns:
|
||||
A dictionary that maps both names and positions of arguments to
|
||||
`_ArgInfo` tuples.
|
||||
"""
|
||||
arg_infos = {}
|
||||
|
||||
# Inspect the func's argspec to find the position of each arg.
|
||||
arg_spec = tf_inspect.getargspec(func)
|
||||
for argname in elementwise_args:
|
||||
assert isinstance(argname, str)
|
||||
is_list = argname.startswith('[') and argname.endswith(']')
|
||||
if is_list:
|
||||
argname = argname[1:-1]
|
||||
assert argname in arg_spec.args, (func, argname, arg_spec.args)
|
||||
arg_info = _ArgInfo(argname, arg_spec.args.index(argname), is_list)
|
||||
arg_infos[arg_info.name] = arg_info
|
||||
arg_infos[arg_info.position] = arg_info
|
||||
return arg_infos
|
||||
|
||||
|
||||
def _broadcast_elementwise_args(elementwise_args):
|
||||
"""Broadcasts the values of `elementwise_args` to have compatible shapes.
|
||||
|
||||
Args:
|
||||
elementwise_args: A dictionary whose keys are potentially ragged tensors.
|
||||
|
||||
Returns:
|
||||
A tuple `(broadcast_args, broadcast_splits, checks)` where:
|
||||
|
||||
* `broadcast_args` is a dictionary with the same keys as
|
||||
`elementwise_args`, mapping to broadcasted tensors.
|
||||
* `broadcast_splits` is the broadcasted nested row splits.
|
||||
* `checks` is a possibly empty tuple of assertion operations that should
|
||||
be added as control dependencies.
|
||||
|
||||
Raises:
|
||||
ValueError: If broadcasting fails.
|
||||
"""
|
||||
# No elementwise arguments were used: nothing to do!
|
||||
if not elementwise_args:
|
||||
return elementwise_args, (), ()
|
||||
|
||||
# A single elementwise argument was used: no broadcasting necessary.
|
||||
if len(elementwise_args) == 1:
|
||||
arg = list(elementwise_args.values())[0]
|
||||
if ragged_tensor.is_ragged(arg):
|
||||
return elementwise_args, arg.nested_row_splits, ()
|
||||
else:
|
||||
return elementwise_args, (), ()
|
||||
|
||||
# Multiple elementwise arguments.
|
||||
else:
|
||||
is_ragged = [ragged_tensor.is_ragged(t) for t in elementwise_args.values()]
|
||||
if not any(is_ragged):
|
||||
return elementwise_args, (), ()
|
||||
|
||||
# If we have a single ragged tensor plus a set of scalars, then we can
|
||||
# rely on the underlying elementwise op to do broadcasting.
|
||||
if (sum(is_ragged) == 1 and
|
||||
all((ragged_tensor.is_ragged(t) or t.shape.ndims == 0)
|
||||
for t in elementwise_args.values())):
|
||||
nested_splits_lists = [
|
||||
t.nested_row_splits
|
||||
for t in elementwise_args.values()
|
||||
if ragged_tensor.is_ragged(t)][0]
|
||||
return elementwise_args, nested_splits_lists, ()
|
||||
|
||||
else:
|
||||
# Get the shapes of all the elementwise arguments.
|
||||
shapes = [ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(t)
|
||||
for t in elementwise_args.values()]
|
||||
|
||||
# Broadcast the shapes to all have the same rank (the max rank).
|
||||
ranks = [t.shape.ndims for t in elementwise_args.values()]
|
||||
if any(rank is None for rank in ranks):
|
||||
raise ValueError('Unable to broadcast: unknown rank')
|
||||
broadcast_rank = max(ranks)
|
||||
shapes = [shape.broadcast_to_rank(broadcast_rank) for shape in shapes]
|
||||
|
||||
# For each dimension, broadcast the shapes to be compatible.
|
||||
for axis in range(broadcast_rank):
|
||||
# For each i, broadcast shape[i+1] to be compatible with shape[i]; and
|
||||
# then finally broadcast shape[0] to be compatible with shape[-1].
|
||||
for i in range(len(shapes)):
|
||||
j = (i + 1) % len(shapes)
|
||||
dim_size = shapes[i].dimension_size(axis)
|
||||
shapes[j] = shapes[j].broadcast_dimension(axis, dim_size)
|
||||
broadcast_shape = shapes[0]
|
||||
|
||||
# Broadcast every elementwise arg to the shape that we calculated.
|
||||
elementwise_args = dict([
|
||||
(key, ragged_tensor_shape.broadcast_to(t, broadcast_shape, False))
|
||||
for (key, t) in elementwise_args.items()])
|
||||
nested_splits_lists = list(elementwise_args.values())[0].nested_row_splits
|
||||
return elementwise_args, nested_splits_lists, ()
|
||||
|
||||
|
||||
# A list of symbols that should be exported in the "ragged" package.
|
||||
_symbols_to_export = []
|
||||
|
||||
|
||||
def _add_elementwise_ops_to_this_module(specs, verbose=False):
|
||||
"""Adds ragged versions of the given ops to this module.
|
||||
|
||||
Args:
|
||||
specs: A list of tuples containing the arguments for `make_elementwise_op`.
|
||||
verbose: If true, then display each op that gets added.
|
||||
"""
|
||||
for spec in specs:
|
||||
original_op = spec[0]
|
||||
ragged_op = make_elementwise_op(*spec)
|
||||
canonical_name = tf_export.get_canonical_name_for_symbol(original_op)
|
||||
if '.' not in canonical_name:
|
||||
op_name = canonical_name
|
||||
else:
|
||||
op_name = original_op.__name__
|
||||
|
||||
# Temporary hack (will be removed once dispatch is added for RaggedTensors):
|
||||
if op_name == 'neg': op_name = 'negative'
|
||||
|
||||
if verbose:
|
||||
print('Adding ragged_elementwise_op: tf.ragged.%s (based on tf.%s)' %
|
||||
(op_name, canonical_name))
|
||||
globals()[op_name] = ragged_op
|
||||
_symbols_to_export.append(op_name)
|
||||
|
||||
|
||||
# A list of tuples containing arguments for `make_elementwise_op`, for each
|
||||
# elementwise operation that should have a ragged version built. Each tuple
|
||||
# contains a standard `Tensor` operation, and the names of any arguments
|
||||
# that are processed in elementwise fashion.
|
||||
_TF_ELEMENTWISE_OPS = [
|
||||
# Unary math operations.
|
||||
(clip_ops.clip_by_value, 't'),
|
||||
(math_ops.abs, 'x'),
|
||||
(math_ops.acos, 'x'),
|
||||
(math_ops.acosh, 'x'),
|
||||
(math_ops.angle, 'input'),
|
||||
(math_ops.asin, 'x'),
|
||||
(math_ops.asinh, 'x'),
|
||||
(math_ops.atan, 'x'),
|
||||
(math_ops.atanh, 'x'),
|
||||
(math_ops.cast, 'x'),
|
||||
(math_ops.ceil, 'x'),
|
||||
(math_ops.conj, 'x'),
|
||||
(math_ops.cos, 'x'),
|
||||
(math_ops.cosh, 'x'),
|
||||
(math_ops.digamma, 'x'),
|
||||
(math_ops.erf, 'x'),
|
||||
(math_ops.erfc, 'x'),
|
||||
(math_ops.exp, 'x'),
|
||||
(math_ops.expm1, 'x'),
|
||||
(math_ops.floor, 'x'),
|
||||
(math_ops.imag, 'input'),
|
||||
(math_ops.is_finite, 'x'),
|
||||
(math_ops.is_inf, 'x'),
|
||||
(math_ops.is_nan, 'x'),
|
||||
(math_ops.lgamma, 'x'),
|
||||
(math_ops.log, 'x'),
|
||||
(math_ops.log1p, 'x'),
|
||||
(math_ops.log_sigmoid, 'x'),
|
||||
(math_ops.logical_not, 'x'),
|
||||
(math_ops.negative, 'x'),
|
||||
(math_ops.real, 'input'),
|
||||
(math_ops.reciprocal, 'x'),
|
||||
(math_ops.rint, 'x'),
|
||||
(math_ops.round, 'x'),
|
||||
(math_ops.rsqrt, 'x'),
|
||||
(math_ops.saturate_cast, 'value'),
|
||||
(math_ops.sign, 'x'),
|
||||
(math_ops.sin, 'x'),
|
||||
(math_ops.sinh, 'x'),
|
||||
(math_ops.sqrt, 'x'),
|
||||
(math_ops.square, 'x'),
|
||||
(math_ops.tan, 'x'),
|
||||
|
||||
# Binary math operations
|
||||
(math_ops.add, 'x', 'y'),
|
||||
(math_ops.atan2, 'y', 'x'),
|
||||
(math_ops.complex, 'real', 'imag'),
|
||||
(math_ops.div, 'x', 'y'),
|
||||
(math_ops.div_no_nan, 'x', 'y'),
|
||||
(math_ops.divide, 'x', 'y'),
|
||||
(math_ops.equal, 'x', 'y'),
|
||||
(math_ops.floordiv, 'x', 'y'),
|
||||
(math_ops.floormod, 'x', 'y'),
|
||||
(math_ops.greater, 'x', 'y'),
|
||||
(math_ops.greater_equal, 'x', 'y'),
|
||||
(math_ops.less, 'x', 'y'),
|
||||
(math_ops.less_equal, 'x', 'y'),
|
||||
(math_ops.logical_and, 'x', 'y'),
|
||||
(math_ops.logical_or, 'x', 'y'),
|
||||
(math_ops.logical_xor, 'x', 'y'),
|
||||
(math_ops.maximum, 'x', 'y'),
|
||||
(math_ops.minimum, 'x', 'y'),
|
||||
(math_ops.multiply, 'x', 'y'),
|
||||
(math_ops.not_equal, 'x', 'y'),
|
||||
(math_ops.pow, 'x', 'y'),
|
||||
(math_ops.realdiv, 'x', 'y'),
|
||||
(math_ops.squared_difference, 'x', 'y'),
|
||||
(math_ops.subtract, 'x', 'y'),
|
||||
(math_ops.truediv, 'x', 'y'),
|
||||
(math_ops.truncatediv, 'x', 'y'),
|
||||
(math_ops.truncatemod, 'x', 'y'),
|
||||
|
||||
# N-ary math operations
|
||||
(math_ops.add_n, '[inputs]'),
|
||||
|
||||
# String operations
|
||||
(string_ops.as_string, 'input'),
|
||||
(string_ops.decode_base64, 'input'),
|
||||
(string_ops.encode_base64, 'input'),
|
||||
(string_ops.regex_full_match, 'input'),
|
||||
(string_ops.regex_replace, 'input'),
|
||||
(string_ops.string_join, '[inputs]'),
|
||||
(string_ops.string_strip, 'input'),
|
||||
(string_ops.string_to_hash_bucket, 'input'),
|
||||
(string_ops.string_to_hash_bucket_fast, 'input'),
|
||||
(string_ops.string_to_hash_bucket_strong, 'input'),
|
||||
(string_ops.substr, 'input'),
|
||||
(string_ops.unicode_script, 'input'),
|
||||
|
||||
# Array ops
|
||||
(array_ops.check_numerics, 'tensor'),
|
||||
(array_ops.identity, 'input'),
|
||||
(array_ops.ones_like, 'tensor'),
|
||||
(array_ops.zeros_like, 'tensor'),
|
||||
|
||||
# Parsing ops
|
||||
(parsing_ops.decode_compressed, 'bytes'),
|
||||
(parsing_ops.string_to_number, 'string_tensor'),
|
||||
]
|
||||
_add_elementwise_ops_to_this_module(_TF_ELEMENTWISE_OPS)
|
||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -55,7 +56,7 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
),
|
||||
# [d1, (d2)] -> [d1, (d2)]
|
||||
dict(
|
||||
fn=lambda x: x+1,
|
||||
fn=lambda x: x + np.int64(1),
|
||||
elems=[[1, 2, 3], [4, 5], [6, 7]],
|
||||
expected_output=[[2, 3, 4], [5, 6], [7, 8]],
|
||||
dtype=dtypes.int64,
|
||||
@ -64,7 +65,7 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
),
|
||||
# [d1, (d2), d3] -> [d1, (d2), d3]
|
||||
dict(
|
||||
fn=lambda x: x+1,
|
||||
fn=lambda x: x + np.int64(1),
|
||||
elems=[[[1, 2], [3, 4]], [], [[5, 6], [7, 8], [9, 0]]],
|
||||
elems_ragged_rank=1,
|
||||
expected_ragged_rank=1,
|
||||
@ -131,7 +132,7 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
),
|
||||
# [d1, (d2), (d3), (d4a), (d5)] -> [d1, (d2), (d3), (d4b), (d5)]
|
||||
dict(
|
||||
fn=lambda x: ragged.add(x, 1),
|
||||
fn=lambda x: x + np.int64(1),
|
||||
elems=[[[[[1, 2, 3]], [[4], [5]]]], [[[[6, 7]]], [[[8], []]]]],
|
||||
expected_output=[[[[[2, 3, 4]], [[5], [6]]]],
|
||||
[[[[7, 8]]], [[[9], []]]]],
|
||||
@ -196,8 +197,8 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
def _increment(f):
|
||||
return {
|
||||
'batman': ragged.add(f['batman'], 1),
|
||||
'robin': ragged.add(f['robin'], 1),
|
||||
'batman': f['batman'] + 1,
|
||||
'robin': f['robin'] + 1,
|
||||
}
|
||||
|
||||
output = ragged.map_fn(
|
||||
|
@ -143,8 +143,11 @@ Computes the %(combination)s along segments of a RaggedTensor.
|
||||
"""
|
||||
|
||||
|
||||
def _ragged_segment_aggregate(unsorted_segment_op, data, segment_ids,
|
||||
num_segments, name=None):
|
||||
def _ragged_segment_aggregate(unsorted_segment_op,
|
||||
data,
|
||||
segment_ids,
|
||||
num_segments,
|
||||
name=None):
|
||||
"""Aggregates along segments of a RaggedTensor using `unsorted_segment_op`.
|
||||
|
||||
Returns a RaggedTensor `output` with `num_segments` rows, where the row
|
||||
@ -212,12 +215,11 @@ def _ragged_segment_aggregate(unsorted_segment_op, data, segment_ids,
|
||||
assert output_row_lengths.dtype == dtypes.int64
|
||||
|
||||
# Build the splits tensor for the output RaggedTensor.
|
||||
output_splits = array_ops.concat(
|
||||
[
|
||||
array_ops.zeros([1], dtypes.int64),
|
||||
math_ops.cumsum(output_row_lengths)
|
||||
],
|
||||
axis=0)
|
||||
output_splits = array_ops.concat([
|
||||
array_ops.zeros([1], dtypes.int64),
|
||||
math_ops.cumsum(output_row_lengths)
|
||||
],
|
||||
axis=0)
|
||||
|
||||
# For each row in `data`, find the start & limit position where that row's
|
||||
# values will be aggregated in output.values.
|
||||
@ -311,7 +313,7 @@ _set_ragged_segment_docstring(segment_sqrt_n, 'sum divided by sqrt(N)',
|
||||
_RAGGED_REDUCE_DOCSTRING = """\
|
||||
Computes the %(combination)s of elements across dimensions of a `RaggedTensor`.
|
||||
|
||||
Reduces `rt_input` along the dimensions given in `axis` by taking the
|
||||
Reduces `input_tensor` along the dimensions given in `axis` by taking the
|
||||
%(combination)s of values. If a reduced dimension has no elements for
|
||||
some index, then the value for that index will be %(default)s.
|
||||
|
||||
@ -319,18 +321,18 @@ Computes the %(combination)s of elements across dimensions of a `RaggedTensor`.
|
||||
`axis` is not specified, then all dimensions are reduced, and a scalar
|
||||
value is returned.
|
||||
Args:
|
||||
rt_input: A `RaggedTensor` containing the values to be %(combined)s.
|
||||
input_tensor: A `RaggedTensor` containing the values to be %(combined)s.
|
||||
axis: The dimensions to reduce. May be `None` (to reduce all axes), an
|
||||
`int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce
|
||||
a given set of axes), or a `Tensor` with a constant value. Must be in
|
||||
the range `[0, rt_input.rank]`.
|
||||
the range `[0, input_tensor.rank]`.
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
Returns:
|
||||
A `RaggedTensor` containing the %(combined)s values. The returned tensor
|
||||
has the same dtype as `data`, and its shape is given by removing the
|
||||
dimensions specified in `axis` from `rt_input.shape`. The `ragged_rank`
|
||||
dimensions specified in `axis` from `input_tensor.shape`. The `ragged_rank`
|
||||
of the returned tensor is given by substracting any ragged dimensions
|
||||
specified in `axis` from `rt_input.ragged_rank`.
|
||||
specified in `axis` from `input_tensor.ragged_rank`.
|
||||
Raises:
|
||||
ValueError: If `axis` contains a `Tensor` whose value is not constant.
|
||||
####Example:
|
||||
@ -387,7 +389,11 @@ _RAGGED_REDUCE_ANY_EXAMPLE = """
|
||||
"""
|
||||
|
||||
|
||||
def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis,
|
||||
def _ragged_reduce_aggregate(reduce_op,
|
||||
unsorted_segment_op,
|
||||
rt_input,
|
||||
axis,
|
||||
keepdims,
|
||||
name=None):
|
||||
"""Aggregates across axes of a RaggedTensor using the given `Tensor` ops.
|
||||
|
||||
@ -412,6 +418,7 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis,
|
||||
`int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce a
|
||||
given set of axes), or a `Tensor` with a constant value. Must be in the
|
||||
range `[0, rt_input.rank)`.
|
||||
keepdims: If true, retains reduced dimensions with length 1.
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
|
||||
Returns:
|
||||
@ -426,6 +433,9 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis,
|
||||
if not ragged_tensor.is_ragged(rt_input):
|
||||
return reduce_op(rt_input, axis, name=name)
|
||||
|
||||
if keepdims:
|
||||
raise ValueError('keepdims=True is not supported for RaggedTensors.')
|
||||
|
||||
if isinstance(axis, ops.Tensor):
|
||||
axis = tensor_util.constant_value(axis)
|
||||
if axis is None:
|
||||
@ -448,9 +458,9 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis,
|
||||
# once will probably require a nontrivial c++ op.
|
||||
axis = sorted(axis)
|
||||
inner_reduced = _ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
|
||||
rt_input, axis[-1])
|
||||
rt_input, axis[-1], keepdims)
|
||||
return _ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
|
||||
inner_reduced, axis[:-1])
|
||||
inner_reduced, axis[:-1], keepdims)
|
||||
|
||||
axis = ragged_util.get_positive_axis(axis, rt_input.shape.ndims)
|
||||
|
||||
@ -476,48 +486,48 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis,
|
||||
# sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N]
|
||||
return rt_input.with_values(
|
||||
_ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
|
||||
rt_input.values, axis - 1))
|
||||
rt_input.values, axis - 1, keepdims))
|
||||
|
||||
|
||||
def reduce_sum(rt_input, axis=None, name=None):
|
||||
def reduce_sum(input_tensor, axis=None, keepdims=None, name=None):
|
||||
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
||||
return _ragged_reduce_aggregate(math_ops.reduce_sum,
|
||||
math_ops.unsorted_segment_sum, rt_input, axis,
|
||||
name or 'RaggedReduceSum')
|
||||
math_ops.unsorted_segment_sum, input_tensor,
|
||||
axis, keepdims, name or 'RaggedReduceSum')
|
||||
|
||||
|
||||
def reduce_prod(rt_input, axis=None, name=None):
|
||||
def reduce_prod(input_tensor, axis=None, keepdims=None, name=None):
|
||||
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
||||
return _ragged_reduce_aggregate(math_ops.reduce_prod,
|
||||
math_ops.unsorted_segment_prod, rt_input,
|
||||
axis, name or 'RaggedReduceProd')
|
||||
math_ops.unsorted_segment_prod, input_tensor,
|
||||
axis, keepdims, name or 'RaggedReduceProd')
|
||||
|
||||
|
||||
def reduce_min(rt_input, axis=None, name=None):
|
||||
def reduce_min(input_tensor, axis=None, keepdims=None, name=None):
|
||||
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
||||
return _ragged_reduce_aggregate(math_ops.reduce_min,
|
||||
math_ops.unsorted_segment_min, rt_input, axis,
|
||||
name or 'RaggedReduceMin')
|
||||
math_ops.unsorted_segment_min, input_tensor,
|
||||
axis, keepdims, name or 'RaggedReduceMin')
|
||||
|
||||
|
||||
def reduce_max(rt_input, axis=None, name=None):
|
||||
def reduce_max(input_tensor, axis=None, keepdims=None, name=None):
|
||||
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
||||
return _ragged_reduce_aggregate(math_ops.reduce_max,
|
||||
math_ops.unsorted_segment_max, rt_input, axis,
|
||||
name or 'RaggedReduceMax')
|
||||
math_ops.unsorted_segment_max, input_tensor,
|
||||
axis, keepdims, name or 'RaggedReduceMax')
|
||||
|
||||
|
||||
def reduce_mean(rt_input, axis=None, name=None):
|
||||
def reduce_mean(input_tensor, axis=None, keepdims=None, name=None):
|
||||
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
||||
with ops.name_scope(name, 'RaggedReduceMean', [rt_input, axis]):
|
||||
total = reduce_sum(rt_input, axis)
|
||||
if ragged_tensor.is_ragged(rt_input):
|
||||
with ops.name_scope(name, 'RaggedReduceMean', [input_tensor, axis]):
|
||||
total = reduce_sum(input_tensor, axis, keepdims)
|
||||
if ragged_tensor.is_ragged(input_tensor):
|
||||
ones = ragged_factory_ops.from_nested_row_splits(
|
||||
array_ops.ones_like(rt_input.inner_values),
|
||||
rt_input.nested_row_splits)
|
||||
array_ops.ones_like(input_tensor.inner_values),
|
||||
input_tensor.nested_row_splits)
|
||||
else:
|
||||
ones = array_ops.ones_like(rt_input)
|
||||
count = reduce_sum(ones, axis)
|
||||
ones = array_ops.ones_like(input_tensor)
|
||||
count = reduce_sum(ones, axis, keepdims)
|
||||
if ragged_tensor.is_ragged(total):
|
||||
return ragged_factory_ops.from_nested_row_splits(
|
||||
total.inner_values / count.inner_values, total.nested_row_splits)
|
||||
@ -525,20 +535,25 @@ def reduce_mean(rt_input, axis=None, name=None):
|
||||
return total / count
|
||||
|
||||
|
||||
def _cast(rt_input, dtype):
|
||||
return ragged_functional_ops.map_inner_values(math_ops.cast, rt_input, dtype)
|
||||
def _cast(input_tensor, dtype):
|
||||
return ragged_functional_ops.map_inner_values(math_ops.cast, input_tensor,
|
||||
dtype)
|
||||
|
||||
|
||||
def reduce_all(rt_input, axis=None, name=None):
|
||||
def reduce_all(input_tensor, axis=None, keepdims=None, name=None):
|
||||
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
||||
with ops.name_scope(name, 'RaggedReduceAll', [rt_input, axis]):
|
||||
return _cast(reduce_prod(_cast(rt_input, dtypes.int32), axis), dtypes.bool)
|
||||
with ops.name_scope(name, 'RaggedReduceAll', [input_tensor, axis]):
|
||||
return _cast(
|
||||
reduce_prod(_cast(input_tensor, dtypes.int32), axis, keepdims),
|
||||
dtypes.bool)
|
||||
|
||||
|
||||
def reduce_any(rt_input, axis=None, name=None):
|
||||
def reduce_any(input_tensor, axis=None, keepdims=None, name=None):
|
||||
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
||||
with ops.name_scope(name, 'RaggedReduceAny', [rt_input, axis]):
|
||||
return _cast(reduce_sum(_cast(rt_input, dtypes.int32), axis), dtypes.bool)
|
||||
with ops.name_scope(name, 'RaggedReduceAny', [input_tensor, axis]):
|
||||
return _cast(
|
||||
reduce_sum(_cast(input_tensor, dtypes.int32), axis, keepdims),
|
||||
dtypes.bool)
|
||||
|
||||
|
||||
def _set_ragged_reduce_docstring(func, combination, combined, default, example):
|
||||
@ -554,9 +569,11 @@ _set_ragged_reduce_docstring(reduce_sum, 'sum', 'summed', '0',
|
||||
_set_ragged_reduce_docstring(reduce_prod, 'product', 'multiplied', '1',
|
||||
_RAGGED_REDUCE_PROD_EXAMPLE)
|
||||
_set_ragged_reduce_docstring(reduce_min, 'minimum', 'minimized',
|
||||
'`rt_input.dtype.min`', _RAGGED_REDUCE_MIN_EXAMPLE)
|
||||
'`input_tensor.dtype.min`',
|
||||
_RAGGED_REDUCE_MIN_EXAMPLE)
|
||||
_set_ragged_reduce_docstring(reduce_max, 'maximum', 'maximized',
|
||||
'`rt_input.dtype.max`', _RAGGED_REDUCE_MAX_EXAMPLE)
|
||||
'`input_tensor.dtype.max`',
|
||||
_RAGGED_REDUCE_MAX_EXAMPLE)
|
||||
_set_ragged_reduce_docstring(reduce_mean, 'mean', 'averaged', 'NaN',
|
||||
_RAGGED_REDUCE_MEAN_EXAMPLE)
|
||||
|
||||
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.ops.ragged import ragged_elementwise_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_getitem
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.util import tf_decorator
|
||||
@ -33,40 +33,39 @@ def _right(operator):
|
||||
ragged_tensor.RaggedTensor.__getitem__ = ragged_getitem.ragged_tensor_getitem
|
||||
|
||||
# Ordering operators
|
||||
ragged_tensor.RaggedTensor.__ge__ = ragged_elementwise_ops.greater_equal
|
||||
ragged_tensor.RaggedTensor.__gt__ = ragged_elementwise_ops.greater
|
||||
ragged_tensor.RaggedTensor.__le__ = ragged_elementwise_ops.less_equal
|
||||
ragged_tensor.RaggedTensor.__lt__ = ragged_elementwise_ops.less
|
||||
ragged_tensor.RaggedTensor.__ge__ = math_ops.greater_equal
|
||||
ragged_tensor.RaggedTensor.__gt__ = math_ops.greater
|
||||
ragged_tensor.RaggedTensor.__le__ = math_ops.less_equal
|
||||
ragged_tensor.RaggedTensor.__lt__ = math_ops.less
|
||||
|
||||
# Logical operators
|
||||
ragged_tensor.RaggedTensor.__and__ = ragged_elementwise_ops.logical_and
|
||||
ragged_tensor.RaggedTensor.__rand__ = _right(ragged_elementwise_ops.logical_and)
|
||||
ragged_tensor.RaggedTensor.__invert__ = ragged_elementwise_ops.logical_not
|
||||
ragged_tensor.RaggedTensor.__ror__ = _right(ragged_elementwise_ops.logical_or)
|
||||
ragged_tensor.RaggedTensor.__or__ = ragged_elementwise_ops.logical_or
|
||||
ragged_tensor.RaggedTensor.__xor__ = ragged_elementwise_ops.logical_xor
|
||||
ragged_tensor.RaggedTensor.__rxor__ = _right(ragged_elementwise_ops.logical_xor)
|
||||
ragged_tensor.RaggedTensor.__and__ = math_ops.logical_and
|
||||
ragged_tensor.RaggedTensor.__rand__ = _right(math_ops.logical_and)
|
||||
ragged_tensor.RaggedTensor.__invert__ = math_ops.logical_not
|
||||
ragged_tensor.RaggedTensor.__ror__ = _right(math_ops.logical_or)
|
||||
ragged_tensor.RaggedTensor.__or__ = math_ops.logical_or
|
||||
ragged_tensor.RaggedTensor.__xor__ = math_ops.logical_xor
|
||||
ragged_tensor.RaggedTensor.__rxor__ = _right(math_ops.logical_xor)
|
||||
|
||||
# Arithmetic operators
|
||||
ragged_tensor.RaggedTensor.__abs__ = ragged_elementwise_ops.abs
|
||||
ragged_tensor.RaggedTensor.__add__ = ragged_elementwise_ops.add
|
||||
ragged_tensor.RaggedTensor.__radd__ = _right(ragged_elementwise_ops.add)
|
||||
ragged_tensor.RaggedTensor.__div__ = ragged_elementwise_ops.div
|
||||
ragged_tensor.RaggedTensor.__rdiv__ = _right(ragged_elementwise_ops.div)
|
||||
ragged_tensor.RaggedTensor.__floordiv__ = ragged_elementwise_ops.floordiv
|
||||
ragged_tensor.RaggedTensor.__rfloordiv__ = _right(
|
||||
ragged_elementwise_ops.floordiv)
|
||||
ragged_tensor.RaggedTensor.__mod__ = ragged_elementwise_ops.floormod
|
||||
ragged_tensor.RaggedTensor.__rmod__ = _right(ragged_elementwise_ops.floormod)
|
||||
ragged_tensor.RaggedTensor.__mul__ = ragged_elementwise_ops.multiply
|
||||
ragged_tensor.RaggedTensor.__rmul__ = _right(ragged_elementwise_ops.multiply)
|
||||
ragged_tensor.RaggedTensor.__neg__ = ragged_elementwise_ops.negative
|
||||
ragged_tensor.RaggedTensor.__pow__ = ragged_elementwise_ops.pow
|
||||
ragged_tensor.RaggedTensor.__rpow__ = _right(ragged_elementwise_ops.pow)
|
||||
ragged_tensor.RaggedTensor.__sub__ = ragged_elementwise_ops.subtract
|
||||
ragged_tensor.RaggedTensor.__rsub__ = _right(ragged_elementwise_ops.subtract)
|
||||
ragged_tensor.RaggedTensor.__truediv__ = ragged_elementwise_ops.truediv
|
||||
ragged_tensor.RaggedTensor.__rtruediv__ = _right(ragged_elementwise_ops.truediv)
|
||||
ragged_tensor.RaggedTensor.__abs__ = math_ops.abs
|
||||
ragged_tensor.RaggedTensor.__add__ = math_ops.add
|
||||
ragged_tensor.RaggedTensor.__radd__ = _right(math_ops.add)
|
||||
ragged_tensor.RaggedTensor.__div__ = math_ops.div
|
||||
ragged_tensor.RaggedTensor.__rdiv__ = _right(math_ops.div)
|
||||
ragged_tensor.RaggedTensor.__floordiv__ = math_ops.floordiv
|
||||
ragged_tensor.RaggedTensor.__rfloordiv__ = _right(math_ops.floordiv)
|
||||
ragged_tensor.RaggedTensor.__mod__ = math_ops.floormod
|
||||
ragged_tensor.RaggedTensor.__rmod__ = _right(math_ops.floormod)
|
||||
ragged_tensor.RaggedTensor.__mul__ = math_ops.multiply
|
||||
ragged_tensor.RaggedTensor.__rmul__ = _right(math_ops.multiply)
|
||||
ragged_tensor.RaggedTensor.__neg__ = math_ops.negative
|
||||
ragged_tensor.RaggedTensor.__pow__ = math_ops.pow
|
||||
ragged_tensor.RaggedTensor.__rpow__ = _right(math_ops.pow)
|
||||
ragged_tensor.RaggedTensor.__sub__ = math_ops.subtract
|
||||
ragged_tensor.RaggedTensor.__rsub__ = _right(math_ops.subtract)
|
||||
ragged_tensor.RaggedTensor.__truediv__ = math_ops.truediv
|
||||
ragged_tensor.RaggedTensor.__rtruediv__ = _right(math_ops.truediv)
|
||||
|
||||
|
||||
# Dummy methods
|
||||
|
@ -2696,6 +2696,7 @@ class _UnaryMapValueDispatcher(dispatch.OpDispatcher):
|
||||
if args:
|
||||
x, args = args[0], args[1:]
|
||||
else:
|
||||
kwargs = kwargs.copy()
|
||||
x = kwargs.pop(self._x, None)
|
||||
if isinstance(x, sparse_tensor.SparseTensor):
|
||||
return sparse_tensor.SparseTensor(
|
||||
|
@ -38,6 +38,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.gen_string_ops import *
|
||||
from tensorflow.python.util import compat as util_compat
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
# pylint: enable=g-bad-import-order
|
||||
# pylint: enable=wildcard-import
|
||||
@ -45,6 +46,7 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
# pylint: disable=redefined-builtin
|
||||
@tf_export("strings.regex_full_match")
|
||||
@dispatch.add_dispatch_support
|
||||
def regex_full_match(input, pattern, name=None):
|
||||
r"""Match elements of `input` with regex `pattern`.
|
||||
|
||||
@ -76,6 +78,7 @@ regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__
|
||||
@tf_export(
|
||||
"strings.regex_replace", v1=["strings.regex_replace", "regex_replace"])
|
||||
@deprecation.deprecated_endpoints("regex_replace")
|
||||
@dispatch.add_dispatch_support
|
||||
def regex_replace(input, pattern, rewrite, replace_global=True, name=None):
|
||||
r"""Replace elements of `input` matching regex `pattern` with `rewrite`.
|
||||
|
||||
@ -350,10 +353,13 @@ reduce_join.__doc__ = reduce_join.__doc__.replace("tf.reduce_join(",
|
||||
# This wrapper provides backwards compatibility for code that predates the
|
||||
# unit argument and that passed 'name' as a positional argument.
|
||||
@tf_export(v1=["strings.length"])
|
||||
@dispatch.add_dispatch_support
|
||||
def string_length(input, name=None, unit="BYTE"):
|
||||
return gen_string_ops.string_length(input, unit=unit, name=name)
|
||||
|
||||
|
||||
@tf_export("strings.length", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def string_length_v2(input, unit="BYTE", name=None):
|
||||
return string_length(input, name, unit)
|
||||
|
||||
@ -370,11 +376,13 @@ substr_deprecated.__doc__ = gen_string_ops.substr.__doc__
|
||||
|
||||
|
||||
@tf_export(v1=["strings.substr"])
|
||||
@dispatch.add_dispatch_support
|
||||
def substr(input, pos, len, name=None, unit="BYTE"):
|
||||
return gen_string_ops.substr(input, pos, len, unit=unit, name=name)
|
||||
|
||||
|
||||
@tf_export("strings.substr", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def substr_v2(input, pos, len, unit="BYTE", name=None):
|
||||
return substr(input, pos, len, name=name, unit=unit)
|
||||
|
||||
@ -395,6 +403,7 @@ ops.NotDifferentiable("DecodeBase64")
|
||||
|
||||
|
||||
@tf_export("strings.to_number", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def string_to_number(input, out_type=dtypes.float32, name=None):
|
||||
r"""Converts each string in the input Tensor to the specified numeric type.
|
||||
|
||||
@ -418,6 +427,7 @@ tf_export(v1=["strings.to_number", "string_to_number"])(
|
||||
|
||||
|
||||
@tf_export("strings.to_hash_bucket", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def string_to_hash_bucket(input, num_buckets, name=None):
|
||||
# pylint: disable=line-too-long
|
||||
r"""Converts each string in the input Tensor to its hash mod by a number of buckets.
|
||||
|
@ -166,15 +166,14 @@ def dispatch_for_types(op, *types):
|
||||
|
||||
def add_dispatch_list(target):
|
||||
"""Decorator that adds a dispatch_list attribute to an op."""
|
||||
assert not hasattr(target, DISPATCH_ATTR)
|
||||
if hasattr(target, DISPATCH_ATTR):
|
||||
raise AssertionError("%s already has a dispatch list" % target)
|
||||
setattr(target, DISPATCH_ATTR, [])
|
||||
return target
|
||||
|
||||
|
||||
def add_dispatch_support(target):
|
||||
"""Decorator that adds a dispatch handling wrapper to an op."""
|
||||
add_dispatch_list(target)
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Call target, and fall back on dispatchers if there is a TypeError."""
|
||||
try:
|
||||
@ -188,5 +187,5 @@ def add_dispatch_support(target):
|
||||
else:
|
||||
raise
|
||||
|
||||
setattr(wrapper, DISPATCH_ATTR, [])
|
||||
add_dispatch_list(wrapper)
|
||||
return tf_decorator.make_decorator(target, wrapper)
|
||||
|
Loading…
x
Reference in New Issue
Block a user