Internal Change

PiperOrigin-RevId: 224225849
This commit is contained in:
A. Unique TensorFlower 2018-12-05 14:51:36 -08:00 committed by TensorFlower Gardener
parent 2b0fd9b66a
commit 45cfe71266
21 changed files with 889 additions and 757 deletions

View File

@ -1,4 +1,6 @@
op { op {
graph_op_name: "FloorDiv" graph_op_name: "FloorDiv"
visibility: HIDDEN endpoint {
name: "floor_div"
}
} }

View File

@ -1,4 +1,9 @@
op { op {
graph_op_name: "FloorMod" graph_op_name: "FloorMod"
visibility: HIDDEN endpoint {
name: "floormod"
}
endpoint {
name: "mod"
}
} }

View File

@ -1,4 +1,6 @@
op { op {
graph_op_name: "RealDiv" graph_op_name: "RealDiv"
visibility: HIDDEN endpoint {
name: "realdiv"
}
} }

View File

@ -1,4 +1,6 @@
op { op {
graph_op_name: "TruncateDiv" graph_op_name: "TruncateDiv"
visibility: HIDDEN endpoint {
name: "truncatediv"
}
} }

View File

@ -1,4 +1,6 @@
op { op {
graph_op_name: "TruncateMod" graph_op_name: "TruncateMod"
visibility: HIDDEN endpoint {
name: "truncatemod"
}
} }

View File

@ -634,7 +634,9 @@ void GenEagerPythonOp::AddEagerFunctionTeardown(
bool GenEagerPythonOp::AddEagerFastPathAndGraphCode( bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
const string& parameters, const std::vector<string>& output_sizes, const string& parameters, const std::vector<string>& output_sizes,
const string& eager_not_allowed_error) { const string& eager_not_allowed_error) {
if (api_def_.visibility() == ApiDef::VISIBLE) {
strings::StrAppend(&result_, "@_dispatch.add_dispatch_list\n"); strings::StrAppend(&result_, "@_dispatch.add_dispatch_list\n");
}
AddExport(); AddExport();
AddDefLine(function_name_, parameters); AddDefLine(function_name_, parameters);
AddDocStringDescription(); AddDocStringDescription();

View File

@ -56,6 +56,7 @@ _BaseSlice = slice
@tf_export("identity") @tf_export("identity")
@dispatch.add_dispatch_support
def identity(input, name=None): # pylint: disable=redefined-builtin def identity(input, name=None): # pylint: disable=redefined-builtin
r"""Return a tensor with the same shape and contents as input. r"""Return a tensor with the same shape and contents as input.
@ -139,6 +140,7 @@ def expand_dims(input, axis=None, name=None, dim=None):
@tf_export("expand_dims", v1=[]) @tf_export("expand_dims", v1=[])
@dispatch.add_dispatch_support
def expand_dims_v2(input, axis, name=None): def expand_dims_v2(input, axis, name=None):
"""Inserts a dimension of 1 into a tensor's shape. """Inserts a dimension of 1 into a tensor's shape.
@ -941,6 +943,7 @@ def parallel_stack(values, name="parallel_stack"):
@tf_export("stack") @tf_export("stack")
@dispatch.add_dispatch_support
def stack(values, axis=0, name="stack"): def stack(values, axis=0, name="stack"):
"""Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor. """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.
@ -1151,6 +1154,7 @@ def unstack(value, num=None, axis=0, name="unstack"):
@tf_export("concat") @tf_export("concat")
@dispatch.add_dispatch_support
def concat(values, axis, name="concat"): def concat(values, axis, name="concat"):
"""Concatenates tensors along one dimension. """Concatenates tensors along one dimension.
@ -1328,6 +1332,7 @@ def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
@tf_export("boolean_mask", v1=[]) @tf_export("boolean_mask", v1=[])
@dispatch.add_dispatch_support
def boolean_mask_v2(tensor, mask, axis=None, name="boolean_mask"): def boolean_mask_v2(tensor, mask, axis=None, name="boolean_mask"):
"""Apply boolean mask to tensor. """Apply boolean mask to tensor.
@ -1810,6 +1815,7 @@ def zeros(shape, dtype=dtypes.float32, name=None):
@tf_export(v1=["zeros_like"]) @tf_export(v1=["zeros_like"])
@dispatch.add_dispatch_support
def zeros_like(tensor, dtype=None, name=None, optimize=True): def zeros_like(tensor, dtype=None, name=None, optimize=True):
"""Creates a tensor with all elements set to zero. """Creates a tensor with all elements set to zero.
@ -1840,6 +1846,7 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
@tf_export("zeros_like", v1=[]) @tf_export("zeros_like", v1=[])
@dispatch.add_dispatch_support
def zeros_like_v2( def zeros_like_v2(
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
dtype=None, dtype=None,
@ -1899,6 +1906,7 @@ def zeros_like_impl(tensor, dtype, name, optimize=True):
@tf_export(v1=["ones_like"]) @tf_export(v1=["ones_like"])
@dispatch.add_dispatch_support
def ones_like(tensor, dtype=None, name=None, optimize=True): def ones_like(tensor, dtype=None, name=None, optimize=True):
"""Creates a tensor with all elements set to 1. """Creates a tensor with all elements set to 1.
@ -1929,6 +1937,7 @@ def ones_like(tensor, dtype=None, name=None, optimize=True):
@tf_export("ones_like", v1=[]) @tf_export("ones_like", v1=[])
@dispatch.add_dispatch_support
def ones_like_v2( def ones_like_v2(
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
dtype=None, dtype=None,
@ -3115,6 +3124,7 @@ def squeeze_v2(input, axis=None, name=None):
@tf_export("where") @tf_export("where")
@dispatch.add_dispatch_support
def where(condition, x=None, y=None, name=None): def where(condition, x=None, y=None, name=None):
"""Return the elements, either from `x` or `y`, depending on the `condition`. """Return the elements, either from `x` or `y`, depending on the `condition`.
@ -3234,6 +3244,7 @@ def gather(params, indices, validate_indices=None, name=None, axis=0):
@tf_export("gather", v1=[]) @tf_export("gather", v1=[])
@dispatch.add_dispatch_support
def gather_v2(params, indices, validate_indices=None, axis=0, name=None): def gather_v2(params, indices, validate_indices=None, axis=0, name=None):
return gather(params, indices, validate_indices=validate_indices, name=name, return gather(params, indices, validate_indices=validate_indices, name=name,
axis=axis) axis=axis)

View File

@ -31,10 +31,12 @@ from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import numerics from tensorflow.python.ops import numerics
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@tf_export("clip_by_value") @tf_export("clip_by_value")
@dispatch.add_dispatch_support
def clip_by_value(t, clip_value_min, clip_value_max, def clip_by_value(t, clip_value_min, clip_value_max,
name=None): name=None):
"""Clips tensor values to a specified min and max. """Clips tensor values to a specified min and max.

View File

@ -230,6 +230,7 @@ class DivideDelegateWithName(object):
@tf_export("math.divide", "divide") @tf_export("math.divide", "divide")
@dispatch.add_dispatch_support
def divide(x, y, name=None): def divide(x, y, name=None):
"""Computes Python style division of `x` by `y`.""" """Computes Python style division of `x` by `y`."""
@ -242,6 +243,7 @@ def divide(x, y, name=None):
@tf_export("math.multiply", "multiply") @tf_export("math.multiply", "multiply")
@dispatch.add_dispatch_support
def multiply(x, y, name=None): def multiply(x, y, name=None):
return gen_math_ops.mul(x, y, name) return gen_math_ops.mul(x, y, name)
@ -262,6 +264,7 @@ _mul.__doc__ = (
@tf_export("math.subtract", "subtract") @tf_export("math.subtract", "subtract")
@dispatch.add_dispatch_support
def subtract(x, y, name=None): def subtract(x, y, name=None):
return gen_math_ops.sub(x, y, name) return gen_math_ops.sub(x, y, name)
@ -347,6 +350,7 @@ def scalar_mul_v2(scalar, x, name=None):
@tf_export("math.pow", "pow") @tf_export("math.pow", "pow")
@dispatch.add_dispatch_support
def pow(x, y, name=None): # pylint: disable=redefined-builtin def pow(x, y, name=None): # pylint: disable=redefined-builtin
r"""Computes the power of one value to another. r"""Computes the power of one value to another.
@ -375,6 +379,7 @@ def pow(x, y, name=None): # pylint: disable=redefined-builtin
# pylint: disable=redefined-builtin,redefined-outer-name # pylint: disable=redefined-builtin,redefined-outer-name
@tf_export("dtypes.complex", "complex") @tf_export("dtypes.complex", "complex")
@dispatch.add_dispatch_support
def complex(real, imag, name=None): def complex(real, imag, name=None):
r"""Converts two real numbers to a complex number. r"""Converts two real numbers to a complex number.
@ -418,6 +423,7 @@ def complex(real, imag, name=None):
@tf_export("math.real", v1=["math.real", "real"]) @tf_export("math.real", v1=["math.real", "real"])
@deprecation.deprecated_endpoints("real") @deprecation.deprecated_endpoints("real")
@dispatch.add_dispatch_support
def real(input, name=None): def real(input, name=None):
r"""Returns the real part of a complex (or real) tensor. r"""Returns the real part of a complex (or real) tensor.
@ -450,6 +456,7 @@ def real(input, name=None):
@tf_export("math.imag", v1=["math.imag", "imag"]) @tf_export("math.imag", v1=["math.imag", "imag"])
@deprecation.deprecated_endpoints("imag") @deprecation.deprecated_endpoints("imag")
@dispatch.add_dispatch_support
def imag(input, name=None): def imag(input, name=None):
r"""Returns the imaginary part of a complex (or real) tensor. r"""Returns the imaginary part of a complex (or real) tensor.
@ -481,6 +488,7 @@ def imag(input, name=None):
@tf_export("math.angle", v1=["math.angle", "angle"]) @tf_export("math.angle", v1=["math.angle", "angle"])
@deprecation.deprecated_endpoints("angle") @deprecation.deprecated_endpoints("angle")
@dispatch.add_dispatch_support
def angle(input, name=None): def angle(input, name=None):
r"""Returns the element-wise argument of a complex (or real) tensor. r"""Returns the element-wise argument of a complex (or real) tensor.
@ -520,6 +528,7 @@ def angle(input, name=None):
@tf_export("math.round", "round") @tf_export("math.round", "round")
@dispatch.add_dispatch_support
def round(x, name=None): # pylint: disable=redefined-builtin def round(x, name=None): # pylint: disable=redefined-builtin
"""Rounds the values of a tensor to the nearest integer, element-wise. """Rounds the values of a tensor to the nearest integer, element-wise.
@ -547,6 +556,7 @@ def round(x, name=None): # pylint: disable=redefined-builtin
@tf_export("dtypes.cast", "cast") @tf_export("dtypes.cast", "cast")
@dispatch.add_dispatch_support
def cast(x, dtype, name=None): def cast(x, dtype, name=None):
"""Casts a tensor to a new type. """Casts a tensor to a new type.
@ -610,6 +620,7 @@ def cast(x, dtype, name=None):
@tf_export("dtypes.saturate_cast", "saturate_cast") @tf_export("dtypes.saturate_cast", "saturate_cast")
@dispatch.add_dispatch_support
def saturate_cast(value, dtype, name=None): def saturate_cast(value, dtype, name=None):
"""Performs a safe saturating cast of `value` to `dtype`. """Performs a safe saturating cast of `value` to `dtype`.
@ -935,6 +946,7 @@ def _div_python2(x, y, name=None):
@tf_export("math.truediv", "truediv") @tf_export("math.truediv", "truediv")
@dispatch.add_dispatch_support
def truediv(x, y, name=None): def truediv(x, y, name=None):
"""Divides x / y elementwise (using Python 3 division operator semantics). """Divides x / y elementwise (using Python 3 division operator semantics).
@ -992,6 +1004,7 @@ def div(x, y, name=None):
@tf_export("div_no_nan") @tf_export("div_no_nan")
@dispatch.add_dispatch_support
def div_no_nan(x, y, name=None): def div_no_nan(x, y, name=None):
"""Computes an unsafe divide which returns 0 if the y is zero. """Computes an unsafe divide which returns 0 if the y is zero.
@ -1021,6 +1034,7 @@ mod = gen_math_ops.floor_mod
# TODO(aselle): Deprecate this once all internal functionality uses # TODO(aselle): Deprecate this once all internal functionality uses
# tf.truncatediv # tf.truncatediv
@tf_export("math.floordiv", v1=["math.floordiv", "floordiv"]) @tf_export("math.floordiv", v1=["math.floordiv", "floordiv"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("floordiv") @deprecation.deprecated_endpoints("floordiv")
def floordiv(x, y, name=None): def floordiv(x, y, name=None):
"""Divides `x / y` elementwise, rounding toward the most negative integer. """Divides `x / y` elementwise, rounding toward the most negative integer.
@ -1050,16 +1064,11 @@ def floordiv(x, y, name=None):
realdiv = gen_math_ops.real_div realdiv = gen_math_ops.real_div
tf_export("realdiv")(realdiv)
truncatediv = gen_math_ops.truncate_div truncatediv = gen_math_ops.truncate_div
tf_export("truncatediv")(truncatediv)
# TODO(aselle): Rename this to floordiv when we can. # TODO(aselle): Rename this to floordiv when we can.
floor_div = gen_math_ops.floor_div floor_div = gen_math_ops.floor_div
tf_export("floor_div")(floor_div)
truncatemod = gen_math_ops.truncate_mod truncatemod = gen_math_ops.truncate_mod
tf_export("truncatemod")(truncatemod)
floormod = gen_math_ops.floor_mod floormod = gen_math_ops.floor_mod
tf_export("floormod", "mod")(floormod)
def _mul_dispatch(x, y, name=None): def _mul_dispatch(x, y, name=None):
@ -1095,6 +1104,7 @@ _OverrideBinaryOperatorHelper(pow, "pow")
@tf_export("math.logical_xor", v1=["math.logical_xor", "logical_xor"]) @tf_export("math.logical_xor", v1=["math.logical_xor", "logical_xor"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("logical_xor") @deprecation.deprecated_endpoints("logical_xor")
def logical_xor(x, y, name="LogicalXor"): def logical_xor(x, y, name="LogicalXor"):
"""x ^ y = (x | y) & ~(x & y).""" """x ^ y = (x | y) & ~(x & y)."""
@ -1277,6 +1287,7 @@ def reduce_sum_v1(input_tensor,
@tf_export("math.reduce_sum", "reduce_sum", v1=[]) @tf_export("math.reduce_sum", "reduce_sum", v1=[])
@dispatch.add_dispatch_support
def reduce_sum(input_tensor, axis=None, keepdims=False, name=None): def reduce_sum(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the sum of elements across dimensions of a tensor. """Computes the sum of elements across dimensions of a tensor.
@ -1524,6 +1535,7 @@ def reduce_mean_v1(input_tensor,
@tf_export("math.reduce_mean", "reduce_mean", v1=[]) @tf_export("math.reduce_mean", "reduce_mean", v1=[])
@dispatch.add_dispatch_support
def reduce_mean(input_tensor, axis=None, keepdims=False, name=None): def reduce_mean(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the mean of elements across dimensions of a tensor. """Computes the mean of elements across dimensions of a tensor.
@ -1675,6 +1687,7 @@ def reduce_std(input_tensor, axis=None, keepdims=False, name=None):
@tf_export("math.reduce_prod", "reduce_prod", v1=[]) @tf_export("math.reduce_prod", "reduce_prod", v1=[])
@dispatch.add_dispatch_support
def reduce_prod(input_tensor, axis=None, keepdims=False, name=None): def reduce_prod(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the product of elements across dimensions of a tensor. """Computes the product of elements across dimensions of a tensor.
@ -1796,6 +1809,7 @@ def reduce_min_v1(input_tensor,
@tf_export("math.reduce_min", "reduce_min", v1=[]) @tf_export("math.reduce_min", "reduce_min", v1=[])
@dispatch.add_dispatch_support
def reduce_min(input_tensor, axis=None, keepdims=False, name=None): def reduce_min(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the minimum of elements across dimensions of a tensor. """Computes the minimum of elements across dimensions of a tensor.
@ -1874,6 +1888,7 @@ def reduce_max_v1(input_tensor,
@tf_export("math.reduce_max", "reduce_max", v1=[]) @tf_export("math.reduce_max", "reduce_max", v1=[])
@dispatch.add_dispatch_support
def reduce_max(input_tensor, axis=None, keepdims=False, name=None): def reduce_max(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the maximum of elements across dimensions of a tensor. """Computes the maximum of elements across dimensions of a tensor.
@ -1961,6 +1976,7 @@ def reduce_all_v1(input_tensor,
@tf_export("reduce_all", "math.reduce_all", v1=[]) @tf_export("reduce_all", "math.reduce_all", v1=[])
@dispatch.add_dispatch_support
def reduce_all(input_tensor, axis=None, keepdims=False, name=None): def reduce_all(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the "logical and" of elements across dimensions of a tensor. """Computes the "logical and" of elements across dimensions of a tensor.
@ -2057,6 +2073,7 @@ def reduce_any_v1(input_tensor,
@tf_export("math.reduce_any", "reduce_any", v1=[]) @tf_export("math.reduce_any", "reduce_any", v1=[])
@dispatch.add_dispatch_support
def reduce_any(input_tensor, axis=None, keepdims=False, name=None): def reduce_any(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the "logical or" of elements across dimensions of a tensor. """Computes the "logical or" of elements across dimensions of a tensor.
@ -2619,6 +2636,7 @@ def _as_indexed_slices_list(inputs, optimize=True):
@tf_export("math.add_n", "add_n") @tf_export("math.add_n", "add_n")
@dispatch.add_dispatch_support
def add_n(inputs, name=None): def add_n(inputs, name=None):
"""Adds all input tensors element-wise. """Adds all input tensors element-wise.
@ -2764,6 +2782,7 @@ def sigmoid(x, name=None):
@tf_export("math.log_sigmoid", v1=["math.log_sigmoid", "log_sigmoid"]) @tf_export("math.log_sigmoid", v1=["math.log_sigmoid", "log_sigmoid"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("log_sigmoid") @deprecation.deprecated_endpoints("log_sigmoid")
def log_sigmoid(x, name=None): def log_sigmoid(x, name=None):
"""Computes log sigmoid of `x` element-wise. """Computes log sigmoid of `x` element-wise.
@ -2973,6 +2992,7 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
@tf_export("math.conj", v1=["math.conj", "conj"]) @tf_export("math.conj", v1=["math.conj", "conj"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("conj") @deprecation.deprecated_endpoints("conj")
def conj(x, name=None): def conj(x, name=None):
r"""Returns the complex conjugate of a complex number. r"""Returns the complex conjugate of a complex number.
@ -3077,6 +3097,7 @@ def _unsorted_segment_N(data, segment_ids, num_segments):
"math.unsorted_segment_mean", "math.unsorted_segment_mean",
v1=["math.unsorted_segment_mean", "unsorted_segment_mean"]) v1=["math.unsorted_segment_mean", "unsorted_segment_mean"])
@deprecation.deprecated_endpoints("unsorted_segment_mean") @deprecation.deprecated_endpoints("unsorted_segment_mean")
@dispatch.add_dispatch_support
def unsorted_segment_mean(data, segment_ids, num_segments, name=None): def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
r"""Computes the mean along segments of a tensor. r"""Computes the mean along segments of a tensor.
@ -3122,6 +3143,7 @@ def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
"math.unsorted_segment_sqrt_n", "math.unsorted_segment_sqrt_n",
v1=["math.unsorted_segment_sqrt_n", "unsorted_segment_sqrt_n"]) v1=["math.unsorted_segment_sqrt_n", "unsorted_segment_sqrt_n"])
@deprecation.deprecated_endpoints("unsorted_segment_sqrt_n") @deprecation.deprecated_endpoints("unsorted_segment_sqrt_n")
@dispatch.add_dispatch_support
def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None): def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None):
r"""Computes the sum along segments of a tensor divided by the sqrt(N). r"""Computes the sum along segments of a tensor divided by the sqrt(N).

View File

@ -25,7 +25,7 @@ py_library(
deps = [ deps = [
":ragged_array_ops", ":ragged_array_ops",
":ragged_conversion_ops", ":ragged_conversion_ops",
":ragged_elementwise_ops", ":ragged_dispatch",
":ragged_factory_ops", ":ragged_factory_ops",
":ragged_functional_ops", ":ragged_functional_ops",
":ragged_getitem", ":ragged_getitem",
@ -150,33 +150,14 @@ py_library(
], ],
) )
py_library(
name = "ragged_elementwise_ops",
srcs = ["ragged_elementwise_ops.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_factory_ops",
":ragged_tensor",
":ragged_tensor_shape",
":ragged_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:clip_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python:util",
],
)
py_library( py_library(
name = "ragged_operators", name = "ragged_operators",
srcs = ["ragged_operators.py"], srcs = ["ragged_operators.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":ragged_elementwise_ops",
":ragged_getitem", ":ragged_getitem",
":ragged_tensor", ":ragged_tensor",
"//tensorflow/python:math_ops",
"//tensorflow/python:util", "//tensorflow/python:util",
], ],
) )
@ -186,12 +167,13 @@ py_library(
srcs = ["ragged_string_ops.py"], srcs = ["ragged_string_ops.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":ragged_array_ops",
":ragged_conversion_ops", ":ragged_conversion_ops",
":ragged_factory_ops", ":ragged_factory_ops",
":ragged_tensor", ":ragged_tensor",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:math_ops", "//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python:util", "//tensorflow/python:util",
], ],
) )
@ -219,10 +201,11 @@ py_library(
":ragged_tensor", ":ragged_tensor",
":ragged_util", ":ragged_util",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util", "//tensorflow/python:tensor_util",
], ],
) )
@ -285,6 +268,29 @@ py_library(
], ],
) )
py_library(
name = "ragged_dispatch",
srcs = ["ragged_dispatch.py"],
srcs_version = "PY2AND3",
deps = [
":ragged_array_ops",
":ragged_factory_ops",
":ragged_math_ops",
":ragged_tensor",
":ragged_tensor_shape",
":ragged_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:clip_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:util",
"//third_party/py/numpy",
],
)
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
# RaggedTensor Tests # RaggedTensor Tests
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
@ -458,6 +464,7 @@ py_test(
"//tensorflow/python:errors", "//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"//tensorflow/python:gradients_impl", "//tensorflow/python:gradients_impl",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
) )
@ -684,17 +691,21 @@ py_test(
) )
py_test( py_test(
name = "ragged_elementwise_ops_test", name = "ragged_dispatch_test",
srcs = ["ragged_elementwise_ops_test.py"], srcs = ["ragged_dispatch_test.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":ragged", ":ragged",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:clip_ops",
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",
"//tensorflow/python:errors", "//tensorflow/python:errors",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python:string_ops",
"//third_party/py/numpy", "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized", "@absl_py//absl/testing:parameterized",
], ],
@ -725,6 +736,7 @@ py_test(
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python:string_ops", "//tensorflow/python:string_ops",
"//tensorflow/python/keras:backend", "//tensorflow/python/keras:backend",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized", "@absl_py//absl/testing:parameterized",
], ],
) )
@ -735,8 +747,10 @@ py_test(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":ragged", ":ragged",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized", "@absl_py//absl/testing:parameterized",
], ],
) )

View File

@ -1,76 +1,53 @@
"""Ragged Tensors. """Ragged Tensors.
This package defines the [`RaggedTensor`](ragged/RaggedTensor.md) class, which This package defines the `tf.RaggedTensor` class, which
represents tensors with non-uniform shapes. In particular, each `RaggedTensor` represents tensors with non-uniform shapes. In particular, each `RaggedTensor`
has one or more *ragged dimensions*, which are dimensions whose slices may have has one or more *ragged dimensions*, which are dimensions whose slices may have
different lengths. For example, the inner (column) dimension of different lengths. For example, the inner (column) dimension of
`rt=[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is ragged, since the column slices `rt=[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is ragged, since the column slices
(`rt[0, :]`, ..., `rt[4, :]`) have different lengths. For a more detailed (`rt[0, :]`, ..., `rt[4, :]`) have different lengths. For a more detailed
description of ragged tensors, see the [`RaggedTensor`](ragged/RaggedTensor.md) description of ragged tensors, see the `tf.RaggedTensor`
class documentation. class documentation.
## RaggedTensor Operations ## `RaggedTensor` Operations
This package also defines a collection of operations for manipulating ### `RaggedTensor` Factory ops
ragged tensors.
### RaggedTensor Versions of Standard Tensor Operations * `tf.ragged.constant`
* `tf.ragged.from_row_splits`
* `tf.ragged.from_row_splits`
* `tf.ragged.from_row_lengths`
* `tf.ragged.from_row_starts`
* `tf.ragged.from_row_limits`
* `tf.ragged.from_value_rowids`
* `tf.ragged.from_nested_row_splits`
* `tf.ragged.from_nested_value_rowids`
Many of the operations defined by this package are analogous to ### `RaggedTensor` Conversion ops
[`Tensor`](https://www.tensorflow.org/api_docs/python/tf/Tensor)
operations, but they accept `RaggedTensor`s as input and can return
`RaggedTensor`s as output. For example, `ragged.add` performs elementwise
addition just like `tf.add`, but can be used on `RaggedTensor`s.
These `RaggedTensor` versions of the standard `Tensor` operations can also be * `tf.ragged.from_tensor`
used with standard `Tensors`; and for the most part, they will return the same * `tf.ragged.to_tensor`
value that the standard `Tensor` operation would return. However, there are * `tf.ragged.from_sparse`
a few notable exceptions: * `tf.ragged.to_sparse`
* `tf.ragged.from_variant`
* `tf.ragged.to_variant`
* `tf.ragged.convert_to_tensor_or_ragged_tensor`
* For [`ragged.stack(...)`](ragged/stack.md) and ### `RaggedTensor` Shape ops
[`ragged.concat(...)`](ragged/concat.md), the input tensors are not required
to have matching shapes. In the returned tensor, all dimensions up to the
`axis` dimension will be ragged.
### Ragged-Tensor Specific Operations * `tf.ragged.row_splits`
* `tf.ragged.row_lengths`
* `tf.ragged.row_starts`
* `tf.ragged.row_limits`
* `tf.ragged.value_rowids`
* `tf.ragged.nrows`
* `tf.ragged.nested_row_splits`
* `tf.ragged.row_splits_to_segment_ids`
* `tf.ragged.segment_ids_to_row_splits`
* `tf.ragged.bounding_shape`
The following operations are specific to ragged tensors: ### Functional ops
* `tf.ragged.map_inner_values`
* **Factory ops**:
[`constant(...)`](ragged/constant.md),
[`from_row_splits(...)`](ragged/from_row_splits.md),
[`from_row_lengths(...)`](ragged/from_row_lengths.md),
[`from_row_starts(...)`](ragged/from_row_starts.md),
[`from_row_limits(...)`](ragged/from_row_limits.md),
[`from_value_rowids(...)`](ragged/from_value_rowids.md),
[`from_nested_row_splits(...)`](ragged/from_nested_row_splits.md),
[`from_nested_value_rowids(...)`](ragged/from_nested_value_rowids.md).
* **Conversion ops**:
[`from_tensor(...)`](ragged/from_tensor.md),
[`to_tensor(...)`](ragged/to_tensor.md),
[`from_sparse(...)`](ragged/from_sparse.md),
[`to_sparse(...)`](ragged/to_sparse.md),
[`from_variant(...)`](ragged/from_variant.md),
[`to_variant(...)`](ragged/to_variant.md),
[`convert_to_tensor_or_ragged_tensor(...)`](
ragged/convert_to_tensor_or_ragged_tensor.md).
* **Shape ops**:
[`row_splits(...)`](ragged/row_splits.md),
[`row_lengths(...)`](ragged/row_lengths.md),
[`row_starts(...)`](ragged/row_starts.md),
[`row_limits(...)`](ragged/row_limits.md),
[`value_rowids(...)`](ragged/value_rowids.md),
[`nrows(...)`](ragged/nrows.md),
[`nested_row_splits(...)`](ragged/nested_row_splits.md),
[`row_splits_to_segment_ids(...)`](ragged/row_splits_to_segment_ids.md),
[`segment_ids_to_row_splits(...)`](ragged/segment_ids_to_row_splits.md),
[`bounding_shape(...)`](ragged/bounding_shape.md).
* **Functional ops**:
[`map_inner_values(...)`](ragged/map_inner_values.md),
[`make_elementwise_op(...)`](ragged/make_elementwise_op.md).
<!-- Ragged Classes & related helper functions --> <!-- Ragged Classes & related helper functions -->
@ -140,21 +117,17 @@ The following operations are specific to ragged tensors:
@@map_inner_values @@map_inner_values
@@map_fn @@map_fn
<!-- Elementwise Ops -->
@@make_elementwise_op
<!-- Shape & broadcasting --> <!-- Shape & broadcasting -->
@@RaggedTensorDynamicShape @@RaggedTensorDynamicShape
@@broadcast_to @@broadcast_to
@@broadcast_dynamic_shape @@broadcast_dynamic_shape
<!-- Symbols from ragged_elementwise_ops._symbols_to_export are whitelisted -->
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.ops.ragged import ragged_dispatch
from tensorflow.python.ops.ragged import ragged_operators from tensorflow.python.ops.ragged import ragged_operators
from tensorflow.python.ops.ragged import ragged_string_ops from tensorflow.python.ops.ragged import ragged_string_ops
@ -179,11 +152,6 @@ from tensorflow.python.ops.ragged.ragged_conversion_ops import from_tensor
from tensorflow.python.ops.ragged.ragged_conversion_ops import to_sparse from tensorflow.python.ops.ragged.ragged_conversion_ops import to_sparse
from tensorflow.python.ops.ragged.ragged_conversion_ops import to_tensor from tensorflow.python.ops.ragged.ragged_conversion_ops import to_tensor
# pylint: disable=protected-access, wildcard-import
from tensorflow.python.ops.ragged.ragged_elementwise_ops import *
from tensorflow.python.ops.ragged.ragged_elementwise_ops import _symbols_to_export as _elementwise_ops
# pylint: enable=protected-access, wildcard-import
from tensorflow.python.ops.ragged.ragged_factory_ops import constant from tensorflow.python.ops.ragged.ragged_factory_ops import constant
from tensorflow.python.ops.ragged.ragged_factory_ops import constant_value from tensorflow.python.ops.ragged.ragged_factory_ops import constant_value
from tensorflow.python.ops.ragged.ragged_factory_ops import convert_to_tensor_or_ragged_tensor from tensorflow.python.ops.ragged.ragged_factory_ops import convert_to_tensor_or_ragged_tensor
@ -231,6 +199,10 @@ from tensorflow.python.ops.ragged.segment_id_ops import segment_ids_to_row_split
from tensorflow.python.util import all_util as _all_util from tensorflow.python.util import all_util as _all_util
# Register OpDispatchers that override standard TF ops to work w/ RaggedTensors.
__doc__ += ragged_dispatch.register_dispatchers() # pylint: disable=redefined-builtin
# Any symbol that is not referenced (with "@@name") in the module docstring # Any symbol that is not referenced (with "@@name") in the module docstring
# above, or included in the "_elementwise_ops" whitelist, will be removed. # above will be removed.
_all_util.remove_undocumented(__name__, _elementwise_ops) _all_util.remove_undocumented(__name__)

View File

@ -308,7 +308,7 @@ def bounding_shape(rt_input, axis=None, name=None):
# ragged_gather # ragged_gather
#=============================================================================== #===============================================================================
# TODO(edloper): Add an `axis` argument # TODO(edloper): Add an `axis` argument
def gather(params, indices, name=None): def gather(params, indices, validate_indices=None, axis=0, name=None):
"""Gathers ragged slices from `params` axis `0` according to `indices`. """Gathers ragged slices from `params` axis `0` according to `indices`.
Returns `RaggedTensor` output, such that: Returns `RaggedTensor` output, such that:
@ -347,6 +347,8 @@ def gather(params, indices, name=None):
indices: The potentially ragged tensor indicating which values to gather. indices: The potentially ragged tensor indicating which values to gather.
Must have dtype `int32` or `int64`. Values must be in the range `[0, Must have dtype `int32` or `int64`. Values must be in the range `[0,
params.shape[0]]`. params.shape[0]]`.
validate_indices: Ignored.
axis: Must be zero.
name: A name for the operation (optional). name: A name for the operation (optional).
Returns: Returns:
@ -357,6 +359,9 @@ def gather(params, indices, name=None):
Raises: Raises:
ValueError: If indices.shape.ndims is not known statically. ValueError: If indices.shape.ndims is not known statically.
""" """
del validate_indices
if not isinstance(axis, int) or axis != 0:
raise ValueError('axis>0 is not supported for ragged gather yet.')
with ops.name_scope(name, 'RaggedGather', [params, indices]): with ops.name_scope(name, 'RaggedGather', [params, indices]):
params = ragged_factory_ops.convert_to_tensor_or_ragged_tensor( params = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
params, name='params') params, name='params')
@ -812,29 +817,29 @@ def boolean_mask(data, mask, keepdims=False, name=None):
#=============================================================================== #===============================================================================
# Concatenation and Stacking # Concatenation and Stacking
#=============================================================================== #===============================================================================
def concat(rt_inputs, axis, name=None): def concat(values, axis, name=None):
"""Concatenates potentially ragged tensors along one dimension. """Concatenates potentially ragged tensors along one dimension.
Given a list of tensors with the same rank `K` (`K >= axis`), returns a Given a list of tensors with the same rank `K` (`K >= axis`), returns a
rank-`K` `RaggedTensor` `result` such that `result[i0...iaxis]` is the rank-`K` `RaggedTensor` `result` such that `result[i0...iaxis]` is the
concatenation of `[rt[i0...iaxis] for rt in rt_inputs]`. concatenation of `[rt[i0...iaxis] for rt in values]`.
Args: Args:
rt_inputs: A list of potentially ragged tensors. May not be empty. All values: A list of potentially ragged tensors. May not be empty. All
`rt_inputs` must have the same rank and the same dtype; but unlike `values` must have the same rank and the same dtype; but unlike
`tf.concat`, they can have arbitrary shapes. `tf.concat`, they can have arbitrary shapes.
axis: A python integer, indicating the dimension along which to concatenate. axis: A python integer, indicating the dimension along which to concatenate.
(Note: Unlike `tf.concat`, the `axis` parameter must be statically known.) (Note: Unlike `tf.concat`, the `axis` parameter must be statically known.)
Negative values are supported only if the rank of at least one Negative values are supported only if the rank of at least one
`rt_inputs` value is statically known. `values` value is statically known.
name: A name prefix for the returned tensor (optional). name: A name prefix for the returned tensor (optional).
Returns: Returns:
A `RaggedTensor` with rank `K`. A `RaggedTensor` with rank `K`.
`result.ragged_rank=max(axis, max(rt.ragged_rank for rt in rt_inputs]))`. `result.ragged_rank=max(axis, max(rt.ragged_rank for rt in values]))`.
Raises: Raises:
ValueError: If `rt_inputs` is empty, if `axis` is out of bounds or if ValueError: If `values` is empty, if `axis` is out of bounds or if
the input tensors have different ranks. the input tensors have different ranks.
#### Example: #### Example:
@ -847,35 +852,35 @@ def concat(rt_inputs, axis, name=None):
[[1, 2, 6], [3, 4, 5, 7, 8, 9]] [[1, 2, 6], [3, 4, 5, 7, 8, 9]]
``` ```
""" """
if not isinstance(rt_inputs, (list, tuple)): if not isinstance(values, (list, tuple)):
rt_inputs = [rt_inputs] values = [values]
with ops.name_scope(name, 'RaggedConcat', rt_inputs): with ops.name_scope(name, 'RaggedConcat', values):
return _ragged_stack_concat_helper(rt_inputs, axis, stack_values=False) return _ragged_stack_concat_helper(values, axis, stack_values=False)
def stack(rt_inputs, axis, name=None): def stack(values, axis, name=None):
"""Stacks potentially ragged tensors along one dimension. """Stacks potentially ragged tensors along one dimension.
Given a list of tensors with the same rank `K` (`K >= axis`), returns a Given a list of tensors with the same rank `K` (`K >= axis`), returns a
rank-`K+1` `RaggedTensor` `result` such that `result[i0...iaxis]` is the rank-`K+1` `RaggedTensor` `result` such that `result[i0...iaxis]` is the
list `[rt[i0...iaxis] for rt in rt_inputs]`. list `[rt[i0...iaxis] for rt in values]`.
Args: Args:
rt_inputs: A list of potentially ragged tensors. May not be empty. All values: A list of potentially ragged tensors. May not be empty. All
`rt_inputs` must have the same rank and the same dtype; but unlike `values` must have the same rank and the same dtype; but unlike
`tf.concat`, they can have arbitrary shapes. `tf.concat`, they can have arbitrary shapes.
axis: A python integer, indicating the dimension along which to stack. axis: A python integer, indicating the dimension along which to stack.
(Note: Unlike `tf.stack`, the `axis` parameter must be statically known.) (Note: Unlike `tf.stack`, the `axis` parameter must be statically known.)
Negative values are supported only if the rank of at least one Negative values are supported only if the rank of at least one
`rt_inputs` value is statically known. `values` value is statically known.
name: A name prefix for the returned tensor (optional). name: A name prefix for the returned tensor (optional).
Returns: Returns:
A `RaggedTensor` with rank `K+1`. A `RaggedTensor` with rank `K+1`.
`result.ragged_rank=max(axis, max(rt.ragged_rank for rt in rt_inputs]))`. `result.ragged_rank=max(axis, max(rt.ragged_rank for rt in values]))`.
Raises: Raises:
ValueError: If `rt_inputs` is empty, if `axis` is out of bounds or if ValueError: If `values` is empty, if `axis` is out of bounds or if
the input tensors have different ranks. the input tensors have different ranks.
#### Example: #### Example:
@ -888,10 +893,10 @@ def stack(rt_inputs, axis, name=None):
[[[1, 2], [6]], [[3, 4, 5], [7, 8, 9]]] [[[1, 2], [6]], [[3, 4, 5], [7, 8, 9]]]
``` ```
""" """
if not isinstance(rt_inputs, (list, tuple)): if not isinstance(values, (list, tuple)):
rt_inputs = [rt_inputs] values = [values]
with ops.name_scope(name, 'RaggedConcat', rt_inputs): with ops.name_scope(name, 'RaggedConcat', values):
return _ragged_stack_concat_helper(rt_inputs, axis, stack_values=True) return _ragged_stack_concat_helper(values, axis, stack_values=True)
def _ragged_stack_concat_helper(rt_inputs, axis, stack_values): def _ragged_stack_concat_helper(rt_inputs, axis, stack_values):
@ -1065,22 +1070,22 @@ def _copy_row_shape(rt_inputs, splits):
#=============================================================================== #===============================================================================
# Tiling # Tiling
#=============================================================================== #===============================================================================
def tile(rt_input, multiples, name=None): def tile(input, multiples, name=None): # pylint: disable=redefined-builtin
"""Constructs a `RaggedTensor` by tiling a given `RaggedTensor`. """Constructs a `RaggedTensor` by tiling a given `RaggedTensor`.
The values of `rt_input` are replicated `multiples[i]` times along the The values of `input` are replicated `multiples[i]` times along the
`i`th dimension (for each dimension `i`). For every dimension `axis` in `i`th dimension (for each dimension `i`). For every dimension `axis` in
`rt_input`, the length of each output element in that dimension is the `input`, the length of each output element in that dimension is the
length of corresponding input element multiplied by `multiples[axis]`. length of corresponding input element multiplied by `multiples[axis]`.
Args: Args:
rt_input: A `RaggedTensor`. input: A `RaggedTensor`.
multiples: A 1-D integer `Tensor`. Length must be the same as the number of multiples: A 1-D integer `Tensor`. Length must be the same as the number of
dimensions in `rt_input`. dimensions in `input`.
name: A name for the operation (optional). name: A name for the operation (optional).
Returns: Returns:
A `RaggedTensor` with the same type, rank, and ragged_rank as `rt_input`. A `RaggedTensor` with the same type, rank, and ragged_rank as `input`.
#### Example: #### Example:
```python ```python
@ -1089,22 +1094,22 @@ def tile(rt_input, multiples, name=None):
[[1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3]] [[1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3]]
``` ```
""" """
with ops.name_scope(name, 'RaggedTile', [rt_input, multiples]): with ops.name_scope(name, 'RaggedTile', [input, multiples]):
rt_input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor( input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
rt_input, name='rt_input') input, name='input')
multiples = ragged_util.convert_to_int_tensor( multiples = ragged_util.convert_to_int_tensor(
multiples, name='multiples', dtype=dtypes.int64) multiples, name='multiples', dtype=dtypes.int64)
multiples.shape.assert_has_rank(1) multiples.shape.assert_has_rank(1)
if not ragged_tensor.is_ragged(rt_input): if not ragged_tensor.is_ragged(input):
return array_ops.tile(rt_input, multiples, name) return array_ops.tile(input, multiples, name)
# If the constant value of `multiples` is available, then we can use it # If the constant value of `multiples` is available, then we can use it
# to skip tiling dimensions where `multiples=1`. # to skip tiling dimensions where `multiples=1`.
const_multiples = tensor_util.constant_value(multiples) const_multiples = tensor_util.constant_value(multiples)
return ragged_factory_ops.from_nested_row_splits( return ragged_factory_ops.from_nested_row_splits(
_tile_ragged_values(rt_input, multiples, const_multiples), _tile_ragged_values(input, multiples, const_multiples),
_tile_ragged_splits(rt_input, multiples, const_multiples)) _tile_ragged_splits(input, multiples, const_multiples))
def _tile_ragged_values(rt_input, multiples, const_multiples=None): def _tile_ragged_values(rt_input, multiples, const_multiples=None):
@ -1240,26 +1245,26 @@ def _tile_ragged_splits(rt_input, multiples, const_multiples=None):
#=============================================================================== #===============================================================================
def expand_dims(rt_input, axis, name=None): def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
"""Inserts a dimension with shape 1 into a potentially ragged tensor's shape. """Inserts a dimension with shape 1 into a potentially ragged tensor's shape.
Given a potentially ragged tenor `rt_input`, this operation inserts a Given a potentially ragged tenor `input`, this operation inserts a
dimension with size 1 at the dimension `axis` of `rt_input`'s shape. dimension with size 1 at the dimension `axis` of `input`'s shape.
* If `rt_input` is a `Tensor`, then this is equivalent to * If `input` is a `Tensor`, then this is equivalent to
`tf.expand_dims`. `tf.expand_dims`.
* If `rt_input` is ragged, and `axis=0`, then the new dimension will be * If `input` is ragged, and `axis=0`, then the new dimension will be
uniform; but the previously outermost dimension will become ragged. uniform; but the previously outermost dimension will become ragged.
* If `rt_input` is ragged, and `0 < axis < rt_input.ragged_rank`, then the * If `input` is ragged, and `0 < axis < input.ragged_rank`, then the
new dimension will be ragged. new dimension will be ragged.
* If `rt_input` is ragged, and axis >= rt_input.ragged_rank`, then the new * If `input` is ragged, and axis >= input.ragged_rank`, then the new
dimension will be uniform. dimension will be uniform.
The following table gives some examples showing how `ragged.expand_dims` The following table gives some examples showing how `ragged.expand_dims`
impacts the shapes of different input tensors. Ragged dimensions are impacts the shapes of different input tensors. Ragged dimensions are
indicated by enclosing them in parentheses. indicated by enclosing them in parentheses.
rt_input.shape | axis | result.shape input.shape | axis | result.shape
----------------------- | ---- | ----------------------------- ----------------------- | ---- | -----------------------------
`[D1, D2]` | `0` | `[1, D1, D2]` `[D1, D2]` | `0` | `[1, D1, D2]`
`[D1, D2]` | `1` | `[D1, 1, D2]` `[D1, D2]` | `1` | `[D1, 1, D2]`
@ -1271,14 +1276,14 @@ def expand_dims(rt_input, axis, name=None):
`[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]` `[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]`
Args: Args:
rt_input: The potentially tensor that should be expanded with a new input: The potentially tensor that should be expanded with a new
dimension. dimension.
axis: An integer constant indicating where the new dimension should be axis: An integer constant indicating where the new dimension should be
inserted. inserted.
name: A name for the operation (optional). name: A name for the operation (optional).
Returns: Returns:
A tensor with the same values as `rt_input`, with an added dimension of A tensor with the same values as `input`, with an added dimension of
size 1 at `axis`. size 1 at `axis`.
#### Examples: #### Examples:
@ -1300,24 +1305,24 @@ def expand_dims(rt_input, axis, name=None):
TensorShape([2, None, 1]) [[[1], [2]], [[3]]] TensorShape([2, None, 1]) [[[1], [2]], [[3]]]
``` ```
""" """
with ops.name_scope(name, 'RaggedExpandDims', [rt_input]): with ops.name_scope(name, 'RaggedExpandDims', [input]):
rt_input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor( input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
rt_input, name='rt_input') input, name='input')
if not ragged_tensor.is_ragged(rt_input): if not ragged_tensor.is_ragged(input):
return array_ops.expand_dims(rt_input, axis) return array_ops.expand_dims(input, axis)
ndims = None if rt_input.shape.ndims is None else rt_input.shape.ndims + 1 ndims = None if input.shape.ndims is None else input.shape.ndims + 1
axis = ragged_util.get_positive_axis(axis, ndims) axis = ragged_util.get_positive_axis(axis, ndims)
if axis == 0: if axis == 0:
values = rt_input values = input
splits = array_ops.stack([0, nrows(rt_input)]) splits = array_ops.stack([0, nrows(input)])
elif axis == 1: elif axis == 1:
values = rt_input values = input
splits = math_ops.range(nrows(rt_input) + 1) splits = math_ops.range(nrows(input) + 1)
else: else:
values = expand_dims(rt_input.values, axis - 1) values = expand_dims(input.values, axis - 1)
splits = rt_input.row_splits splits = input.row_splits
return ragged_factory_ops.from_row_splits(values, splits) return ragged_factory_ops.from_row_splits(values, splits)

View File

@ -0,0 +1,441 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Operator dispatch for RaggedTensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_shape
from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.util import dispatch
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_export
from tensorflow.python.util import tf_inspect
# @TODO(edloper): Set this to True in the CL that exports RaggedTensors.
_UPDATE_DOCSTRINGS = False
# Information about an argument to an operation: The name of the argument, its
# position in the argument list, and a boolean flag indicating whether it
# expects a list of tensors.
_ArgInfo = collections.namedtuple('ArgInfo', ['name', 'position', 'is_list'])
def _get_arg_infos(func, arg_names):
"""Returns an `_ArgInfo` for each argument of `func` specified by `arg_names`.
Args:
func: The function whose arguments should be described.
arg_names: The names of the arguments to get info for.
Returns:
A tuple of `_ArgInfo`s.
"""
arg_infos = []
# Inspect the func's argspec to find the position of each arg.
arg_spec = tf_inspect.getargspec(func)
for argname in arg_names:
assert isinstance(argname, str)
is_list = argname.startswith('[') and argname.endswith(']')
if is_list:
argname = argname[1:-1]
if argname not in arg_spec.args:
raise ValueError('Argument %r not found function in %s. Args=%s' %
(argname, func, arg_spec.args))
arg_infos.append(_ArgInfo(argname, arg_spec.args.index(argname), is_list))
return arg_infos
def _is_convertible_to_tensor(value):
"""Returns true if `value` is convertible to a `Tensor`."""
if isinstance(value,
(ops.Tensor, variables.Variable, np.ndarray, int, float, str)):
return True
elif isinstance(value, (sparse_tensor.SparseTensor,)):
return False
else:
try:
ops.convert_to_tensor(value)
return True
except (TypeError, ValueError):
return False
class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
"""OpDispatcher for unary ops that map a base op across ragged values."""
def __init__(self, original_op, arg_is_list=False):
self._original_op = original_op
self._arg_is_list = arg_is_list
arg_names = tf_inspect.getfullargspec(original_op)[0]
self._x = arg_names[0]
if _UPDATE_DOCSTRINGS:
original_op.__doc__ = (
original_op.__doc__.rstrip() + '\n\n' +
' `{x}` may be a `tf.RaggedTensor`.\n'.format(x=self._x))
def handle(self, args, kwargs):
if args:
x, args = args[0], args[1:]
else:
kwargs = kwargs.copy()
x = kwargs.pop(self._x, None)
if x is None:
return self.NOT_SUPPORTED
if self._arg_is_list:
found_ragged = False
for elt in x:
if ragged_tensor.is_ragged(elt):
found_ragged = True
elif not _is_convertible_to_tensor(elt):
return self.NOT_SUPPORTED
if found_ragged:
nested_splits_lists = [
elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt)
]
inner_values = [
elt.inner_values if ragged_tensor.is_ragged(elt) else elt
for elt in x
]
with ops.control_dependencies(
ragged_util.assert_splits_match(nested_splits_lists)):
return ragged_factory_ops.from_nested_row_splits(
self._original_op(inner_values, *args, **kwargs),
nested_splits_lists[0])
else:
return self.NOT_SUPPORTED
else:
found_ragged = ragged_tensor.is_ragged(x)
if found_ragged:
mapped_values = self._original_op(x.inner_values, *args, **kwargs)
return x.with_inner_values(mapped_values)
else:
return self.NOT_SUPPORTED
class BinaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
"""OpDispatcher for binary ops that map a base op across ragged values.
Supports broadcasting.
"""
def __init__(self, original_op):
self._original_op = original_op
arg_names = tf_inspect.getfullargspec(original_op)[0]
self._x = arg_names[0]
self._y = arg_names[1]
if _UPDATE_DOCSTRINGS:
original_op.__doc__ = (
original_op.__doc__.rstrip() + '\n\n' +
' `{x}` and `{y}` may be a `tf.RaggedTensor`.\n'.format(
x=self._x, y=self._y))
def handle(self, args, kwargs):
# Extract the binary args.
if len(args) > 1:
x = args[0]
y = args[1]
args = args[2:]
elif args:
kwargs = kwargs.copy()
x = args[0]
y = kwargs.pop(self._y, None)
args = args[1:]
else:
kwargs = kwargs.copy()
x = kwargs.pop(self._x, None)
y = kwargs.pop(self._y, None)
# Bail if we don't have at least one ragged argument.
x_is_ragged = ragged_tensor.is_ragged(x)
y_is_ragged = ragged_tensor.is_ragged(y)
if not (x_is_ragged or y_is_ragged):
return self.NOT_SUPPORTED
# Convert args to tensors. Bail if conversion fails.
try:
if not x_is_ragged:
x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype)
if not y_is_ragged:
y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype)
except (TypeError, ValueError):
return self.NOT_SUPPORTED
if ((x_is_ragged and y_is_ragged) or
(x_is_ragged and x.inner_values.shape.ndims <= y.shape.ndims) or
(y_is_ragged and y.inner_values.shape.ndims <= x.shape.ndims)):
bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(
ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x),
ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y))
x = ragged_tensor_shape.broadcast_to(
x, bcast_shape, broadcast_inner_dimensions=False)
y = ragged_tensor_shape.broadcast_to(
y, bcast_shape, broadcast_inner_dimensions=False)
x_values = x.inner_values if ragged_tensor.is_ragged(x) else x
y_values = y.inner_values if ragged_tensor.is_ragged(y) else y
mapped_values = self._original_op(x_values, y_values, *args, **kwargs)
if ragged_tensor.is_ragged(x):
return x.with_inner_values(mapped_values)
else:
return y.with_inner_values(mapped_values)
class RaggedDispatcher(dispatch.OpDispatcher):
"""OpDispatcher for ragged ops.
Dispatches to a wrapped op-handler if at least one of the `tensor_args`
arguments is a RaggedTensor or a RaggedTensorValue; and all of the
`tensor_args` arguments are convertible to Tensor or RaggedTensor.
"""
def __init__(self, original_op, ragged_op, ragged_args):
op_arg_names = tf_inspect.getfullargspec(original_op)[0]
ragged_arg_names = tf_inspect.getfullargspec(ragged_op)[0]
if op_arg_names != ragged_arg_names:
raise AssertionError(
'Signature must exactly match when overriding %s with %s: %s vs %s' %
(original_op, ragged_op, op_arg_names, ragged_arg_names))
self._ragged_op = ragged_op
self._ragged_args = _get_arg_infos(ragged_op, ragged_args)
if _UPDATE_DOCSTRINGS:
arg_list = ' and '.join('`%s`' % arg for arg in ragged_args)
original_op.__doc__ = (
original_op.__doc__.rstrip() + '\n\n' +
' {0} may be a `tf.RaggedTensor`.\n'.format(arg_list))
def handle(self, args, kwargs):
if self.is_supported(args, kwargs):
return self._ragged_op(*args, **kwargs)
else:
return self.NOT_SUPPORTED
def is_supported(self, args, kwargs):
found_ragged = False
for arg_info in self._ragged_args:
if arg_info.position < len(args):
arg = args[arg_info.position]
else:
arg = kwargs.get(arg_info.name, None)
if arg_info.is_list:
if not isinstance(arg, (list, tuple)):
return False
for elt in arg:
if ragged_tensor.is_ragged(elt):
found_ragged = True
elif not _is_convertible_to_tensor(elt):
return False
else:
if ragged_tensor.is_ragged(arg):
found_ragged = True
elif not _is_convertible_to_tensor(arg):
return False
return found_ragged
def ragged_dispatch(original_op, tensor_args):
def decorator(ragged_op):
dispatch.RaggedDispatcher(original_op, ragged_op,
tensor_args).register(original_op)
return ragged_op
return decorator
_UNARY_ELEMENTWISE_OPS = [
array_ops.check_numerics,
array_ops.identity,
array_ops.ones_like,
array_ops.ones_like_v2,
array_ops.zeros_like,
array_ops.zeros_like_v2,
clip_ops.clip_by_value,
math_ops.abs,
math_ops.acos,
math_ops.acosh,
math_ops.angle,
math_ops.asin,
math_ops.asinh,
math_ops.atan,
math_ops.atanh,
math_ops.cast,
math_ops.ceil,
math_ops.conj,
math_ops.cos,
math_ops.cosh,
math_ops.digamma,
math_ops.erf,
math_ops.erfc,
math_ops.exp,
math_ops.expm1,
math_ops.floor,
math_ops.imag,
math_ops.is_finite,
math_ops.is_inf,
math_ops.is_nan,
math_ops.lgamma,
math_ops.log,
math_ops.log1p,
math_ops.log_sigmoid,
math_ops.logical_not,
math_ops.negative,
math_ops.real,
math_ops.reciprocal,
math_ops.rint,
math_ops.round,
math_ops.rsqrt,
math_ops.saturate_cast,
math_ops.sign,
math_ops.sin,
math_ops.sinh,
math_ops.sqrt,
math_ops.square,
math_ops.tan,
parsing_ops.decode_compressed,
string_ops.string_to_number,
string_ops.string_to_hash_bucket,
string_ops.as_string,
string_ops.decode_base64,
string_ops.encode_base64,
string_ops.regex_full_match,
string_ops.regex_replace,
string_ops.string_strip,
string_ops.string_to_hash_bucket,
string_ops.string_to_hash_bucket_fast,
string_ops.string_to_hash_bucket_strong,
string_ops.substr,
string_ops.substr_v2,
string_ops.string_length,
string_ops.string_length_v2,
string_ops.unicode_script,
]
_UNARY_LIST_ELEMENTWISE_OPS = [
math_ops.add_n,
string_ops.string_join,
]
_BINARY_ELEMENTWISE_OPS = [
math_ops.add,
math_ops.atan2,
math_ops.complex,
math_ops.div_no_nan,
math_ops.divide,
math_ops.equal,
math_ops.floordiv,
math_ops.floormod,
math_ops.greater,
math_ops.greater_equal,
math_ops.less,
math_ops.less_equal,
math_ops.logical_and,
math_ops.logical_or,
math_ops.logical_xor,
math_ops.maximum,
math_ops.minimum,
math_ops.multiply,
math_ops.not_equal,
math_ops.pow,
math_ops.realdiv,
math_ops.squared_difference,
math_ops.subtract,
math_ops.truediv,
math_ops.truncatediv,
math_ops.truncatemod,
]
# (original_op, ragged_op, ragged_args)
_RAGGED_DISPATCH_OPS = [
(array_ops.batch_gather, ragged_array_ops.batch_gather,
['params', 'indices']),
(array_ops.concat, ragged_array_ops.concat, ['values']),
(array_ops.expand_dims_v2, ragged_array_ops.expand_dims, ['input']),
(array_ops.gather_v2, ragged_array_ops.gather, ['params', 'indices']),
(array_ops.gather_nd, ragged_array_ops.gather_nd, ['params', 'indices']),
(array_ops.stack, ragged_array_ops.stack, ['values']),
(array_ops.tile, ragged_array_ops.tile, ['input']),
(array_ops.where, ragged_array_ops.where, ['condition', 'x', 'y']),
(math_ops.unsorted_segment_sum, ragged_math_ops.segment_sum,
['data', 'segment_ids']),
(math_ops.unsorted_segment_prod, ragged_math_ops.segment_prod,
['data', 'segment_ids']),
(math_ops.unsorted_segment_min, ragged_math_ops.segment_min,
['data', 'segment_ids']),
(math_ops.unsorted_segment_max, ragged_math_ops.segment_max,
['data', 'segment_ids']),
(math_ops.unsorted_segment_mean, ragged_math_ops.segment_mean,
['data', 'segment_ids']),
(math_ops.unsorted_segment_sqrt_n, ragged_math_ops.segment_sqrt_n,
['data', 'segment_ids']),
(math_ops.reduce_sum, ragged_math_ops.reduce_sum, ['input_tensor']),
(math_ops.reduce_prod, ragged_math_ops.reduce_prod, ['input_tensor']),
(math_ops.reduce_min, ragged_math_ops.reduce_min, ['input_tensor']),
(math_ops.reduce_max, ragged_math_ops.reduce_max, ['input_tensor']),
(math_ops.reduce_mean, ragged_math_ops.reduce_mean, ['input_tensor']),
(math_ops.reduce_any, ragged_math_ops.reduce_any, ['input_tensor']),
(math_ops.reduce_all, ragged_math_ops.reduce_all, ['input_tensor']),
]
def register_dispatchers():
"""Constructs & registers OpDispatchers for ragged ops."""
op_list = (
_UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS +
_BINARY_ELEMENTWISE_OPS + [x[0] for x in _RAGGED_DISPATCH_OPS])
for op in op_list:
_, undecorated_op = tf_decorator.unwrap(op)
if not hasattr(undecorated_op, tf_export.API_ATTRS['tensorflow'].names):
raise AssertionError('Expected %s to be an exported symbol '
'(while adding a RaggedTensor dispatcher)')
for op in _UNARY_ELEMENTWISE_OPS:
UnaryRaggedElementwiseDispatcher(op).register(op)
for op in _UNARY_LIST_ELEMENTWISE_OPS:
UnaryRaggedElementwiseDispatcher(op, True).register(op)
for op in _BINARY_ELEMENTWISE_OPS:
BinaryRaggedElementwiseDispatcher(op).register(op)
for (original_op, ragged_op, args) in _RAGGED_DISPATCH_OPS:
RaggedDispatcher(original_op, ragged_op, args).register(original_op)
docstring = (
'\n\n### Additional ops that support `RaggedTensor`\n\n' + '\n'.join([
'* `tf.%s`' % tf_export.get_canonical_name_for_symbol(op)
for op in op_list
]))
return docstring

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for ragged.elementwise_ops.""" """Tests for RaggedTensor operator dispatch."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -21,106 +21,108 @@ from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import ragged from tensorflow.python.ops import ragged
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
# Constants listing various op types to test. Each elementwise operation # Constants listing various op types to test. Each operation
# should be included in at least one list below, or tested separately if # should be included in at least one list below, or tested separately if
# necessary (e.g., because it expects additional arguments). # necessary (e.g., because it expects additional arguments).
UNARY_FLOAT_OPS = [ UNARY_FLOAT_OPS = [
ragged.abs, math_ops.abs,
ragged.acos, math_ops.acos,
ragged.acosh, math_ops.acosh,
ragged.angle, math_ops.angle,
ragged.asin, math_ops.asin,
ragged.asinh, math_ops.asinh,
ragged.atan, math_ops.atan,
ragged.atanh, math_ops.atanh,
ragged.ceil, math_ops.ceil,
ragged.conj, math_ops.conj,
ragged.cos, math_ops.cos,
ragged.cosh, math_ops.cosh,
ragged.digamma, math_ops.digamma,
ragged.erf, math_ops.erf,
ragged.erfc, math_ops.erfc,
ragged.exp, math_ops.exp,
ragged.expm1, math_ops.expm1,
ragged.floor, math_ops.floor,
ragged.imag, math_ops.imag,
ragged.is_finite, math_ops.is_finite,
ragged.is_inf, math_ops.is_inf,
ragged.is_nan, math_ops.is_nan,
ragged.lgamma, math_ops.lgamma,
ragged.log, math_ops.log,
ragged.log1p, math_ops.log1p,
ragged.log_sigmoid, math_ops.log_sigmoid,
ragged.negative, math_ops.negative,
ragged.real, math_ops.real,
ragged.reciprocal, math_ops.reciprocal,
ragged.rint, math_ops.rint,
ragged.round, math_ops.round,
ragged.rsqrt, math_ops.rsqrt,
ragged.sign, math_ops.sign,
ragged.sin, math_ops.sin,
ragged.sinh, math_ops.sinh,
ragged.sqrt, math_ops.sqrt,
ragged.square, math_ops.square,
ragged.tan, math_ops.tan,
ragged.as_string, array_ops.identity,
ragged.identity, array_ops.ones_like,
ragged.ones_like, array_ops.zeros_like,
ragged.zeros_like,
] ]
UNARY_BOOL_OPS = [ UNARY_BOOL_OPS = [
ragged.logical_not, math_ops.logical_not,
] ]
UNARY_STRING_OPS = [ UNARY_STRING_OPS = [
ragged.decode_base64, string_ops.decode_base64,
ragged.encode_base64, string_ops.encode_base64,
ragged.string_strip, string_ops.string_strip,
ragged.decode_compressed, parsing_ops.decode_compressed,
] ]
BINARY_FLOAT_OPS = [ BINARY_FLOAT_OPS = [
ragged.add, math_ops.add,
ragged.atan2, math_ops.atan2,
ragged.complex, math_ops.complex,
ragged.div, math_ops.div_no_nan,
ragged.div_no_nan, math_ops.divide,
ragged.divide, math_ops.equal,
ragged.equal, math_ops.floordiv,
ragged.floordiv, math_ops.floormod,
ragged.floormod, math_ops.greater,
ragged.greater, math_ops.greater_equal,
ragged.greater_equal, math_ops.less,
ragged.less, math_ops.less_equal,
ragged.less_equal, math_ops.maximum,
ragged.maximum, math_ops.minimum,
ragged.minimum, math_ops.multiply,
ragged.multiply, math_ops.not_equal,
ragged.not_equal, math_ops.pow,
ragged.pow, math_ops.realdiv,
ragged.realdiv, math_ops.squared_difference,
ragged.squared_difference, math_ops.subtract,
ragged.subtract, math_ops.truediv,
ragged.truediv,
] ]
BINARY_BOOL_OPS = [ BINARY_BOOL_OPS = [
ragged.logical_and, math_ops.logical_and,
ragged.logical_or, math_ops.logical_or,
ragged.logical_xor, math_ops.logical_xor,
] ]
UNARY_INT_OPS = [ UNARY_INT_OPS = [
ragged.unicode_script, string_ops.unicode_script,
] ]
BINARY_INT_OPS = [ BINARY_INT_OPS = [
ragged.truncatediv, math_ops.truncatediv,
ragged.truncatemod, math_ops.truncatemod,
] ]
@ -171,50 +173,49 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
[{'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]), 'op': op} [{'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]), 'op': op}
for op in UNARY_STRING_OPS] + for op in UNARY_STRING_OPS] +
[ [
{'op': ragged.clip_by_value, {'op': clip_ops.clip_by_value,
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]), 'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
'clip_value_min': 0.1, 'clip_value_max': 4.0}, 'clip_value_min': 0.1, 'clip_value_max': 4.0},
{'op': ragged.cast, {'op': math_ops.cast,
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]), 'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
'dtype': dtypes.int32}, 'dtype': dtypes.int32},
{'op': ragged.saturate_cast, {'op': math_ops.saturate_cast,
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]), 'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
'dtype': dtypes.int32}, 'dtype': dtypes.int32},
{'op': ragged.string_to_hash_bucket, {'op': string_ops.string_to_hash_bucket,
'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]), 'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
'num_buckets': 1000}, 'num_buckets': 1000},
{'op': ragged.string_to_hash_bucket_fast, {'op': string_ops.string_to_hash_bucket_fast,
'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]), 'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
'num_buckets': 1000}, 'num_buckets': 1000},
{'op': ragged.string_to_hash_bucket_strong, {'op': string_ops.string_to_hash_bucket_strong,
'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]), 'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
'num_buckets': 1000, 'num_buckets': 1000,
'key': [1231, 12512]}, 'key': [1231, 12512]},
{'op': ragged.string_to_number, {'op': string_ops.string_to_number,
'x': ragged.constant_value([['-2.0', '3.0'], ['-3.0']])}, 'x': ragged.constant_value([['-2.0', '3.0'], ['-3.0']])},
{'op': ragged.regex_full_match, {'op': string_ops.regex_full_match,
'x': ragged.constant_value([['hello', '123'], ['1+1']]), 'x': ragged.constant_value([['hello', '123'], ['1+1']]),
'pattern': r'\w+'}, 'pattern': r'\w+'},
{'op': ragged.regex_replace, {'op': string_ops.regex_replace,
'x': ragged.constant_value([['hello', '123'], ['1+1']]), 'x': ragged.constant_value([['hello', '123'], ['1+1']]),
'pattern': r'\d', 'pattern': r'\d',
'rewrite': '#'}, 'rewrite': '#'},
{'op': ragged.substr, {'op': string_ops.substr,
'x': ragged.constant_value([['hello', '123'], ['1+1']]), 'x': ragged.constant_value([['hello', '123'], ['1+1']]),
'pos': 2, 'len': 3}, 'pos': 2, 'len': 3},
{'op': ragged.check_numerics, {'op': array_ops.check_numerics,
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]), 'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
'message': 'check-numerics'}, 'message': 'check-numerics'},
] ]
) # pyformat: disable ) # pyformat: disable
def testUnaryOp(self, x, op=ragged.abs, **extra_args): def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args):
x = ragged.convert_to_tensor_or_ragged_tensor(x) x = ragged.convert_to_tensor_or_ragged_tensor(x)
result = op(x, **extra_args) result = op(x, **extra_args)
# Run the wrapped op on the dense values, for comparison. # Run the wrapped op on the dense values, for comparison.
dense_x = x.inner_values if isinstance(x, ragged.RaggedTensor) else x dense_x = x.inner_values if isinstance(x, ragged.RaggedTensor) else x
expected_flat_values = array_ops.reshape( expected_flat_values = array_ops.reshape(op(dense_x, **extra_args), [-1])
op.__wrapped__(dense_x, **extra_args), [-1])
with self.test_session(): with self.test_session():
# Check that the result has the expected shape. # Check that the result has the expected shape.
@ -285,12 +286,17 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
#===================================================================== #=====================================================================
{'x': ragged.constant_value([[[1, 2], [3], [4]], [[], [5, 7, 8]]]), {'x': ragged.constant_value([[[1, 2], [3], [4]], [[], [5, 7, 8]]]),
'y': ragged.constant_value([[[3, 8], [2], [5]], [[], [1, 9, 8]]]), 'y': ragged.constant_value([[[3, 8], [2], [5]], [[], [1, 9, 8]]]),
'use_kwargs': True}, 'use_kwargs': ('x', 'y')},
{'x': ragged.constant_value([[[1, 2]], [[3, 4], [5, 6], [7, 8]]], {'x': ragged.constant_value([[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
ragged_rank=1), ragged_rank=1),
'y': ragged.constant_value([[[9, 3]], [[5, 2], [3, 4], [7, 6]]], 'y': ragged.constant_value([[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
ragged_rank=1), ragged_rank=1),
'use_kwargs': True}, 'use_kwargs': ('x', 'y')},
{'x': ragged.constant_value([[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
ragged_rank=1),
'y': ragged.constant_value([[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
ragged_rank=1),
'use_kwargs': ('x',)},
] + ] +
#========================================================================= #=========================================================================
# Test each unary op. # Test each unary op.
@ -306,16 +312,16 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
[{'x': ragged.constant_value([[True, True], [False]]), [{'x': ragged.constant_value([[True, True], [False]]),
'y': ragged.constant_value([[False, True], [False]]), 'y': ragged.constant_value([[False, True], [False]]),
'op': op} 'op': op}
for op in BINARY_BOOL_OPS] + for op in BINARY_BOOL_OPS]
[
]
) # pyformat: disable ) # pyformat: disable
def testBinaryOp(self, x, y, op=ragged.add, **extra_args): def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args):
use_kwargs = extra_args.pop('use_kwargs', False) use_kwargs = extra_args.pop('use_kwargs', ())
x = ragged.convert_to_tensor_or_ragged_tensor(x) x = ragged.convert_to_tensor_or_ragged_tensor(x)
y = ragged.convert_to_tensor_or_ragged_tensor(y) y = ragged.convert_to_tensor_or_ragged_tensor(y)
if use_kwargs: if 'x' in use_kwargs and 'y' in use_kwargs:
result = op(x=x, y=y, **extra_args) result = op(x=x, y=y, **extra_args)
elif 'y' in use_kwargs:
result = op(x, y=y, **extra_args)
else: else:
result = op(x, y, **extra_args) result = op(x, y, **extra_args)
@ -323,7 +329,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
dense_x = x.inner_values if isinstance(x, ragged.RaggedTensor) else x dense_x = x.inner_values if isinstance(x, ragged.RaggedTensor) else x
dense_y = y.inner_values if isinstance(y, ragged.RaggedTensor) else y dense_y = y.inner_values if isinstance(y, ragged.RaggedTensor) else y
expected_flat_values = array_ops.reshape( expected_flat_values = array_ops.reshape(
op.__wrapped__(dense_x, dense_y, **extra_args), [-1]) op(dense_x, dense_y, **extra_args), [-1])
with self.test_session(): with self.test_session():
# Check that the result has the expected shape. # Check that the result has the expected shape.
@ -358,16 +364,17 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
ragged.constant_value([[[2, 9], [12]], [[8]]])), ragged.constant_value([[[2, 9], [12]], [[8]]])),
'use_kwargs': True}, 'use_kwargs': True},
] + [ ] + [
{'op': ragged.add_n, {'op': math_ops.add_n,
'inputs': (ragged.constant_value([[1, 3], [-3]]), 'inputs': (ragged.constant_value([[1, 3], [-3]]),
ragged.constant_value([[4, 7], [88]]), ragged.constant_value([[4, 7], [88]]),
ragged.constant_value([[2, 9], [12]]))}, ragged.constant_value([[2, 9], [12]]))},
{'op': ragged.string_join, {'op': string_ops.string_join,
'inputs': (ragged.constant_value([['a', 'b'], ['c']]), 'inputs': (ragged.constant_value([['a', 'b'], ['c']]),
ragged.constant_value([['foo', 'bar'], ['baz']]), ragged.constant_value([['foo', 'bar'], ['baz']]),
ragged.constant_value([['2', '9'], ['12']]))}, ragged.constant_value([['2', '9'], ['12']]))},
]) # pyformat: disable ]) # pyformat: disable
def testListValuedOp(self, inputs, op=ragged.add_n, **extra_args): def testListValuedElementwiseOp(self, inputs, op=math_ops.add_n,
**extra_args):
use_kwargs = extra_args.pop('use_kwargs', False) use_kwargs = extra_args.pop('use_kwargs', False)
inputs = [ragged.convert_to_tensor_or_ragged_tensor(x) for x in inputs] inputs = [ragged.convert_to_tensor_or_ragged_tensor(x) for x in inputs]
if use_kwargs: if use_kwargs:
@ -381,7 +388,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
for x in inputs for x in inputs
] ]
expected_flat_values = array_ops.reshape( expected_flat_values = array_ops.reshape(
op.__wrapped__(dense_inputs, **extra_args), [-1]) op(dense_inputs, **extra_args), [-1])
with self.test_session(): with self.test_session():
# Check that the result has the expected shape. # Check that the result has the expected shape.
@ -395,13 +402,13 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
self.assertAllEqual(expected_flat_values, result_flat_values) self.assertAllEqual(expected_flat_values, result_flat_values)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testUnknownRankError(self): def testElementwiseOpUnknownRankError(self):
x = ragged.constant([[1, 2], [3]]) x = ragged.constant([[1, 2], [3]])
y = ragged.from_row_splits( y = ragged.from_row_splits(
array_ops.placeholder_with_default([1, 2, 3], shape=None), x.row_splits) array_ops.placeholder_with_default([1, 2, 3], shape=None), x.row_splits)
with self.assertRaisesRegexp( with self.assertRaisesRegexp(ValueError,
ValueError, r'Unable to broadcast: unknown rank'): r'Unable to broadcast: unknown rank'):
ragged.add(x, y) math_ops.add(x, y)
@parameterized.parameters([ @parameterized.parameters([
dict( dict(
@ -417,26 +424,31 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
y=ragged.constant_value([[1]]), y=ragged.constant_value([[1]]),
expected=[[[2]]]), expected=[[[2]]]),
]) ])
def testBroadcastAdd(self, x, y, expected): def testElementwiseOpBroadcast(self, x, y, expected):
x = ragged.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32) x = ragged.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32)
y = ragged.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32) y = ragged.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32)
result = x + y result = x + y
with self.cached_session(): with self.cached_session():
self.assertEqual(result.eval().tolist(), expected) self.assertEqual(result.eval().tolist(), expected)
def testShapeMismatch(self): def testElementwiseOpShapeMismatch(self):
x = ragged.constant([[1, 2, 3], [4, 5]]) x = ragged.constant([[1, 2, 3], [4, 5]])
y = ragged.constant([[1, 2, 3], [4, 5, 6]]) y = ragged.constant([[1, 2, 3], [4, 5, 6]])
with self.assertRaisesRegexp(errors.InvalidArgumentError, with self.assertRaisesRegexp(errors.InvalidArgumentError,
'Incompatible shapes'): 'Incompatible shapes'):
with self.cached_session(): with self.cached_session():
ragged.add(x, y).eval() math_ops.add(x, y).eval()
def testDocstring(self): def testBinaryOpSparseAndRagged(self):
self.assertRegexpMatches( x = ragged.constant([[1, 2, 3], [4, 5]])
ragged.add.__doc__, y = sparse_tensor.SparseTensor([[0, 0], [0, 1], [2, 0]], [1, 2, 3], [3, 2])
'Ragged version of the elementwise operation `tf.math.add`') with self.assertRaises(TypeError):
self.assertEqual(ragged.add.__name__, 'add') with self.cached_session():
math_ops.add(x, y).eval()
with self.assertRaises(TypeError):
with self.cached_session():
math_ops.add_n([x, y]).eval()
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,389 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Elementwise operations for RaggedTensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_shape
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_export
from tensorflow.python.util import tf_inspect
# Information about an argument to an operation: The name of the argument, its
# position in the argument list, and a boolean flag indicating whether it
# expects a list of tensors.
_ArgInfo = collections.namedtuple('ArgInfo', ['name', 'position', 'is_list'])
def make_elementwise_op(op, *elementwise_args):
"""Returns a ragged-tensor version of the elementwise operation `op`.
The returned operation will:
1. Broadcast the elementwise arguments to have a compatible shape.
An exception is raised if the tensors not broadcast-compatible.
2. Call `op`, substituting the dense values of the broadcasted tensor for
each elementwise argument.
3. Return a potentially ragged tensor constructed from the output of `op`
and the broadcasted tensors' nested row splits.
For example, you can construct a ragged-tensor version of the standard
operation `tf.add` by calling `make_elementwise_op(tf.add, 'x', 'y')`.
Args:
op: The operation to wrap.
*elementwise_args: The names of arguments to `op` that are treated as
elementwise. Arguments that take a list of tensors should have their
names wrapped in square brackets (e.g. "[inputs]").
Raises:
ValueError: If any name specified in `elementwise_args` is not the name
of an argument to `op`.
"""
elementwise_arg_infos = _get_arg_infos(op, elementwise_args)
def ragged_op(*args, **kwargs):
"""Ragged version of `op`."""
args = list(args)
# Collect all of the elementwise arguments, and put them in a single
# dict whose values are the (potentially ragged) tensors that need to
# be broadcast to a common shape. The keys of this dict are tuples
# (argkey, index), where argkey is an int for poitional args or a string
# for keyword args; and index is None for non-list args and the index of the
# tensor for list args.
elementwise_args = {}
for (name, position, is_list) in elementwise_arg_infos.values():
if position < len(args):
if is_list:
args[position] = list(args[position])
for (index, arg) in enumerate(args[position]):
elementwise_args[position, index] = arg
else:
elementwise_args[position, None] = args[position]
elif name in kwargs:
if is_list:
kwargs[name] = list(kwargs[name])
for (i, arg) in enumerate(kwargs[name]):
elementwise_args[name, i] = arg
else:
elementwise_args[name, None] = kwargs[name]
with ops.name_scope(None, op.__name__, elementwise_args.values()):
# Convert all inputs to tensors or ragged tensors.
for ((key, index), tensor) in elementwise_args.items():
argname = elementwise_arg_infos[key].name
converted = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
tensor, name=argname)
elementwise_args[key, index] = converted
# Broadcast tensors to have compatible shapes.
broadcast_args, result_splits, broadcast_check_ops = \
_broadcast_elementwise_args(elementwise_args)
# Replace tensor arguments with their dense values.
for ((key, index), tensor) in broadcast_args.items():
if ragged_tensor.is_ragged(tensor):
if isinstance(key, int) and index is None:
args[key] = tensor.inner_values
elif isinstance(key, int) and index is not None:
args[key][index] = tensor.inner_values
elif isinstance(key, str) and index is None:
kwargs[key] = tensor.inner_values
else:
assert isinstance(key, str) and index is not None
kwargs[key][index] = tensor.inner_values
# Call the elementwise op on the broadcasted dense values.
with ops.control_dependencies(broadcast_check_ops):
result_values = op(*args, **kwargs)
# Restore any ragged dimensions that we stripped off, and return the
# result.
return ragged_factory_ops.from_nested_row_splits(result_values,
result_splits)
# Construct the docstring.
op_name = tf_export.get_canonical_name_for_symbol(op)
assert op_name is not None, op
argnames = ', '.join('`%s`' % s.strip('[]') for s in elementwise_args)
docstring = _ELEMENTWISE_DOCSTRING % dict(op_name=op_name, argnames=argnames)
# Update name, docstring, signature, etc., for the wrapper, and return it.
return tf_decorator.make_decorator(op, ragged_op, decorator_doc=docstring)
_ELEMENTWISE_DOCSTRING = """\
Ragged version of the elementwise operation `tf.%(op_name)s`.
The following elementwise arguments may be ragged or dense:
%(argnames)s.
These arguments will be broadcast to a compatible shape if necessary.
"""
def _get_arg_infos(func, elementwise_args):
"""Returns `_ArgInfo`s for each `func` arg specified by `elementwise_args`.
Args:
func: The function whose arguments should be described.
elementwise_args: The names of the arguments to get info for.
Returns:
A dictionary that maps both names and positions of arguments to
`_ArgInfo` tuples.
"""
arg_infos = {}
# Inspect the func's argspec to find the position of each arg.
arg_spec = tf_inspect.getargspec(func)
for argname in elementwise_args:
assert isinstance(argname, str)
is_list = argname.startswith('[') and argname.endswith(']')
if is_list:
argname = argname[1:-1]
assert argname in arg_spec.args, (func, argname, arg_spec.args)
arg_info = _ArgInfo(argname, arg_spec.args.index(argname), is_list)
arg_infos[arg_info.name] = arg_info
arg_infos[arg_info.position] = arg_info
return arg_infos
def _broadcast_elementwise_args(elementwise_args):
"""Broadcasts the values of `elementwise_args` to have compatible shapes.
Args:
elementwise_args: A dictionary whose keys are potentially ragged tensors.
Returns:
A tuple `(broadcast_args, broadcast_splits, checks)` where:
* `broadcast_args` is a dictionary with the same keys as
`elementwise_args`, mapping to broadcasted tensors.
* `broadcast_splits` is the broadcasted nested row splits.
* `checks` is a possibly empty tuple of assertion operations that should
be added as control dependencies.
Raises:
ValueError: If broadcasting fails.
"""
# No elementwise arguments were used: nothing to do!
if not elementwise_args:
return elementwise_args, (), ()
# A single elementwise argument was used: no broadcasting necessary.
if len(elementwise_args) == 1:
arg = list(elementwise_args.values())[0]
if ragged_tensor.is_ragged(arg):
return elementwise_args, arg.nested_row_splits, ()
else:
return elementwise_args, (), ()
# Multiple elementwise arguments.
else:
is_ragged = [ragged_tensor.is_ragged(t) for t in elementwise_args.values()]
if not any(is_ragged):
return elementwise_args, (), ()
# If we have a single ragged tensor plus a set of scalars, then we can
# rely on the underlying elementwise op to do broadcasting.
if (sum(is_ragged) == 1 and
all((ragged_tensor.is_ragged(t) or t.shape.ndims == 0)
for t in elementwise_args.values())):
nested_splits_lists = [
t.nested_row_splits
for t in elementwise_args.values()
if ragged_tensor.is_ragged(t)][0]
return elementwise_args, nested_splits_lists, ()
else:
# Get the shapes of all the elementwise arguments.
shapes = [ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(t)
for t in elementwise_args.values()]
# Broadcast the shapes to all have the same rank (the max rank).
ranks = [t.shape.ndims for t in elementwise_args.values()]
if any(rank is None for rank in ranks):
raise ValueError('Unable to broadcast: unknown rank')
broadcast_rank = max(ranks)
shapes = [shape.broadcast_to_rank(broadcast_rank) for shape in shapes]
# For each dimension, broadcast the shapes to be compatible.
for axis in range(broadcast_rank):
# For each i, broadcast shape[i+1] to be compatible with shape[i]; and
# then finally broadcast shape[0] to be compatible with shape[-1].
for i in range(len(shapes)):
j = (i + 1) % len(shapes)
dim_size = shapes[i].dimension_size(axis)
shapes[j] = shapes[j].broadcast_dimension(axis, dim_size)
broadcast_shape = shapes[0]
# Broadcast every elementwise arg to the shape that we calculated.
elementwise_args = dict([
(key, ragged_tensor_shape.broadcast_to(t, broadcast_shape, False))
for (key, t) in elementwise_args.items()])
nested_splits_lists = list(elementwise_args.values())[0].nested_row_splits
return elementwise_args, nested_splits_lists, ()
# A list of symbols that should be exported in the "ragged" package.
_symbols_to_export = []
def _add_elementwise_ops_to_this_module(specs, verbose=False):
"""Adds ragged versions of the given ops to this module.
Args:
specs: A list of tuples containing the arguments for `make_elementwise_op`.
verbose: If true, then display each op that gets added.
"""
for spec in specs:
original_op = spec[0]
ragged_op = make_elementwise_op(*spec)
canonical_name = tf_export.get_canonical_name_for_symbol(original_op)
if '.' not in canonical_name:
op_name = canonical_name
else:
op_name = original_op.__name__
# Temporary hack (will be removed once dispatch is added for RaggedTensors):
if op_name == 'neg': op_name = 'negative'
if verbose:
print('Adding ragged_elementwise_op: tf.ragged.%s (based on tf.%s)' %
(op_name, canonical_name))
globals()[op_name] = ragged_op
_symbols_to_export.append(op_name)
# A list of tuples containing arguments for `make_elementwise_op`, for each
# elementwise operation that should have a ragged version built. Each tuple
# contains a standard `Tensor` operation, and the names of any arguments
# that are processed in elementwise fashion.
_TF_ELEMENTWISE_OPS = [
# Unary math operations.
(clip_ops.clip_by_value, 't'),
(math_ops.abs, 'x'),
(math_ops.acos, 'x'),
(math_ops.acosh, 'x'),
(math_ops.angle, 'input'),
(math_ops.asin, 'x'),
(math_ops.asinh, 'x'),
(math_ops.atan, 'x'),
(math_ops.atanh, 'x'),
(math_ops.cast, 'x'),
(math_ops.ceil, 'x'),
(math_ops.conj, 'x'),
(math_ops.cos, 'x'),
(math_ops.cosh, 'x'),
(math_ops.digamma, 'x'),
(math_ops.erf, 'x'),
(math_ops.erfc, 'x'),
(math_ops.exp, 'x'),
(math_ops.expm1, 'x'),
(math_ops.floor, 'x'),
(math_ops.imag, 'input'),
(math_ops.is_finite, 'x'),
(math_ops.is_inf, 'x'),
(math_ops.is_nan, 'x'),
(math_ops.lgamma, 'x'),
(math_ops.log, 'x'),
(math_ops.log1p, 'x'),
(math_ops.log_sigmoid, 'x'),
(math_ops.logical_not, 'x'),
(math_ops.negative, 'x'),
(math_ops.real, 'input'),
(math_ops.reciprocal, 'x'),
(math_ops.rint, 'x'),
(math_ops.round, 'x'),
(math_ops.rsqrt, 'x'),
(math_ops.saturate_cast, 'value'),
(math_ops.sign, 'x'),
(math_ops.sin, 'x'),
(math_ops.sinh, 'x'),
(math_ops.sqrt, 'x'),
(math_ops.square, 'x'),
(math_ops.tan, 'x'),
# Binary math operations
(math_ops.add, 'x', 'y'),
(math_ops.atan2, 'y', 'x'),
(math_ops.complex, 'real', 'imag'),
(math_ops.div, 'x', 'y'),
(math_ops.div_no_nan, 'x', 'y'),
(math_ops.divide, 'x', 'y'),
(math_ops.equal, 'x', 'y'),
(math_ops.floordiv, 'x', 'y'),
(math_ops.floormod, 'x', 'y'),
(math_ops.greater, 'x', 'y'),
(math_ops.greater_equal, 'x', 'y'),
(math_ops.less, 'x', 'y'),
(math_ops.less_equal, 'x', 'y'),
(math_ops.logical_and, 'x', 'y'),
(math_ops.logical_or, 'x', 'y'),
(math_ops.logical_xor, 'x', 'y'),
(math_ops.maximum, 'x', 'y'),
(math_ops.minimum, 'x', 'y'),
(math_ops.multiply, 'x', 'y'),
(math_ops.not_equal, 'x', 'y'),
(math_ops.pow, 'x', 'y'),
(math_ops.realdiv, 'x', 'y'),
(math_ops.squared_difference, 'x', 'y'),
(math_ops.subtract, 'x', 'y'),
(math_ops.truediv, 'x', 'y'),
(math_ops.truncatediv, 'x', 'y'),
(math_ops.truncatemod, 'x', 'y'),
# N-ary math operations
(math_ops.add_n, '[inputs]'),
# String operations
(string_ops.as_string, 'input'),
(string_ops.decode_base64, 'input'),
(string_ops.encode_base64, 'input'),
(string_ops.regex_full_match, 'input'),
(string_ops.regex_replace, 'input'),
(string_ops.string_join, '[inputs]'),
(string_ops.string_strip, 'input'),
(string_ops.string_to_hash_bucket, 'input'),
(string_ops.string_to_hash_bucket_fast, 'input'),
(string_ops.string_to_hash_bucket_strong, 'input'),
(string_ops.substr, 'input'),
(string_ops.unicode_script, 'input'),
# Array ops
(array_ops.check_numerics, 'tensor'),
(array_ops.identity, 'input'),
(array_ops.ones_like, 'tensor'),
(array_ops.zeros_like, 'tensor'),
# Parsing ops
(parsing_ops.decode_compressed, 'bytes'),
(parsing_ops.string_to_number, 'string_tensor'),
]
_add_elementwise_ops_to_this_module(_TF_ELEMENTWISE_OPS)

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
@ -55,7 +56,7 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
), ),
# [d1, (d2)] -> [d1, (d2)] # [d1, (d2)] -> [d1, (d2)]
dict( dict(
fn=lambda x: x+1, fn=lambda x: x + np.int64(1),
elems=[[1, 2, 3], [4, 5], [6, 7]], elems=[[1, 2, 3], [4, 5], [6, 7]],
expected_output=[[2, 3, 4], [5, 6], [7, 8]], expected_output=[[2, 3, 4], [5, 6], [7, 8]],
dtype=dtypes.int64, dtype=dtypes.int64,
@ -64,7 +65,7 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
), ),
# [d1, (d2), d3] -> [d1, (d2), d3] # [d1, (d2), d3] -> [d1, (d2), d3]
dict( dict(
fn=lambda x: x+1, fn=lambda x: x + np.int64(1),
elems=[[[1, 2], [3, 4]], [], [[5, 6], [7, 8], [9, 0]]], elems=[[[1, 2], [3, 4]], [], [[5, 6], [7, 8], [9, 0]]],
elems_ragged_rank=1, elems_ragged_rank=1,
expected_ragged_rank=1, expected_ragged_rank=1,
@ -131,7 +132,7 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
), ),
# [d1, (d2), (d3), (d4a), (d5)] -> [d1, (d2), (d3), (d4b), (d5)] # [d1, (d2), (d3), (d4a), (d5)] -> [d1, (d2), (d3), (d4b), (d5)]
dict( dict(
fn=lambda x: ragged.add(x, 1), fn=lambda x: x + np.int64(1),
elems=[[[[[1, 2, 3]], [[4], [5]]]], [[[[6, 7]]], [[[8], []]]]], elems=[[[[[1, 2, 3]], [[4], [5]]]], [[[[6, 7]]], [[[8], []]]]],
expected_output=[[[[[2, 3, 4]], [[5], [6]]]], expected_output=[[[[[2, 3, 4]], [[5], [6]]]],
[[[[7, 8]]], [[[9], []]]]], [[[[7, 8]]], [[[9], []]]]],
@ -196,8 +197,8 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def _increment(f): def _increment(f):
return { return {
'batman': ragged.add(f['batman'], 1), 'batman': f['batman'] + 1,
'robin': ragged.add(f['robin'], 1), 'robin': f['robin'] + 1,
} }
output = ragged.map_fn( output = ragged.map_fn(

View File

@ -143,8 +143,11 @@ Computes the %(combination)s along segments of a RaggedTensor.
""" """
def _ragged_segment_aggregate(unsorted_segment_op, data, segment_ids, def _ragged_segment_aggregate(unsorted_segment_op,
num_segments, name=None): data,
segment_ids,
num_segments,
name=None):
"""Aggregates along segments of a RaggedTensor using `unsorted_segment_op`. """Aggregates along segments of a RaggedTensor using `unsorted_segment_op`.
Returns a RaggedTensor `output` with `num_segments` rows, where the row Returns a RaggedTensor `output` with `num_segments` rows, where the row
@ -212,8 +215,7 @@ def _ragged_segment_aggregate(unsorted_segment_op, data, segment_ids,
assert output_row_lengths.dtype == dtypes.int64 assert output_row_lengths.dtype == dtypes.int64
# Build the splits tensor for the output RaggedTensor. # Build the splits tensor for the output RaggedTensor.
output_splits = array_ops.concat( output_splits = array_ops.concat([
[
array_ops.zeros([1], dtypes.int64), array_ops.zeros([1], dtypes.int64),
math_ops.cumsum(output_row_lengths) math_ops.cumsum(output_row_lengths)
], ],
@ -311,7 +313,7 @@ _set_ragged_segment_docstring(segment_sqrt_n, 'sum divided by sqrt(N)',
_RAGGED_REDUCE_DOCSTRING = """\ _RAGGED_REDUCE_DOCSTRING = """\
Computes the %(combination)s of elements across dimensions of a `RaggedTensor`. Computes the %(combination)s of elements across dimensions of a `RaggedTensor`.
Reduces `rt_input` along the dimensions given in `axis` by taking the Reduces `input_tensor` along the dimensions given in `axis` by taking the
%(combination)s of values. If a reduced dimension has no elements for %(combination)s of values. If a reduced dimension has no elements for
some index, then the value for that index will be %(default)s. some index, then the value for that index will be %(default)s.
@ -319,18 +321,18 @@ Computes the %(combination)s of elements across dimensions of a `RaggedTensor`.
`axis` is not specified, then all dimensions are reduced, and a scalar `axis` is not specified, then all dimensions are reduced, and a scalar
value is returned. value is returned.
Args: Args:
rt_input: A `RaggedTensor` containing the values to be %(combined)s. input_tensor: A `RaggedTensor` containing the values to be %(combined)s.
axis: The dimensions to reduce. May be `None` (to reduce all axes), an axis: The dimensions to reduce. May be `None` (to reduce all axes), an
`int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce
a given set of axes), or a `Tensor` with a constant value. Must be in a given set of axes), or a `Tensor` with a constant value. Must be in
the range `[0, rt_input.rank]`. the range `[0, input_tensor.rank]`.
name: A name prefix for the returned tensor (optional). name: A name prefix for the returned tensor (optional).
Returns: Returns:
A `RaggedTensor` containing the %(combined)s values. The returned tensor A `RaggedTensor` containing the %(combined)s values. The returned tensor
has the same dtype as `data`, and its shape is given by removing the has the same dtype as `data`, and its shape is given by removing the
dimensions specified in `axis` from `rt_input.shape`. The `ragged_rank` dimensions specified in `axis` from `input_tensor.shape`. The `ragged_rank`
of the returned tensor is given by substracting any ragged dimensions of the returned tensor is given by substracting any ragged dimensions
specified in `axis` from `rt_input.ragged_rank`. specified in `axis` from `input_tensor.ragged_rank`.
Raises: Raises:
ValueError: If `axis` contains a `Tensor` whose value is not constant. ValueError: If `axis` contains a `Tensor` whose value is not constant.
####Example: ####Example:
@ -387,7 +389,11 @@ _RAGGED_REDUCE_ANY_EXAMPLE = """
""" """
def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis, def _ragged_reduce_aggregate(reduce_op,
unsorted_segment_op,
rt_input,
axis,
keepdims,
name=None): name=None):
"""Aggregates across axes of a RaggedTensor using the given `Tensor` ops. """Aggregates across axes of a RaggedTensor using the given `Tensor` ops.
@ -412,6 +418,7 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis,
`int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce a `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce a
given set of axes), or a `Tensor` with a constant value. Must be in the given set of axes), or a `Tensor` with a constant value. Must be in the
range `[0, rt_input.rank)`. range `[0, rt_input.rank)`.
keepdims: If true, retains reduced dimensions with length 1.
name: A name prefix for the returned tensor (optional). name: A name prefix for the returned tensor (optional).
Returns: Returns:
@ -426,6 +433,9 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis,
if not ragged_tensor.is_ragged(rt_input): if not ragged_tensor.is_ragged(rt_input):
return reduce_op(rt_input, axis, name=name) return reduce_op(rt_input, axis, name=name)
if keepdims:
raise ValueError('keepdims=True is not supported for RaggedTensors.')
if isinstance(axis, ops.Tensor): if isinstance(axis, ops.Tensor):
axis = tensor_util.constant_value(axis) axis = tensor_util.constant_value(axis)
if axis is None: if axis is None:
@ -448,9 +458,9 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis,
# once will probably require a nontrivial c++ op. # once will probably require a nontrivial c++ op.
axis = sorted(axis) axis = sorted(axis)
inner_reduced = _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, inner_reduced = _ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
rt_input, axis[-1]) rt_input, axis[-1], keepdims)
return _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, return _ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
inner_reduced, axis[:-1]) inner_reduced, axis[:-1], keepdims)
axis = ragged_util.get_positive_axis(axis, rt_input.shape.ndims) axis = ragged_util.get_positive_axis(axis, rt_input.shape.ndims)
@ -476,48 +486,48 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis,
# sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N] # sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N]
return rt_input.with_values( return rt_input.with_values(
_ragged_reduce_aggregate(reduce_op, unsorted_segment_op, _ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
rt_input.values, axis - 1)) rt_input.values, axis - 1, keepdims))
def reduce_sum(rt_input, axis=None, name=None): def reduce_sum(input_tensor, axis=None, keepdims=None, name=None):
"""For docs, see: _RAGGED_REDUCE_DOCSTRING.""" """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
return _ragged_reduce_aggregate(math_ops.reduce_sum, return _ragged_reduce_aggregate(math_ops.reduce_sum,
math_ops.unsorted_segment_sum, rt_input, axis, math_ops.unsorted_segment_sum, input_tensor,
name or 'RaggedReduceSum') axis, keepdims, name or 'RaggedReduceSum')
def reduce_prod(rt_input, axis=None, name=None): def reduce_prod(input_tensor, axis=None, keepdims=None, name=None):
"""For docs, see: _RAGGED_REDUCE_DOCSTRING.""" """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
return _ragged_reduce_aggregate(math_ops.reduce_prod, return _ragged_reduce_aggregate(math_ops.reduce_prod,
math_ops.unsorted_segment_prod, rt_input, math_ops.unsorted_segment_prod, input_tensor,
axis, name or 'RaggedReduceProd') axis, keepdims, name or 'RaggedReduceProd')
def reduce_min(rt_input, axis=None, name=None): def reduce_min(input_tensor, axis=None, keepdims=None, name=None):
"""For docs, see: _RAGGED_REDUCE_DOCSTRING.""" """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
return _ragged_reduce_aggregate(math_ops.reduce_min, return _ragged_reduce_aggregate(math_ops.reduce_min,
math_ops.unsorted_segment_min, rt_input, axis, math_ops.unsorted_segment_min, input_tensor,
name or 'RaggedReduceMin') axis, keepdims, name or 'RaggedReduceMin')
def reduce_max(rt_input, axis=None, name=None): def reduce_max(input_tensor, axis=None, keepdims=None, name=None):
"""For docs, see: _RAGGED_REDUCE_DOCSTRING.""" """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
return _ragged_reduce_aggregate(math_ops.reduce_max, return _ragged_reduce_aggregate(math_ops.reduce_max,
math_ops.unsorted_segment_max, rt_input, axis, math_ops.unsorted_segment_max, input_tensor,
name or 'RaggedReduceMax') axis, keepdims, name or 'RaggedReduceMax')
def reduce_mean(rt_input, axis=None, name=None): def reduce_mean(input_tensor, axis=None, keepdims=None, name=None):
"""For docs, see: _RAGGED_REDUCE_DOCSTRING.""" """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
with ops.name_scope(name, 'RaggedReduceMean', [rt_input, axis]): with ops.name_scope(name, 'RaggedReduceMean', [input_tensor, axis]):
total = reduce_sum(rt_input, axis) total = reduce_sum(input_tensor, axis, keepdims)
if ragged_tensor.is_ragged(rt_input): if ragged_tensor.is_ragged(input_tensor):
ones = ragged_factory_ops.from_nested_row_splits( ones = ragged_factory_ops.from_nested_row_splits(
array_ops.ones_like(rt_input.inner_values), array_ops.ones_like(input_tensor.inner_values),
rt_input.nested_row_splits) input_tensor.nested_row_splits)
else: else:
ones = array_ops.ones_like(rt_input) ones = array_ops.ones_like(input_tensor)
count = reduce_sum(ones, axis) count = reduce_sum(ones, axis, keepdims)
if ragged_tensor.is_ragged(total): if ragged_tensor.is_ragged(total):
return ragged_factory_ops.from_nested_row_splits( return ragged_factory_ops.from_nested_row_splits(
total.inner_values / count.inner_values, total.nested_row_splits) total.inner_values / count.inner_values, total.nested_row_splits)
@ -525,20 +535,25 @@ def reduce_mean(rt_input, axis=None, name=None):
return total / count return total / count
def _cast(rt_input, dtype): def _cast(input_tensor, dtype):
return ragged_functional_ops.map_inner_values(math_ops.cast, rt_input, dtype) return ragged_functional_ops.map_inner_values(math_ops.cast, input_tensor,
dtype)
def reduce_all(rt_input, axis=None, name=None): def reduce_all(input_tensor, axis=None, keepdims=None, name=None):
"""For docs, see: _RAGGED_REDUCE_DOCSTRING.""" """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
with ops.name_scope(name, 'RaggedReduceAll', [rt_input, axis]): with ops.name_scope(name, 'RaggedReduceAll', [input_tensor, axis]):
return _cast(reduce_prod(_cast(rt_input, dtypes.int32), axis), dtypes.bool) return _cast(
reduce_prod(_cast(input_tensor, dtypes.int32), axis, keepdims),
dtypes.bool)
def reduce_any(rt_input, axis=None, name=None): def reduce_any(input_tensor, axis=None, keepdims=None, name=None):
"""For docs, see: _RAGGED_REDUCE_DOCSTRING.""" """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
with ops.name_scope(name, 'RaggedReduceAny', [rt_input, axis]): with ops.name_scope(name, 'RaggedReduceAny', [input_tensor, axis]):
return _cast(reduce_sum(_cast(rt_input, dtypes.int32), axis), dtypes.bool) return _cast(
reduce_sum(_cast(input_tensor, dtypes.int32), axis, keepdims),
dtypes.bool)
def _set_ragged_reduce_docstring(func, combination, combined, default, example): def _set_ragged_reduce_docstring(func, combination, combined, default, example):
@ -554,9 +569,11 @@ _set_ragged_reduce_docstring(reduce_sum, 'sum', 'summed', '0',
_set_ragged_reduce_docstring(reduce_prod, 'product', 'multiplied', '1', _set_ragged_reduce_docstring(reduce_prod, 'product', 'multiplied', '1',
_RAGGED_REDUCE_PROD_EXAMPLE) _RAGGED_REDUCE_PROD_EXAMPLE)
_set_ragged_reduce_docstring(reduce_min, 'minimum', 'minimized', _set_ragged_reduce_docstring(reduce_min, 'minimum', 'minimized',
'`rt_input.dtype.min`', _RAGGED_REDUCE_MIN_EXAMPLE) '`input_tensor.dtype.min`',
_RAGGED_REDUCE_MIN_EXAMPLE)
_set_ragged_reduce_docstring(reduce_max, 'maximum', 'maximized', _set_ragged_reduce_docstring(reduce_max, 'maximum', 'maximized',
'`rt_input.dtype.max`', _RAGGED_REDUCE_MAX_EXAMPLE) '`input_tensor.dtype.max`',
_RAGGED_REDUCE_MAX_EXAMPLE)
_set_ragged_reduce_docstring(reduce_mean, 'mean', 'averaged', 'NaN', _set_ragged_reduce_docstring(reduce_mean, 'mean', 'averaged', 'NaN',
_RAGGED_REDUCE_MEAN_EXAMPLE) _RAGGED_REDUCE_MEAN_EXAMPLE)

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.ops.ragged import ragged_elementwise_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_getitem from tensorflow.python.ops.ragged import ragged_getitem
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_decorator
@ -33,40 +33,39 @@ def _right(operator):
ragged_tensor.RaggedTensor.__getitem__ = ragged_getitem.ragged_tensor_getitem ragged_tensor.RaggedTensor.__getitem__ = ragged_getitem.ragged_tensor_getitem
# Ordering operators # Ordering operators
ragged_tensor.RaggedTensor.__ge__ = ragged_elementwise_ops.greater_equal ragged_tensor.RaggedTensor.__ge__ = math_ops.greater_equal
ragged_tensor.RaggedTensor.__gt__ = ragged_elementwise_ops.greater ragged_tensor.RaggedTensor.__gt__ = math_ops.greater
ragged_tensor.RaggedTensor.__le__ = ragged_elementwise_ops.less_equal ragged_tensor.RaggedTensor.__le__ = math_ops.less_equal
ragged_tensor.RaggedTensor.__lt__ = ragged_elementwise_ops.less ragged_tensor.RaggedTensor.__lt__ = math_ops.less
# Logical operators # Logical operators
ragged_tensor.RaggedTensor.__and__ = ragged_elementwise_ops.logical_and ragged_tensor.RaggedTensor.__and__ = math_ops.logical_and
ragged_tensor.RaggedTensor.__rand__ = _right(ragged_elementwise_ops.logical_and) ragged_tensor.RaggedTensor.__rand__ = _right(math_ops.logical_and)
ragged_tensor.RaggedTensor.__invert__ = ragged_elementwise_ops.logical_not ragged_tensor.RaggedTensor.__invert__ = math_ops.logical_not
ragged_tensor.RaggedTensor.__ror__ = _right(ragged_elementwise_ops.logical_or) ragged_tensor.RaggedTensor.__ror__ = _right(math_ops.logical_or)
ragged_tensor.RaggedTensor.__or__ = ragged_elementwise_ops.logical_or ragged_tensor.RaggedTensor.__or__ = math_ops.logical_or
ragged_tensor.RaggedTensor.__xor__ = ragged_elementwise_ops.logical_xor ragged_tensor.RaggedTensor.__xor__ = math_ops.logical_xor
ragged_tensor.RaggedTensor.__rxor__ = _right(ragged_elementwise_ops.logical_xor) ragged_tensor.RaggedTensor.__rxor__ = _right(math_ops.logical_xor)
# Arithmetic operators # Arithmetic operators
ragged_tensor.RaggedTensor.__abs__ = ragged_elementwise_ops.abs ragged_tensor.RaggedTensor.__abs__ = math_ops.abs
ragged_tensor.RaggedTensor.__add__ = ragged_elementwise_ops.add ragged_tensor.RaggedTensor.__add__ = math_ops.add
ragged_tensor.RaggedTensor.__radd__ = _right(ragged_elementwise_ops.add) ragged_tensor.RaggedTensor.__radd__ = _right(math_ops.add)
ragged_tensor.RaggedTensor.__div__ = ragged_elementwise_ops.div ragged_tensor.RaggedTensor.__div__ = math_ops.div
ragged_tensor.RaggedTensor.__rdiv__ = _right(ragged_elementwise_ops.div) ragged_tensor.RaggedTensor.__rdiv__ = _right(math_ops.div)
ragged_tensor.RaggedTensor.__floordiv__ = ragged_elementwise_ops.floordiv ragged_tensor.RaggedTensor.__floordiv__ = math_ops.floordiv
ragged_tensor.RaggedTensor.__rfloordiv__ = _right( ragged_tensor.RaggedTensor.__rfloordiv__ = _right(math_ops.floordiv)
ragged_elementwise_ops.floordiv) ragged_tensor.RaggedTensor.__mod__ = math_ops.floormod
ragged_tensor.RaggedTensor.__mod__ = ragged_elementwise_ops.floormod ragged_tensor.RaggedTensor.__rmod__ = _right(math_ops.floormod)
ragged_tensor.RaggedTensor.__rmod__ = _right(ragged_elementwise_ops.floormod) ragged_tensor.RaggedTensor.__mul__ = math_ops.multiply
ragged_tensor.RaggedTensor.__mul__ = ragged_elementwise_ops.multiply ragged_tensor.RaggedTensor.__rmul__ = _right(math_ops.multiply)
ragged_tensor.RaggedTensor.__rmul__ = _right(ragged_elementwise_ops.multiply) ragged_tensor.RaggedTensor.__neg__ = math_ops.negative
ragged_tensor.RaggedTensor.__neg__ = ragged_elementwise_ops.negative ragged_tensor.RaggedTensor.__pow__ = math_ops.pow
ragged_tensor.RaggedTensor.__pow__ = ragged_elementwise_ops.pow ragged_tensor.RaggedTensor.__rpow__ = _right(math_ops.pow)
ragged_tensor.RaggedTensor.__rpow__ = _right(ragged_elementwise_ops.pow) ragged_tensor.RaggedTensor.__sub__ = math_ops.subtract
ragged_tensor.RaggedTensor.__sub__ = ragged_elementwise_ops.subtract ragged_tensor.RaggedTensor.__rsub__ = _right(math_ops.subtract)
ragged_tensor.RaggedTensor.__rsub__ = _right(ragged_elementwise_ops.subtract) ragged_tensor.RaggedTensor.__truediv__ = math_ops.truediv
ragged_tensor.RaggedTensor.__truediv__ = ragged_elementwise_ops.truediv ragged_tensor.RaggedTensor.__rtruediv__ = _right(math_ops.truediv)
ragged_tensor.RaggedTensor.__rtruediv__ = _right(ragged_elementwise_ops.truediv)
# Dummy methods # Dummy methods

View File

@ -2696,6 +2696,7 @@ class _UnaryMapValueDispatcher(dispatch.OpDispatcher):
if args: if args:
x, args = args[0], args[1:] x, args = args[0], args[1:]
else: else:
kwargs = kwargs.copy()
x = kwargs.pop(self._x, None) x = kwargs.pop(self._x, None)
if isinstance(x, sparse_tensor.SparseTensor): if isinstance(x, sparse_tensor.SparseTensor):
return sparse_tensor.SparseTensor( return sparse_tensor.SparseTensor(

View File

@ -38,6 +38,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops.gen_string_ops import * from tensorflow.python.ops.gen_string_ops import *
from tensorflow.python.util import compat as util_compat from tensorflow.python.util import compat as util_compat
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
# pylint: enable=wildcard-import # pylint: enable=wildcard-import
@ -45,6 +46,7 @@ from tensorflow.python.util.tf_export import tf_export
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
@tf_export("strings.regex_full_match") @tf_export("strings.regex_full_match")
@dispatch.add_dispatch_support
def regex_full_match(input, pattern, name=None): def regex_full_match(input, pattern, name=None):
r"""Match elements of `input` with regex `pattern`. r"""Match elements of `input` with regex `pattern`.
@ -76,6 +78,7 @@ regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__
@tf_export( @tf_export(
"strings.regex_replace", v1=["strings.regex_replace", "regex_replace"]) "strings.regex_replace", v1=["strings.regex_replace", "regex_replace"])
@deprecation.deprecated_endpoints("regex_replace") @deprecation.deprecated_endpoints("regex_replace")
@dispatch.add_dispatch_support
def regex_replace(input, pattern, rewrite, replace_global=True, name=None): def regex_replace(input, pattern, rewrite, replace_global=True, name=None):
r"""Replace elements of `input` matching regex `pattern` with `rewrite`. r"""Replace elements of `input` matching regex `pattern` with `rewrite`.
@ -350,10 +353,13 @@ reduce_join.__doc__ = reduce_join.__doc__.replace("tf.reduce_join(",
# This wrapper provides backwards compatibility for code that predates the # This wrapper provides backwards compatibility for code that predates the
# unit argument and that passed 'name' as a positional argument. # unit argument and that passed 'name' as a positional argument.
@tf_export(v1=["strings.length"]) @tf_export(v1=["strings.length"])
@dispatch.add_dispatch_support
def string_length(input, name=None, unit="BYTE"): def string_length(input, name=None, unit="BYTE"):
return gen_string_ops.string_length(input, unit=unit, name=name) return gen_string_ops.string_length(input, unit=unit, name=name)
@tf_export("strings.length", v1=[]) @tf_export("strings.length", v1=[])
@dispatch.add_dispatch_support
def string_length_v2(input, unit="BYTE", name=None): def string_length_v2(input, unit="BYTE", name=None):
return string_length(input, name, unit) return string_length(input, name, unit)
@ -370,11 +376,13 @@ substr_deprecated.__doc__ = gen_string_ops.substr.__doc__
@tf_export(v1=["strings.substr"]) @tf_export(v1=["strings.substr"])
@dispatch.add_dispatch_support
def substr(input, pos, len, name=None, unit="BYTE"): def substr(input, pos, len, name=None, unit="BYTE"):
return gen_string_ops.substr(input, pos, len, unit=unit, name=name) return gen_string_ops.substr(input, pos, len, unit=unit, name=name)
@tf_export("strings.substr", v1=[]) @tf_export("strings.substr", v1=[])
@dispatch.add_dispatch_support
def substr_v2(input, pos, len, unit="BYTE", name=None): def substr_v2(input, pos, len, unit="BYTE", name=None):
return substr(input, pos, len, name=name, unit=unit) return substr(input, pos, len, name=name, unit=unit)
@ -395,6 +403,7 @@ ops.NotDifferentiable("DecodeBase64")
@tf_export("strings.to_number", v1=[]) @tf_export("strings.to_number", v1=[])
@dispatch.add_dispatch_support
def string_to_number(input, out_type=dtypes.float32, name=None): def string_to_number(input, out_type=dtypes.float32, name=None):
r"""Converts each string in the input Tensor to the specified numeric type. r"""Converts each string in the input Tensor to the specified numeric type.
@ -418,6 +427,7 @@ tf_export(v1=["strings.to_number", "string_to_number"])(
@tf_export("strings.to_hash_bucket", v1=[]) @tf_export("strings.to_hash_bucket", v1=[])
@dispatch.add_dispatch_support
def string_to_hash_bucket(input, num_buckets, name=None): def string_to_hash_bucket(input, num_buckets, name=None):
# pylint: disable=line-too-long # pylint: disable=line-too-long
r"""Converts each string in the input Tensor to its hash mod by a number of buckets. r"""Converts each string in the input Tensor to its hash mod by a number of buckets.

View File

@ -166,15 +166,14 @@ def dispatch_for_types(op, *types):
def add_dispatch_list(target): def add_dispatch_list(target):
"""Decorator that adds a dispatch_list attribute to an op.""" """Decorator that adds a dispatch_list attribute to an op."""
assert not hasattr(target, DISPATCH_ATTR) if hasattr(target, DISPATCH_ATTR):
raise AssertionError("%s already has a dispatch list" % target)
setattr(target, DISPATCH_ATTR, []) setattr(target, DISPATCH_ATTR, [])
return target return target
def add_dispatch_support(target): def add_dispatch_support(target):
"""Decorator that adds a dispatch handling wrapper to an op.""" """Decorator that adds a dispatch handling wrapper to an op."""
add_dispatch_list(target)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
"""Call target, and fall back on dispatchers if there is a TypeError.""" """Call target, and fall back on dispatchers if there is a TypeError."""
try: try:
@ -188,5 +187,5 @@ def add_dispatch_support(target):
else: else:
raise raise
setattr(wrapper, DISPATCH_ATTR, []) add_dispatch_list(wrapper)
return tf_decorator.make_decorator(target, wrapper) return tf_decorator.make_decorator(target, wrapper)