Internal Change
PiperOrigin-RevId: 224225849
This commit is contained in:
parent
2b0fd9b66a
commit
45cfe71266
@ -1,4 +1,6 @@
|
|||||||
op {
|
op {
|
||||||
graph_op_name: "FloorDiv"
|
graph_op_name: "FloorDiv"
|
||||||
visibility: HIDDEN
|
endpoint {
|
||||||
|
name: "floor_div"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,9 @@
|
|||||||
op {
|
op {
|
||||||
graph_op_name: "FloorMod"
|
graph_op_name: "FloorMod"
|
||||||
visibility: HIDDEN
|
endpoint {
|
||||||
|
name: "floormod"
|
||||||
|
}
|
||||||
|
endpoint {
|
||||||
|
name: "mod"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
op {
|
op {
|
||||||
graph_op_name: "RealDiv"
|
graph_op_name: "RealDiv"
|
||||||
visibility: HIDDEN
|
endpoint {
|
||||||
|
name: "realdiv"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
op {
|
op {
|
||||||
graph_op_name: "TruncateDiv"
|
graph_op_name: "TruncateDiv"
|
||||||
visibility: HIDDEN
|
endpoint {
|
||||||
|
name: "truncatediv"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
op {
|
op {
|
||||||
graph_op_name: "TruncateMod"
|
graph_op_name: "TruncateMod"
|
||||||
visibility: HIDDEN
|
endpoint {
|
||||||
|
name: "truncatemod"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -634,7 +634,9 @@ void GenEagerPythonOp::AddEagerFunctionTeardown(
|
|||||||
bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
|
bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
|
||||||
const string& parameters, const std::vector<string>& output_sizes,
|
const string& parameters, const std::vector<string>& output_sizes,
|
||||||
const string& eager_not_allowed_error) {
|
const string& eager_not_allowed_error) {
|
||||||
|
if (api_def_.visibility() == ApiDef::VISIBLE) {
|
||||||
strings::StrAppend(&result_, "@_dispatch.add_dispatch_list\n");
|
strings::StrAppend(&result_, "@_dispatch.add_dispatch_list\n");
|
||||||
|
}
|
||||||
AddExport();
|
AddExport();
|
||||||
AddDefLine(function_name_, parameters);
|
AddDefLine(function_name_, parameters);
|
||||||
AddDocStringDescription();
|
AddDocStringDescription();
|
||||||
|
@ -56,6 +56,7 @@ _BaseSlice = slice
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("identity")
|
@tf_export("identity")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def identity(input, name=None): # pylint: disable=redefined-builtin
|
def identity(input, name=None): # pylint: disable=redefined-builtin
|
||||||
r"""Return a tensor with the same shape and contents as input.
|
r"""Return a tensor with the same shape and contents as input.
|
||||||
|
|
||||||
@ -139,6 +140,7 @@ def expand_dims(input, axis=None, name=None, dim=None):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("expand_dims", v1=[])
|
@tf_export("expand_dims", v1=[])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def expand_dims_v2(input, axis, name=None):
|
def expand_dims_v2(input, axis, name=None):
|
||||||
"""Inserts a dimension of 1 into a tensor's shape.
|
"""Inserts a dimension of 1 into a tensor's shape.
|
||||||
|
|
||||||
@ -941,6 +943,7 @@ def parallel_stack(values, name="parallel_stack"):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("stack")
|
@tf_export("stack")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def stack(values, axis=0, name="stack"):
|
def stack(values, axis=0, name="stack"):
|
||||||
"""Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.
|
"""Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.
|
||||||
|
|
||||||
@ -1151,6 +1154,7 @@ def unstack(value, num=None, axis=0, name="unstack"):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("concat")
|
@tf_export("concat")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def concat(values, axis, name="concat"):
|
def concat(values, axis, name="concat"):
|
||||||
"""Concatenates tensors along one dimension.
|
"""Concatenates tensors along one dimension.
|
||||||
|
|
||||||
@ -1328,6 +1332,7 @@ def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("boolean_mask", v1=[])
|
@tf_export("boolean_mask", v1=[])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def boolean_mask_v2(tensor, mask, axis=None, name="boolean_mask"):
|
def boolean_mask_v2(tensor, mask, axis=None, name="boolean_mask"):
|
||||||
"""Apply boolean mask to tensor.
|
"""Apply boolean mask to tensor.
|
||||||
|
|
||||||
@ -1810,6 +1815,7 @@ def zeros(shape, dtype=dtypes.float32, name=None):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["zeros_like"])
|
@tf_export(v1=["zeros_like"])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def zeros_like(tensor, dtype=None, name=None, optimize=True):
|
def zeros_like(tensor, dtype=None, name=None, optimize=True):
|
||||||
"""Creates a tensor with all elements set to zero.
|
"""Creates a tensor with all elements set to zero.
|
||||||
|
|
||||||
@ -1840,6 +1846,7 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("zeros_like", v1=[])
|
@tf_export("zeros_like", v1=[])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def zeros_like_v2(
|
def zeros_like_v2(
|
||||||
input, # pylint: disable=redefined-builtin
|
input, # pylint: disable=redefined-builtin
|
||||||
dtype=None,
|
dtype=None,
|
||||||
@ -1899,6 +1906,7 @@ def zeros_like_impl(tensor, dtype, name, optimize=True):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["ones_like"])
|
@tf_export(v1=["ones_like"])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def ones_like(tensor, dtype=None, name=None, optimize=True):
|
def ones_like(tensor, dtype=None, name=None, optimize=True):
|
||||||
"""Creates a tensor with all elements set to 1.
|
"""Creates a tensor with all elements set to 1.
|
||||||
|
|
||||||
@ -1929,6 +1937,7 @@ def ones_like(tensor, dtype=None, name=None, optimize=True):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("ones_like", v1=[])
|
@tf_export("ones_like", v1=[])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def ones_like_v2(
|
def ones_like_v2(
|
||||||
input, # pylint: disable=redefined-builtin
|
input, # pylint: disable=redefined-builtin
|
||||||
dtype=None,
|
dtype=None,
|
||||||
@ -3115,6 +3124,7 @@ def squeeze_v2(input, axis=None, name=None):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("where")
|
@tf_export("where")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def where(condition, x=None, y=None, name=None):
|
def where(condition, x=None, y=None, name=None):
|
||||||
"""Return the elements, either from `x` or `y`, depending on the `condition`.
|
"""Return the elements, either from `x` or `y`, depending on the `condition`.
|
||||||
|
|
||||||
@ -3234,6 +3244,7 @@ def gather(params, indices, validate_indices=None, name=None, axis=0):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("gather", v1=[])
|
@tf_export("gather", v1=[])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def gather_v2(params, indices, validate_indices=None, axis=0, name=None):
|
def gather_v2(params, indices, validate_indices=None, axis=0, name=None):
|
||||||
return gather(params, indices, validate_indices=validate_indices, name=name,
|
return gather(params, indices, validate_indices=validate_indices, name=name,
|
||||||
axis=axis)
|
axis=axis)
|
||||||
|
@ -31,10 +31,12 @@ from tensorflow.python.ops import gen_nn_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import numerics
|
from tensorflow.python.ops import numerics
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
|
from tensorflow.python.util import dispatch
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
@tf_export("clip_by_value")
|
@tf_export("clip_by_value")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def clip_by_value(t, clip_value_min, clip_value_max,
|
def clip_by_value(t, clip_value_min, clip_value_max,
|
||||||
name=None):
|
name=None):
|
||||||
"""Clips tensor values to a specified min and max.
|
"""Clips tensor values to a specified min and max.
|
||||||
|
@ -230,6 +230,7 @@ class DivideDelegateWithName(object):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("math.divide", "divide")
|
@tf_export("math.divide", "divide")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def divide(x, y, name=None):
|
def divide(x, y, name=None):
|
||||||
"""Computes Python style division of `x` by `y`."""
|
"""Computes Python style division of `x` by `y`."""
|
||||||
|
|
||||||
@ -242,6 +243,7 @@ def divide(x, y, name=None):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("math.multiply", "multiply")
|
@tf_export("math.multiply", "multiply")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def multiply(x, y, name=None):
|
def multiply(x, y, name=None):
|
||||||
return gen_math_ops.mul(x, y, name)
|
return gen_math_ops.mul(x, y, name)
|
||||||
|
|
||||||
@ -262,6 +264,7 @@ _mul.__doc__ = (
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("math.subtract", "subtract")
|
@tf_export("math.subtract", "subtract")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def subtract(x, y, name=None):
|
def subtract(x, y, name=None):
|
||||||
return gen_math_ops.sub(x, y, name)
|
return gen_math_ops.sub(x, y, name)
|
||||||
|
|
||||||
@ -347,6 +350,7 @@ def scalar_mul_v2(scalar, x, name=None):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("math.pow", "pow")
|
@tf_export("math.pow", "pow")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def pow(x, y, name=None): # pylint: disable=redefined-builtin
|
def pow(x, y, name=None): # pylint: disable=redefined-builtin
|
||||||
r"""Computes the power of one value to another.
|
r"""Computes the power of one value to another.
|
||||||
|
|
||||||
@ -375,6 +379,7 @@ def pow(x, y, name=None): # pylint: disable=redefined-builtin
|
|||||||
|
|
||||||
# pylint: disable=redefined-builtin,redefined-outer-name
|
# pylint: disable=redefined-builtin,redefined-outer-name
|
||||||
@tf_export("dtypes.complex", "complex")
|
@tf_export("dtypes.complex", "complex")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def complex(real, imag, name=None):
|
def complex(real, imag, name=None):
|
||||||
r"""Converts two real numbers to a complex number.
|
r"""Converts two real numbers to a complex number.
|
||||||
|
|
||||||
@ -418,6 +423,7 @@ def complex(real, imag, name=None):
|
|||||||
|
|
||||||
@tf_export("math.real", v1=["math.real", "real"])
|
@tf_export("math.real", v1=["math.real", "real"])
|
||||||
@deprecation.deprecated_endpoints("real")
|
@deprecation.deprecated_endpoints("real")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def real(input, name=None):
|
def real(input, name=None):
|
||||||
r"""Returns the real part of a complex (or real) tensor.
|
r"""Returns the real part of a complex (or real) tensor.
|
||||||
|
|
||||||
@ -450,6 +456,7 @@ def real(input, name=None):
|
|||||||
|
|
||||||
@tf_export("math.imag", v1=["math.imag", "imag"])
|
@tf_export("math.imag", v1=["math.imag", "imag"])
|
||||||
@deprecation.deprecated_endpoints("imag")
|
@deprecation.deprecated_endpoints("imag")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def imag(input, name=None):
|
def imag(input, name=None):
|
||||||
r"""Returns the imaginary part of a complex (or real) tensor.
|
r"""Returns the imaginary part of a complex (or real) tensor.
|
||||||
|
|
||||||
@ -481,6 +488,7 @@ def imag(input, name=None):
|
|||||||
|
|
||||||
@tf_export("math.angle", v1=["math.angle", "angle"])
|
@tf_export("math.angle", v1=["math.angle", "angle"])
|
||||||
@deprecation.deprecated_endpoints("angle")
|
@deprecation.deprecated_endpoints("angle")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def angle(input, name=None):
|
def angle(input, name=None):
|
||||||
r"""Returns the element-wise argument of a complex (or real) tensor.
|
r"""Returns the element-wise argument of a complex (or real) tensor.
|
||||||
|
|
||||||
@ -520,6 +528,7 @@ def angle(input, name=None):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("math.round", "round")
|
@tf_export("math.round", "round")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def round(x, name=None): # pylint: disable=redefined-builtin
|
def round(x, name=None): # pylint: disable=redefined-builtin
|
||||||
"""Rounds the values of a tensor to the nearest integer, element-wise.
|
"""Rounds the values of a tensor to the nearest integer, element-wise.
|
||||||
|
|
||||||
@ -547,6 +556,7 @@ def round(x, name=None): # pylint: disable=redefined-builtin
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("dtypes.cast", "cast")
|
@tf_export("dtypes.cast", "cast")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def cast(x, dtype, name=None):
|
def cast(x, dtype, name=None):
|
||||||
"""Casts a tensor to a new type.
|
"""Casts a tensor to a new type.
|
||||||
|
|
||||||
@ -610,6 +620,7 @@ def cast(x, dtype, name=None):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("dtypes.saturate_cast", "saturate_cast")
|
@tf_export("dtypes.saturate_cast", "saturate_cast")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def saturate_cast(value, dtype, name=None):
|
def saturate_cast(value, dtype, name=None):
|
||||||
"""Performs a safe saturating cast of `value` to `dtype`.
|
"""Performs a safe saturating cast of `value` to `dtype`.
|
||||||
|
|
||||||
@ -935,6 +946,7 @@ def _div_python2(x, y, name=None):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("math.truediv", "truediv")
|
@tf_export("math.truediv", "truediv")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def truediv(x, y, name=None):
|
def truediv(x, y, name=None):
|
||||||
"""Divides x / y elementwise (using Python 3 division operator semantics).
|
"""Divides x / y elementwise (using Python 3 division operator semantics).
|
||||||
|
|
||||||
@ -992,6 +1004,7 @@ def div(x, y, name=None):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("div_no_nan")
|
@tf_export("div_no_nan")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def div_no_nan(x, y, name=None):
|
def div_no_nan(x, y, name=None):
|
||||||
"""Computes an unsafe divide which returns 0 if the y is zero.
|
"""Computes an unsafe divide which returns 0 if the y is zero.
|
||||||
|
|
||||||
@ -1021,6 +1034,7 @@ mod = gen_math_ops.floor_mod
|
|||||||
# TODO(aselle): Deprecate this once all internal functionality uses
|
# TODO(aselle): Deprecate this once all internal functionality uses
|
||||||
# tf.truncatediv
|
# tf.truncatediv
|
||||||
@tf_export("math.floordiv", v1=["math.floordiv", "floordiv"])
|
@tf_export("math.floordiv", v1=["math.floordiv", "floordiv"])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
@deprecation.deprecated_endpoints("floordiv")
|
@deprecation.deprecated_endpoints("floordiv")
|
||||||
def floordiv(x, y, name=None):
|
def floordiv(x, y, name=None):
|
||||||
"""Divides `x / y` elementwise, rounding toward the most negative integer.
|
"""Divides `x / y` elementwise, rounding toward the most negative integer.
|
||||||
@ -1050,16 +1064,11 @@ def floordiv(x, y, name=None):
|
|||||||
|
|
||||||
|
|
||||||
realdiv = gen_math_ops.real_div
|
realdiv = gen_math_ops.real_div
|
||||||
tf_export("realdiv")(realdiv)
|
|
||||||
truncatediv = gen_math_ops.truncate_div
|
truncatediv = gen_math_ops.truncate_div
|
||||||
tf_export("truncatediv")(truncatediv)
|
|
||||||
# TODO(aselle): Rename this to floordiv when we can.
|
# TODO(aselle): Rename this to floordiv when we can.
|
||||||
floor_div = gen_math_ops.floor_div
|
floor_div = gen_math_ops.floor_div
|
||||||
tf_export("floor_div")(floor_div)
|
|
||||||
truncatemod = gen_math_ops.truncate_mod
|
truncatemod = gen_math_ops.truncate_mod
|
||||||
tf_export("truncatemod")(truncatemod)
|
|
||||||
floormod = gen_math_ops.floor_mod
|
floormod = gen_math_ops.floor_mod
|
||||||
tf_export("floormod", "mod")(floormod)
|
|
||||||
|
|
||||||
|
|
||||||
def _mul_dispatch(x, y, name=None):
|
def _mul_dispatch(x, y, name=None):
|
||||||
@ -1095,6 +1104,7 @@ _OverrideBinaryOperatorHelper(pow, "pow")
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("math.logical_xor", v1=["math.logical_xor", "logical_xor"])
|
@tf_export("math.logical_xor", v1=["math.logical_xor", "logical_xor"])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
@deprecation.deprecated_endpoints("logical_xor")
|
@deprecation.deprecated_endpoints("logical_xor")
|
||||||
def logical_xor(x, y, name="LogicalXor"):
|
def logical_xor(x, y, name="LogicalXor"):
|
||||||
"""x ^ y = (x | y) & ~(x & y)."""
|
"""x ^ y = (x | y) & ~(x & y)."""
|
||||||
@ -1277,6 +1287,7 @@ def reduce_sum_v1(input_tensor,
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("math.reduce_sum", "reduce_sum", v1=[])
|
@tf_export("math.reduce_sum", "reduce_sum", v1=[])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def reduce_sum(input_tensor, axis=None, keepdims=False, name=None):
|
def reduce_sum(input_tensor, axis=None, keepdims=False, name=None):
|
||||||
"""Computes the sum of elements across dimensions of a tensor.
|
"""Computes the sum of elements across dimensions of a tensor.
|
||||||
|
|
||||||
@ -1524,6 +1535,7 @@ def reduce_mean_v1(input_tensor,
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("math.reduce_mean", "reduce_mean", v1=[])
|
@tf_export("math.reduce_mean", "reduce_mean", v1=[])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def reduce_mean(input_tensor, axis=None, keepdims=False, name=None):
|
def reduce_mean(input_tensor, axis=None, keepdims=False, name=None):
|
||||||
"""Computes the mean of elements across dimensions of a tensor.
|
"""Computes the mean of elements across dimensions of a tensor.
|
||||||
|
|
||||||
@ -1675,6 +1687,7 @@ def reduce_std(input_tensor, axis=None, keepdims=False, name=None):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("math.reduce_prod", "reduce_prod", v1=[])
|
@tf_export("math.reduce_prod", "reduce_prod", v1=[])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def reduce_prod(input_tensor, axis=None, keepdims=False, name=None):
|
def reduce_prod(input_tensor, axis=None, keepdims=False, name=None):
|
||||||
"""Computes the product of elements across dimensions of a tensor.
|
"""Computes the product of elements across dimensions of a tensor.
|
||||||
|
|
||||||
@ -1796,6 +1809,7 @@ def reduce_min_v1(input_tensor,
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("math.reduce_min", "reduce_min", v1=[])
|
@tf_export("math.reduce_min", "reduce_min", v1=[])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def reduce_min(input_tensor, axis=None, keepdims=False, name=None):
|
def reduce_min(input_tensor, axis=None, keepdims=False, name=None):
|
||||||
"""Computes the minimum of elements across dimensions of a tensor.
|
"""Computes the minimum of elements across dimensions of a tensor.
|
||||||
|
|
||||||
@ -1874,6 +1888,7 @@ def reduce_max_v1(input_tensor,
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("math.reduce_max", "reduce_max", v1=[])
|
@tf_export("math.reduce_max", "reduce_max", v1=[])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def reduce_max(input_tensor, axis=None, keepdims=False, name=None):
|
def reduce_max(input_tensor, axis=None, keepdims=False, name=None):
|
||||||
"""Computes the maximum of elements across dimensions of a tensor.
|
"""Computes the maximum of elements across dimensions of a tensor.
|
||||||
|
|
||||||
@ -1961,6 +1976,7 @@ def reduce_all_v1(input_tensor,
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("reduce_all", "math.reduce_all", v1=[])
|
@tf_export("reduce_all", "math.reduce_all", v1=[])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def reduce_all(input_tensor, axis=None, keepdims=False, name=None):
|
def reduce_all(input_tensor, axis=None, keepdims=False, name=None):
|
||||||
"""Computes the "logical and" of elements across dimensions of a tensor.
|
"""Computes the "logical and" of elements across dimensions of a tensor.
|
||||||
|
|
||||||
@ -2057,6 +2073,7 @@ def reduce_any_v1(input_tensor,
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("math.reduce_any", "reduce_any", v1=[])
|
@tf_export("math.reduce_any", "reduce_any", v1=[])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def reduce_any(input_tensor, axis=None, keepdims=False, name=None):
|
def reduce_any(input_tensor, axis=None, keepdims=False, name=None):
|
||||||
"""Computes the "logical or" of elements across dimensions of a tensor.
|
"""Computes the "logical or" of elements across dimensions of a tensor.
|
||||||
|
|
||||||
@ -2619,6 +2636,7 @@ def _as_indexed_slices_list(inputs, optimize=True):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("math.add_n", "add_n")
|
@tf_export("math.add_n", "add_n")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def add_n(inputs, name=None):
|
def add_n(inputs, name=None):
|
||||||
"""Adds all input tensors element-wise.
|
"""Adds all input tensors element-wise.
|
||||||
|
|
||||||
@ -2764,6 +2782,7 @@ def sigmoid(x, name=None):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("math.log_sigmoid", v1=["math.log_sigmoid", "log_sigmoid"])
|
@tf_export("math.log_sigmoid", v1=["math.log_sigmoid", "log_sigmoid"])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
@deprecation.deprecated_endpoints("log_sigmoid")
|
@deprecation.deprecated_endpoints("log_sigmoid")
|
||||||
def log_sigmoid(x, name=None):
|
def log_sigmoid(x, name=None):
|
||||||
"""Computes log sigmoid of `x` element-wise.
|
"""Computes log sigmoid of `x` element-wise.
|
||||||
@ -2973,6 +2992,7 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("math.conj", v1=["math.conj", "conj"])
|
@tf_export("math.conj", v1=["math.conj", "conj"])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
@deprecation.deprecated_endpoints("conj")
|
@deprecation.deprecated_endpoints("conj")
|
||||||
def conj(x, name=None):
|
def conj(x, name=None):
|
||||||
r"""Returns the complex conjugate of a complex number.
|
r"""Returns the complex conjugate of a complex number.
|
||||||
@ -3077,6 +3097,7 @@ def _unsorted_segment_N(data, segment_ids, num_segments):
|
|||||||
"math.unsorted_segment_mean",
|
"math.unsorted_segment_mean",
|
||||||
v1=["math.unsorted_segment_mean", "unsorted_segment_mean"])
|
v1=["math.unsorted_segment_mean", "unsorted_segment_mean"])
|
||||||
@deprecation.deprecated_endpoints("unsorted_segment_mean")
|
@deprecation.deprecated_endpoints("unsorted_segment_mean")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
|
def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
|
||||||
r"""Computes the mean along segments of a tensor.
|
r"""Computes the mean along segments of a tensor.
|
||||||
|
|
||||||
@ -3122,6 +3143,7 @@ def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
|
|||||||
"math.unsorted_segment_sqrt_n",
|
"math.unsorted_segment_sqrt_n",
|
||||||
v1=["math.unsorted_segment_sqrt_n", "unsorted_segment_sqrt_n"])
|
v1=["math.unsorted_segment_sqrt_n", "unsorted_segment_sqrt_n"])
|
||||||
@deprecation.deprecated_endpoints("unsorted_segment_sqrt_n")
|
@deprecation.deprecated_endpoints("unsorted_segment_sqrt_n")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None):
|
def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None):
|
||||||
r"""Computes the sum along segments of a tensor divided by the sqrt(N).
|
r"""Computes the sum along segments of a tensor divided by the sqrt(N).
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ py_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":ragged_array_ops",
|
":ragged_array_ops",
|
||||||
":ragged_conversion_ops",
|
":ragged_conversion_ops",
|
||||||
":ragged_elementwise_ops",
|
":ragged_dispatch",
|
||||||
":ragged_factory_ops",
|
":ragged_factory_ops",
|
||||||
":ragged_functional_ops",
|
":ragged_functional_ops",
|
||||||
":ragged_getitem",
|
":ragged_getitem",
|
||||||
@ -150,33 +150,14 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "ragged_elementwise_ops",
|
|
||||||
srcs = ["ragged_elementwise_ops.py"],
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
deps = [
|
|
||||||
":ragged_factory_ops",
|
|
||||||
":ragged_tensor",
|
|
||||||
":ragged_tensor_shape",
|
|
||||||
":ragged_util",
|
|
||||||
"//tensorflow/python:array_ops",
|
|
||||||
"//tensorflow/python:clip_ops",
|
|
||||||
"//tensorflow/python:framework_ops",
|
|
||||||
"//tensorflow/python:math_ops",
|
|
||||||
"//tensorflow/python:parsing_ops",
|
|
||||||
"//tensorflow/python:string_ops",
|
|
||||||
"//tensorflow/python:util",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "ragged_operators",
|
name = "ragged_operators",
|
||||||
srcs = ["ragged_operators.py"],
|
srcs = ["ragged_operators.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":ragged_elementwise_ops",
|
|
||||||
":ragged_getitem",
|
":ragged_getitem",
|
||||||
":ragged_tensor",
|
":ragged_tensor",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -186,12 +167,13 @@ py_library(
|
|||||||
srcs = ["ragged_string_ops.py"],
|
srcs = ["ragged_string_ops.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":ragged_array_ops",
|
|
||||||
":ragged_conversion_ops",
|
":ragged_conversion_ops",
|
||||||
":ragged_factory_ops",
|
":ragged_factory_ops",
|
||||||
":ragged_tensor",
|
":ragged_tensor",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:string_ops",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -219,10 +201,11 @@ py_library(
|
|||||||
":ragged_tensor",
|
":ragged_tensor",
|
||||||
":ragged_util",
|
":ragged_util",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:constant_op",
|
||||||
|
"//tensorflow/python:control_flow_ops",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:tensor_shape",
|
|
||||||
"//tensorflow/python:tensor_util",
|
"//tensorflow/python:tensor_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -285,6 +268,29 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "ragged_dispatch",
|
||||||
|
srcs = ["ragged_dispatch.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":ragged_array_ops",
|
||||||
|
":ragged_factory_ops",
|
||||||
|
":ragged_math_ops",
|
||||||
|
":ragged_tensor",
|
||||||
|
":ragged_tensor_shape",
|
||||||
|
":ragged_util",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:clip_ops",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:parsing_ops",
|
||||||
|
"//tensorflow/python:sparse_tensor",
|
||||||
|
"//tensorflow/python:string_ops",
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
#-------------------------------------------------------------------------------
|
#-------------------------------------------------------------------------------
|
||||||
# RaggedTensor Tests
|
# RaggedTensor Tests
|
||||||
#-------------------------------------------------------------------------------
|
#-------------------------------------------------------------------------------
|
||||||
@ -458,6 +464,7 @@ py_test(
|
|||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:gradients_impl",
|
"//tensorflow/python:gradients_impl",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -684,17 +691,21 @@ py_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "ragged_elementwise_ops_test",
|
name = "ragged_dispatch_test",
|
||||||
srcs = ["ragged_elementwise_ops_test.py"],
|
srcs = ["ragged_dispatch_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":ragged",
|
":ragged",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:clip_ops",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:parsing_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
|
"//tensorflow/python:string_ops",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
@ -725,6 +736,7 @@ py_test(
|
|||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:string_ops",
|
"//tensorflow/python:string_ops",
|
||||||
"//tensorflow/python/keras:backend",
|
"//tensorflow/python/keras:backend",
|
||||||
|
"//third_party/py/numpy",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -735,8 +747,10 @@ py_test(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":ragged",
|
":ragged",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
|
"//third_party/py/numpy",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -1,76 +1,53 @@
|
|||||||
"""Ragged Tensors.
|
"""Ragged Tensors.
|
||||||
|
|
||||||
This package defines the [`RaggedTensor`](ragged/RaggedTensor.md) class, which
|
This package defines the `tf.RaggedTensor` class, which
|
||||||
represents tensors with non-uniform shapes. In particular, each `RaggedTensor`
|
represents tensors with non-uniform shapes. In particular, each `RaggedTensor`
|
||||||
has one or more *ragged dimensions*, which are dimensions whose slices may have
|
has one or more *ragged dimensions*, which are dimensions whose slices may have
|
||||||
different lengths. For example, the inner (column) dimension of
|
different lengths. For example, the inner (column) dimension of
|
||||||
`rt=[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is ragged, since the column slices
|
`rt=[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is ragged, since the column slices
|
||||||
(`rt[0, :]`, ..., `rt[4, :]`) have different lengths. For a more detailed
|
(`rt[0, :]`, ..., `rt[4, :]`) have different lengths. For a more detailed
|
||||||
description of ragged tensors, see the [`RaggedTensor`](ragged/RaggedTensor.md)
|
description of ragged tensors, see the `tf.RaggedTensor`
|
||||||
class documentation.
|
class documentation.
|
||||||
|
|
||||||
## RaggedTensor Operations
|
## `RaggedTensor` Operations
|
||||||
|
|
||||||
This package also defines a collection of operations for manipulating
|
### `RaggedTensor` Factory ops
|
||||||
ragged tensors.
|
|
||||||
|
|
||||||
### RaggedTensor Versions of Standard Tensor Operations
|
* `tf.ragged.constant`
|
||||||
|
* `tf.ragged.from_row_splits`
|
||||||
|
* `tf.ragged.from_row_splits`
|
||||||
|
* `tf.ragged.from_row_lengths`
|
||||||
|
* `tf.ragged.from_row_starts`
|
||||||
|
* `tf.ragged.from_row_limits`
|
||||||
|
* `tf.ragged.from_value_rowids`
|
||||||
|
* `tf.ragged.from_nested_row_splits`
|
||||||
|
* `tf.ragged.from_nested_value_rowids`
|
||||||
|
|
||||||
Many of the operations defined by this package are analogous to
|
### `RaggedTensor` Conversion ops
|
||||||
[`Tensor`](https://www.tensorflow.org/api_docs/python/tf/Tensor)
|
|
||||||
operations, but they accept `RaggedTensor`s as input and can return
|
|
||||||
`RaggedTensor`s as output. For example, `ragged.add` performs elementwise
|
|
||||||
addition just like `tf.add`, but can be used on `RaggedTensor`s.
|
|
||||||
|
|
||||||
These `RaggedTensor` versions of the standard `Tensor` operations can also be
|
* `tf.ragged.from_tensor`
|
||||||
used with standard `Tensors`; and for the most part, they will return the same
|
* `tf.ragged.to_tensor`
|
||||||
value that the standard `Tensor` operation would return. However, there are
|
* `tf.ragged.from_sparse`
|
||||||
a few notable exceptions:
|
* `tf.ragged.to_sparse`
|
||||||
|
* `tf.ragged.from_variant`
|
||||||
|
* `tf.ragged.to_variant`
|
||||||
|
* `tf.ragged.convert_to_tensor_or_ragged_tensor`
|
||||||
|
|
||||||
* For [`ragged.stack(...)`](ragged/stack.md) and
|
### `RaggedTensor` Shape ops
|
||||||
[`ragged.concat(...)`](ragged/concat.md), the input tensors are not required
|
|
||||||
to have matching shapes. In the returned tensor, all dimensions up to the
|
|
||||||
`axis` dimension will be ragged.
|
|
||||||
|
|
||||||
### Ragged-Tensor Specific Operations
|
* `tf.ragged.row_splits`
|
||||||
|
* `tf.ragged.row_lengths`
|
||||||
|
* `tf.ragged.row_starts`
|
||||||
|
* `tf.ragged.row_limits`
|
||||||
|
* `tf.ragged.value_rowids`
|
||||||
|
* `tf.ragged.nrows`
|
||||||
|
* `tf.ragged.nested_row_splits`
|
||||||
|
* `tf.ragged.row_splits_to_segment_ids`
|
||||||
|
* `tf.ragged.segment_ids_to_row_splits`
|
||||||
|
* `tf.ragged.bounding_shape`
|
||||||
|
|
||||||
The following operations are specific to ragged tensors:
|
### Functional ops
|
||||||
|
* `tf.ragged.map_inner_values`
|
||||||
* **Factory ops**:
|
|
||||||
[`constant(...)`](ragged/constant.md),
|
|
||||||
[`from_row_splits(...)`](ragged/from_row_splits.md),
|
|
||||||
[`from_row_lengths(...)`](ragged/from_row_lengths.md),
|
|
||||||
[`from_row_starts(...)`](ragged/from_row_starts.md),
|
|
||||||
[`from_row_limits(...)`](ragged/from_row_limits.md),
|
|
||||||
[`from_value_rowids(...)`](ragged/from_value_rowids.md),
|
|
||||||
[`from_nested_row_splits(...)`](ragged/from_nested_row_splits.md),
|
|
||||||
[`from_nested_value_rowids(...)`](ragged/from_nested_value_rowids.md).
|
|
||||||
|
|
||||||
* **Conversion ops**:
|
|
||||||
[`from_tensor(...)`](ragged/from_tensor.md),
|
|
||||||
[`to_tensor(...)`](ragged/to_tensor.md),
|
|
||||||
[`from_sparse(...)`](ragged/from_sparse.md),
|
|
||||||
[`to_sparse(...)`](ragged/to_sparse.md),
|
|
||||||
[`from_variant(...)`](ragged/from_variant.md),
|
|
||||||
[`to_variant(...)`](ragged/to_variant.md),
|
|
||||||
[`convert_to_tensor_or_ragged_tensor(...)`](
|
|
||||||
ragged/convert_to_tensor_or_ragged_tensor.md).
|
|
||||||
|
|
||||||
* **Shape ops**:
|
|
||||||
[`row_splits(...)`](ragged/row_splits.md),
|
|
||||||
[`row_lengths(...)`](ragged/row_lengths.md),
|
|
||||||
[`row_starts(...)`](ragged/row_starts.md),
|
|
||||||
[`row_limits(...)`](ragged/row_limits.md),
|
|
||||||
[`value_rowids(...)`](ragged/value_rowids.md),
|
|
||||||
[`nrows(...)`](ragged/nrows.md),
|
|
||||||
[`nested_row_splits(...)`](ragged/nested_row_splits.md),
|
|
||||||
[`row_splits_to_segment_ids(...)`](ragged/row_splits_to_segment_ids.md),
|
|
||||||
[`segment_ids_to_row_splits(...)`](ragged/segment_ids_to_row_splits.md),
|
|
||||||
[`bounding_shape(...)`](ragged/bounding_shape.md).
|
|
||||||
|
|
||||||
* **Functional ops**:
|
|
||||||
[`map_inner_values(...)`](ragged/map_inner_values.md),
|
|
||||||
[`make_elementwise_op(...)`](ragged/make_elementwise_op.md).
|
|
||||||
|
|
||||||
|
|
||||||
<!-- Ragged Classes & related helper functions -->
|
<!-- Ragged Classes & related helper functions -->
|
||||||
@ -140,21 +117,17 @@ The following operations are specific to ragged tensors:
|
|||||||
@@map_inner_values
|
@@map_inner_values
|
||||||
@@map_fn
|
@@map_fn
|
||||||
|
|
||||||
<!-- Elementwise Ops -->
|
|
||||||
@@make_elementwise_op
|
|
||||||
|
|
||||||
<!-- Shape & broadcasting -->
|
<!-- Shape & broadcasting -->
|
||||||
@@RaggedTensorDynamicShape
|
@@RaggedTensorDynamicShape
|
||||||
@@broadcast_to
|
@@broadcast_to
|
||||||
@@broadcast_dynamic_shape
|
@@broadcast_dynamic_shape
|
||||||
|
|
||||||
<!-- Symbols from ragged_elementwise_ops._symbols_to_export are whitelisted -->
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.ops.ragged import ragged_dispatch
|
||||||
from tensorflow.python.ops.ragged import ragged_operators
|
from tensorflow.python.ops.ragged import ragged_operators
|
||||||
from tensorflow.python.ops.ragged import ragged_string_ops
|
from tensorflow.python.ops.ragged import ragged_string_ops
|
||||||
|
|
||||||
@ -179,11 +152,6 @@ from tensorflow.python.ops.ragged.ragged_conversion_ops import from_tensor
|
|||||||
from tensorflow.python.ops.ragged.ragged_conversion_ops import to_sparse
|
from tensorflow.python.ops.ragged.ragged_conversion_ops import to_sparse
|
||||||
from tensorflow.python.ops.ragged.ragged_conversion_ops import to_tensor
|
from tensorflow.python.ops.ragged.ragged_conversion_ops import to_tensor
|
||||||
|
|
||||||
# pylint: disable=protected-access, wildcard-import
|
|
||||||
from tensorflow.python.ops.ragged.ragged_elementwise_ops import *
|
|
||||||
from tensorflow.python.ops.ragged.ragged_elementwise_ops import _symbols_to_export as _elementwise_ops
|
|
||||||
# pylint: enable=protected-access, wildcard-import
|
|
||||||
|
|
||||||
from tensorflow.python.ops.ragged.ragged_factory_ops import constant
|
from tensorflow.python.ops.ragged.ragged_factory_ops import constant
|
||||||
from tensorflow.python.ops.ragged.ragged_factory_ops import constant_value
|
from tensorflow.python.ops.ragged.ragged_factory_ops import constant_value
|
||||||
from tensorflow.python.ops.ragged.ragged_factory_ops import convert_to_tensor_or_ragged_tensor
|
from tensorflow.python.ops.ragged.ragged_factory_ops import convert_to_tensor_or_ragged_tensor
|
||||||
@ -231,6 +199,10 @@ from tensorflow.python.ops.ragged.segment_id_ops import segment_ids_to_row_split
|
|||||||
|
|
||||||
from tensorflow.python.util import all_util as _all_util
|
from tensorflow.python.util import all_util as _all_util
|
||||||
|
|
||||||
|
|
||||||
|
# Register OpDispatchers that override standard TF ops to work w/ RaggedTensors.
|
||||||
|
__doc__ += ragged_dispatch.register_dispatchers() # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
# Any symbol that is not referenced (with "@@name") in the module docstring
|
# Any symbol that is not referenced (with "@@name") in the module docstring
|
||||||
# above, or included in the "_elementwise_ops" whitelist, will be removed.
|
# above will be removed.
|
||||||
_all_util.remove_undocumented(__name__, _elementwise_ops)
|
_all_util.remove_undocumented(__name__)
|
||||||
|
@ -308,7 +308,7 @@ def bounding_shape(rt_input, axis=None, name=None):
|
|||||||
# ragged_gather
|
# ragged_gather
|
||||||
#===============================================================================
|
#===============================================================================
|
||||||
# TODO(edloper): Add an `axis` argument
|
# TODO(edloper): Add an `axis` argument
|
||||||
def gather(params, indices, name=None):
|
def gather(params, indices, validate_indices=None, axis=0, name=None):
|
||||||
"""Gathers ragged slices from `params` axis `0` according to `indices`.
|
"""Gathers ragged slices from `params` axis `0` according to `indices`.
|
||||||
|
|
||||||
Returns `RaggedTensor` output, such that:
|
Returns `RaggedTensor` output, such that:
|
||||||
@ -347,6 +347,8 @@ def gather(params, indices, name=None):
|
|||||||
indices: The potentially ragged tensor indicating which values to gather.
|
indices: The potentially ragged tensor indicating which values to gather.
|
||||||
Must have dtype `int32` or `int64`. Values must be in the range `[0,
|
Must have dtype `int32` or `int64`. Values must be in the range `[0,
|
||||||
params.shape[0]]`.
|
params.shape[0]]`.
|
||||||
|
validate_indices: Ignored.
|
||||||
|
axis: Must be zero.
|
||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -357,6 +359,9 @@ def gather(params, indices, name=None):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If indices.shape.ndims is not known statically.
|
ValueError: If indices.shape.ndims is not known statically.
|
||||||
"""
|
"""
|
||||||
|
del validate_indices
|
||||||
|
if not isinstance(axis, int) or axis != 0:
|
||||||
|
raise ValueError('axis>0 is not supported for ragged gather yet.')
|
||||||
with ops.name_scope(name, 'RaggedGather', [params, indices]):
|
with ops.name_scope(name, 'RaggedGather', [params, indices]):
|
||||||
params = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
|
params = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
|
||||||
params, name='params')
|
params, name='params')
|
||||||
@ -812,29 +817,29 @@ def boolean_mask(data, mask, keepdims=False, name=None):
|
|||||||
#===============================================================================
|
#===============================================================================
|
||||||
# Concatenation and Stacking
|
# Concatenation and Stacking
|
||||||
#===============================================================================
|
#===============================================================================
|
||||||
def concat(rt_inputs, axis, name=None):
|
def concat(values, axis, name=None):
|
||||||
"""Concatenates potentially ragged tensors along one dimension.
|
"""Concatenates potentially ragged tensors along one dimension.
|
||||||
|
|
||||||
Given a list of tensors with the same rank `K` (`K >= axis`), returns a
|
Given a list of tensors with the same rank `K` (`K >= axis`), returns a
|
||||||
rank-`K` `RaggedTensor` `result` such that `result[i0...iaxis]` is the
|
rank-`K` `RaggedTensor` `result` such that `result[i0...iaxis]` is the
|
||||||
concatenation of `[rt[i0...iaxis] for rt in rt_inputs]`.
|
concatenation of `[rt[i0...iaxis] for rt in values]`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
rt_inputs: A list of potentially ragged tensors. May not be empty. All
|
values: A list of potentially ragged tensors. May not be empty. All
|
||||||
`rt_inputs` must have the same rank and the same dtype; but unlike
|
`values` must have the same rank and the same dtype; but unlike
|
||||||
`tf.concat`, they can have arbitrary shapes.
|
`tf.concat`, they can have arbitrary shapes.
|
||||||
axis: A python integer, indicating the dimension along which to concatenate.
|
axis: A python integer, indicating the dimension along which to concatenate.
|
||||||
(Note: Unlike `tf.concat`, the `axis` parameter must be statically known.)
|
(Note: Unlike `tf.concat`, the `axis` parameter must be statically known.)
|
||||||
Negative values are supported only if the rank of at least one
|
Negative values are supported only if the rank of at least one
|
||||||
`rt_inputs` value is statically known.
|
`values` value is statically known.
|
||||||
name: A name prefix for the returned tensor (optional).
|
name: A name prefix for the returned tensor (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `RaggedTensor` with rank `K`.
|
A `RaggedTensor` with rank `K`.
|
||||||
`result.ragged_rank=max(axis, max(rt.ragged_rank for rt in rt_inputs]))`.
|
`result.ragged_rank=max(axis, max(rt.ragged_rank for rt in values]))`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If `rt_inputs` is empty, if `axis` is out of bounds or if
|
ValueError: If `values` is empty, if `axis` is out of bounds or if
|
||||||
the input tensors have different ranks.
|
the input tensors have different ranks.
|
||||||
|
|
||||||
#### Example:
|
#### Example:
|
||||||
@ -847,35 +852,35 @@ def concat(rt_inputs, axis, name=None):
|
|||||||
[[1, 2, 6], [3, 4, 5, 7, 8, 9]]
|
[[1, 2, 6], [3, 4, 5, 7, 8, 9]]
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
if not isinstance(rt_inputs, (list, tuple)):
|
if not isinstance(values, (list, tuple)):
|
||||||
rt_inputs = [rt_inputs]
|
values = [values]
|
||||||
with ops.name_scope(name, 'RaggedConcat', rt_inputs):
|
with ops.name_scope(name, 'RaggedConcat', values):
|
||||||
return _ragged_stack_concat_helper(rt_inputs, axis, stack_values=False)
|
return _ragged_stack_concat_helper(values, axis, stack_values=False)
|
||||||
|
|
||||||
|
|
||||||
def stack(rt_inputs, axis, name=None):
|
def stack(values, axis, name=None):
|
||||||
"""Stacks potentially ragged tensors along one dimension.
|
"""Stacks potentially ragged tensors along one dimension.
|
||||||
|
|
||||||
Given a list of tensors with the same rank `K` (`K >= axis`), returns a
|
Given a list of tensors with the same rank `K` (`K >= axis`), returns a
|
||||||
rank-`K+1` `RaggedTensor` `result` such that `result[i0...iaxis]` is the
|
rank-`K+1` `RaggedTensor` `result` such that `result[i0...iaxis]` is the
|
||||||
list `[rt[i0...iaxis] for rt in rt_inputs]`.
|
list `[rt[i0...iaxis] for rt in values]`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
rt_inputs: A list of potentially ragged tensors. May not be empty. All
|
values: A list of potentially ragged tensors. May not be empty. All
|
||||||
`rt_inputs` must have the same rank and the same dtype; but unlike
|
`values` must have the same rank and the same dtype; but unlike
|
||||||
`tf.concat`, they can have arbitrary shapes.
|
`tf.concat`, they can have arbitrary shapes.
|
||||||
axis: A python integer, indicating the dimension along which to stack.
|
axis: A python integer, indicating the dimension along which to stack.
|
||||||
(Note: Unlike `tf.stack`, the `axis` parameter must be statically known.)
|
(Note: Unlike `tf.stack`, the `axis` parameter must be statically known.)
|
||||||
Negative values are supported only if the rank of at least one
|
Negative values are supported only if the rank of at least one
|
||||||
`rt_inputs` value is statically known.
|
`values` value is statically known.
|
||||||
name: A name prefix for the returned tensor (optional).
|
name: A name prefix for the returned tensor (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `RaggedTensor` with rank `K+1`.
|
A `RaggedTensor` with rank `K+1`.
|
||||||
`result.ragged_rank=max(axis, max(rt.ragged_rank for rt in rt_inputs]))`.
|
`result.ragged_rank=max(axis, max(rt.ragged_rank for rt in values]))`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If `rt_inputs` is empty, if `axis` is out of bounds or if
|
ValueError: If `values` is empty, if `axis` is out of bounds or if
|
||||||
the input tensors have different ranks.
|
the input tensors have different ranks.
|
||||||
|
|
||||||
#### Example:
|
#### Example:
|
||||||
@ -888,10 +893,10 @@ def stack(rt_inputs, axis, name=None):
|
|||||||
[[[1, 2], [6]], [[3, 4, 5], [7, 8, 9]]]
|
[[[1, 2], [6]], [[3, 4, 5], [7, 8, 9]]]
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
if not isinstance(rt_inputs, (list, tuple)):
|
if not isinstance(values, (list, tuple)):
|
||||||
rt_inputs = [rt_inputs]
|
values = [values]
|
||||||
with ops.name_scope(name, 'RaggedConcat', rt_inputs):
|
with ops.name_scope(name, 'RaggedConcat', values):
|
||||||
return _ragged_stack_concat_helper(rt_inputs, axis, stack_values=True)
|
return _ragged_stack_concat_helper(values, axis, stack_values=True)
|
||||||
|
|
||||||
|
|
||||||
def _ragged_stack_concat_helper(rt_inputs, axis, stack_values):
|
def _ragged_stack_concat_helper(rt_inputs, axis, stack_values):
|
||||||
@ -1065,22 +1070,22 @@ def _copy_row_shape(rt_inputs, splits):
|
|||||||
#===============================================================================
|
#===============================================================================
|
||||||
# Tiling
|
# Tiling
|
||||||
#===============================================================================
|
#===============================================================================
|
||||||
def tile(rt_input, multiples, name=None):
|
def tile(input, multiples, name=None): # pylint: disable=redefined-builtin
|
||||||
"""Constructs a `RaggedTensor` by tiling a given `RaggedTensor`.
|
"""Constructs a `RaggedTensor` by tiling a given `RaggedTensor`.
|
||||||
|
|
||||||
The values of `rt_input` are replicated `multiples[i]` times along the
|
The values of `input` are replicated `multiples[i]` times along the
|
||||||
`i`th dimension (for each dimension `i`). For every dimension `axis` in
|
`i`th dimension (for each dimension `i`). For every dimension `axis` in
|
||||||
`rt_input`, the length of each output element in that dimension is the
|
`input`, the length of each output element in that dimension is the
|
||||||
length of corresponding input element multiplied by `multiples[axis]`.
|
length of corresponding input element multiplied by `multiples[axis]`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
rt_input: A `RaggedTensor`.
|
input: A `RaggedTensor`.
|
||||||
multiples: A 1-D integer `Tensor`. Length must be the same as the number of
|
multiples: A 1-D integer `Tensor`. Length must be the same as the number of
|
||||||
dimensions in `rt_input`.
|
dimensions in `input`.
|
||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `RaggedTensor` with the same type, rank, and ragged_rank as `rt_input`.
|
A `RaggedTensor` with the same type, rank, and ragged_rank as `input`.
|
||||||
|
|
||||||
#### Example:
|
#### Example:
|
||||||
```python
|
```python
|
||||||
@ -1089,22 +1094,22 @@ def tile(rt_input, multiples, name=None):
|
|||||||
[[1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3]]
|
[[1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3]]
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(name, 'RaggedTile', [rt_input, multiples]):
|
with ops.name_scope(name, 'RaggedTile', [input, multiples]):
|
||||||
rt_input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
|
input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
|
||||||
rt_input, name='rt_input')
|
input, name='input')
|
||||||
multiples = ragged_util.convert_to_int_tensor(
|
multiples = ragged_util.convert_to_int_tensor(
|
||||||
multiples, name='multiples', dtype=dtypes.int64)
|
multiples, name='multiples', dtype=dtypes.int64)
|
||||||
multiples.shape.assert_has_rank(1)
|
multiples.shape.assert_has_rank(1)
|
||||||
if not ragged_tensor.is_ragged(rt_input):
|
if not ragged_tensor.is_ragged(input):
|
||||||
return array_ops.tile(rt_input, multiples, name)
|
return array_ops.tile(input, multiples, name)
|
||||||
|
|
||||||
# If the constant value of `multiples` is available, then we can use it
|
# If the constant value of `multiples` is available, then we can use it
|
||||||
# to skip tiling dimensions where `multiples=1`.
|
# to skip tiling dimensions where `multiples=1`.
|
||||||
const_multiples = tensor_util.constant_value(multiples)
|
const_multiples = tensor_util.constant_value(multiples)
|
||||||
|
|
||||||
return ragged_factory_ops.from_nested_row_splits(
|
return ragged_factory_ops.from_nested_row_splits(
|
||||||
_tile_ragged_values(rt_input, multiples, const_multiples),
|
_tile_ragged_values(input, multiples, const_multiples),
|
||||||
_tile_ragged_splits(rt_input, multiples, const_multiples))
|
_tile_ragged_splits(input, multiples, const_multiples))
|
||||||
|
|
||||||
|
|
||||||
def _tile_ragged_values(rt_input, multiples, const_multiples=None):
|
def _tile_ragged_values(rt_input, multiples, const_multiples=None):
|
||||||
@ -1240,26 +1245,26 @@ def _tile_ragged_splits(rt_input, multiples, const_multiples=None):
|
|||||||
#===============================================================================
|
#===============================================================================
|
||||||
|
|
||||||
|
|
||||||
def expand_dims(rt_input, axis, name=None):
|
def expand_dims(input, axis, name=None): # pylint: disable=redefined-builtin
|
||||||
"""Inserts a dimension with shape 1 into a potentially ragged tensor's shape.
|
"""Inserts a dimension with shape 1 into a potentially ragged tensor's shape.
|
||||||
|
|
||||||
Given a potentially ragged tenor `rt_input`, this operation inserts a
|
Given a potentially ragged tenor `input`, this operation inserts a
|
||||||
dimension with size 1 at the dimension `axis` of `rt_input`'s shape.
|
dimension with size 1 at the dimension `axis` of `input`'s shape.
|
||||||
|
|
||||||
* If `rt_input` is a `Tensor`, then this is equivalent to
|
* If `input` is a `Tensor`, then this is equivalent to
|
||||||
`tf.expand_dims`.
|
`tf.expand_dims`.
|
||||||
* If `rt_input` is ragged, and `axis=0`, then the new dimension will be
|
* If `input` is ragged, and `axis=0`, then the new dimension will be
|
||||||
uniform; but the previously outermost dimension will become ragged.
|
uniform; but the previously outermost dimension will become ragged.
|
||||||
* If `rt_input` is ragged, and `0 < axis < rt_input.ragged_rank`, then the
|
* If `input` is ragged, and `0 < axis < input.ragged_rank`, then the
|
||||||
new dimension will be ragged.
|
new dimension will be ragged.
|
||||||
* If `rt_input` is ragged, and axis >= rt_input.ragged_rank`, then the new
|
* If `input` is ragged, and axis >= input.ragged_rank`, then the new
|
||||||
dimension will be uniform.
|
dimension will be uniform.
|
||||||
|
|
||||||
The following table gives some examples showing how `ragged.expand_dims`
|
The following table gives some examples showing how `ragged.expand_dims`
|
||||||
impacts the shapes of different input tensors. Ragged dimensions are
|
impacts the shapes of different input tensors. Ragged dimensions are
|
||||||
indicated by enclosing them in parentheses.
|
indicated by enclosing them in parentheses.
|
||||||
|
|
||||||
rt_input.shape | axis | result.shape
|
input.shape | axis | result.shape
|
||||||
----------------------- | ---- | -----------------------------
|
----------------------- | ---- | -----------------------------
|
||||||
`[D1, D2]` | `0` | `[1, D1, D2]`
|
`[D1, D2]` | `0` | `[1, D1, D2]`
|
||||||
`[D1, D2]` | `1` | `[D1, 1, D2]`
|
`[D1, D2]` | `1` | `[D1, 1, D2]`
|
||||||
@ -1271,14 +1276,14 @@ def expand_dims(rt_input, axis, name=None):
|
|||||||
`[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]`
|
`[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]`
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
rt_input: The potentially tensor that should be expanded with a new
|
input: The potentially tensor that should be expanded with a new
|
||||||
dimension.
|
dimension.
|
||||||
axis: An integer constant indicating where the new dimension should be
|
axis: An integer constant indicating where the new dimension should be
|
||||||
inserted.
|
inserted.
|
||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tensor with the same values as `rt_input`, with an added dimension of
|
A tensor with the same values as `input`, with an added dimension of
|
||||||
size 1 at `axis`.
|
size 1 at `axis`.
|
||||||
|
|
||||||
#### Examples:
|
#### Examples:
|
||||||
@ -1300,24 +1305,24 @@ def expand_dims(rt_input, axis, name=None):
|
|||||||
TensorShape([2, None, 1]) [[[1], [2]], [[3]]]
|
TensorShape([2, None, 1]) [[[1], [2]], [[3]]]
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(name, 'RaggedExpandDims', [rt_input]):
|
with ops.name_scope(name, 'RaggedExpandDims', [input]):
|
||||||
rt_input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
|
input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
|
||||||
rt_input, name='rt_input')
|
input, name='input')
|
||||||
|
|
||||||
if not ragged_tensor.is_ragged(rt_input):
|
if not ragged_tensor.is_ragged(input):
|
||||||
return array_ops.expand_dims(rt_input, axis)
|
return array_ops.expand_dims(input, axis)
|
||||||
|
|
||||||
ndims = None if rt_input.shape.ndims is None else rt_input.shape.ndims + 1
|
ndims = None if input.shape.ndims is None else input.shape.ndims + 1
|
||||||
axis = ragged_util.get_positive_axis(axis, ndims)
|
axis = ragged_util.get_positive_axis(axis, ndims)
|
||||||
if axis == 0:
|
if axis == 0:
|
||||||
values = rt_input
|
values = input
|
||||||
splits = array_ops.stack([0, nrows(rt_input)])
|
splits = array_ops.stack([0, nrows(input)])
|
||||||
elif axis == 1:
|
elif axis == 1:
|
||||||
values = rt_input
|
values = input
|
||||||
splits = math_ops.range(nrows(rt_input) + 1)
|
splits = math_ops.range(nrows(input) + 1)
|
||||||
else:
|
else:
|
||||||
values = expand_dims(rt_input.values, axis - 1)
|
values = expand_dims(input.values, axis - 1)
|
||||||
splits = rt_input.row_splits
|
splits = input.row_splits
|
||||||
|
|
||||||
return ragged_factory_ops.from_row_splits(values, splits)
|
return ragged_factory_ops.from_row_splits(values, splits)
|
||||||
|
|
||||||
|
441
tensorflow/python/ops/ragged/ragged_dispatch.py
Normal file
441
tensorflow/python/ops/ragged/ragged_dispatch.py
Normal file
@ -0,0 +1,441 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Operator dispatch for RaggedTensors."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import clip_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import parsing_ops
|
||||||
|
from tensorflow.python.ops import string_ops
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.ops.ragged import ragged_array_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_math_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
|
from tensorflow.python.ops.ragged import ragged_tensor_shape
|
||||||
|
from tensorflow.python.ops.ragged import ragged_util
|
||||||
|
from tensorflow.python.util import dispatch
|
||||||
|
from tensorflow.python.util import tf_decorator
|
||||||
|
from tensorflow.python.util import tf_export
|
||||||
|
from tensorflow.python.util import tf_inspect
|
||||||
|
|
||||||
|
# @TODO(edloper): Set this to True in the CL that exports RaggedTensors.
|
||||||
|
_UPDATE_DOCSTRINGS = False
|
||||||
|
|
||||||
|
# Information about an argument to an operation: The name of the argument, its
|
||||||
|
# position in the argument list, and a boolean flag indicating whether it
|
||||||
|
# expects a list of tensors.
|
||||||
|
_ArgInfo = collections.namedtuple('ArgInfo', ['name', 'position', 'is_list'])
|
||||||
|
|
||||||
|
|
||||||
|
def _get_arg_infos(func, arg_names):
|
||||||
|
"""Returns an `_ArgInfo` for each argument of `func` specified by `arg_names`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function whose arguments should be described.
|
||||||
|
arg_names: The names of the arguments to get info for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of `_ArgInfo`s.
|
||||||
|
"""
|
||||||
|
arg_infos = []
|
||||||
|
|
||||||
|
# Inspect the func's argspec to find the position of each arg.
|
||||||
|
arg_spec = tf_inspect.getargspec(func)
|
||||||
|
for argname in arg_names:
|
||||||
|
assert isinstance(argname, str)
|
||||||
|
is_list = argname.startswith('[') and argname.endswith(']')
|
||||||
|
if is_list:
|
||||||
|
argname = argname[1:-1]
|
||||||
|
if argname not in arg_spec.args:
|
||||||
|
raise ValueError('Argument %r not found function in %s. Args=%s' %
|
||||||
|
(argname, func, arg_spec.args))
|
||||||
|
arg_infos.append(_ArgInfo(argname, arg_spec.args.index(argname), is_list))
|
||||||
|
return arg_infos
|
||||||
|
|
||||||
|
|
||||||
|
def _is_convertible_to_tensor(value):
|
||||||
|
"""Returns true if `value` is convertible to a `Tensor`."""
|
||||||
|
if isinstance(value,
|
||||||
|
(ops.Tensor, variables.Variable, np.ndarray, int, float, str)):
|
||||||
|
return True
|
||||||
|
elif isinstance(value, (sparse_tensor.SparseTensor,)):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
ops.convert_to_tensor(value)
|
||||||
|
return True
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
|
||||||
|
"""OpDispatcher for unary ops that map a base op across ragged values."""
|
||||||
|
|
||||||
|
def __init__(self, original_op, arg_is_list=False):
|
||||||
|
self._original_op = original_op
|
||||||
|
self._arg_is_list = arg_is_list
|
||||||
|
arg_names = tf_inspect.getfullargspec(original_op)[0]
|
||||||
|
self._x = arg_names[0]
|
||||||
|
if _UPDATE_DOCSTRINGS:
|
||||||
|
original_op.__doc__ = (
|
||||||
|
original_op.__doc__.rstrip() + '\n\n' +
|
||||||
|
' `{x}` may be a `tf.RaggedTensor`.\n'.format(x=self._x))
|
||||||
|
|
||||||
|
def handle(self, args, kwargs):
|
||||||
|
if args:
|
||||||
|
x, args = args[0], args[1:]
|
||||||
|
else:
|
||||||
|
kwargs = kwargs.copy()
|
||||||
|
x = kwargs.pop(self._x, None)
|
||||||
|
if x is None:
|
||||||
|
return self.NOT_SUPPORTED
|
||||||
|
if self._arg_is_list:
|
||||||
|
found_ragged = False
|
||||||
|
for elt in x:
|
||||||
|
if ragged_tensor.is_ragged(elt):
|
||||||
|
found_ragged = True
|
||||||
|
elif not _is_convertible_to_tensor(elt):
|
||||||
|
return self.NOT_SUPPORTED
|
||||||
|
if found_ragged:
|
||||||
|
nested_splits_lists = [
|
||||||
|
elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt)
|
||||||
|
]
|
||||||
|
inner_values = [
|
||||||
|
elt.inner_values if ragged_tensor.is_ragged(elt) else elt
|
||||||
|
for elt in x
|
||||||
|
]
|
||||||
|
with ops.control_dependencies(
|
||||||
|
ragged_util.assert_splits_match(nested_splits_lists)):
|
||||||
|
return ragged_factory_ops.from_nested_row_splits(
|
||||||
|
self._original_op(inner_values, *args, **kwargs),
|
||||||
|
nested_splits_lists[0])
|
||||||
|
else:
|
||||||
|
return self.NOT_SUPPORTED
|
||||||
|
else:
|
||||||
|
found_ragged = ragged_tensor.is_ragged(x)
|
||||||
|
if found_ragged:
|
||||||
|
mapped_values = self._original_op(x.inner_values, *args, **kwargs)
|
||||||
|
return x.with_inner_values(mapped_values)
|
||||||
|
else:
|
||||||
|
return self.NOT_SUPPORTED
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
|
||||||
|
"""OpDispatcher for binary ops that map a base op across ragged values.
|
||||||
|
|
||||||
|
Supports broadcasting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, original_op):
|
||||||
|
self._original_op = original_op
|
||||||
|
arg_names = tf_inspect.getfullargspec(original_op)[0]
|
||||||
|
self._x = arg_names[0]
|
||||||
|
self._y = arg_names[1]
|
||||||
|
if _UPDATE_DOCSTRINGS:
|
||||||
|
original_op.__doc__ = (
|
||||||
|
original_op.__doc__.rstrip() + '\n\n' +
|
||||||
|
' `{x}` and `{y}` may be a `tf.RaggedTensor`.\n'.format(
|
||||||
|
x=self._x, y=self._y))
|
||||||
|
|
||||||
|
def handle(self, args, kwargs):
|
||||||
|
# Extract the binary args.
|
||||||
|
if len(args) > 1:
|
||||||
|
x = args[0]
|
||||||
|
y = args[1]
|
||||||
|
args = args[2:]
|
||||||
|
elif args:
|
||||||
|
kwargs = kwargs.copy()
|
||||||
|
x = args[0]
|
||||||
|
y = kwargs.pop(self._y, None)
|
||||||
|
args = args[1:]
|
||||||
|
else:
|
||||||
|
kwargs = kwargs.copy()
|
||||||
|
x = kwargs.pop(self._x, None)
|
||||||
|
y = kwargs.pop(self._y, None)
|
||||||
|
|
||||||
|
# Bail if we don't have at least one ragged argument.
|
||||||
|
x_is_ragged = ragged_tensor.is_ragged(x)
|
||||||
|
y_is_ragged = ragged_tensor.is_ragged(y)
|
||||||
|
if not (x_is_ragged or y_is_ragged):
|
||||||
|
return self.NOT_SUPPORTED
|
||||||
|
|
||||||
|
# Convert args to tensors. Bail if conversion fails.
|
||||||
|
try:
|
||||||
|
if not x_is_ragged:
|
||||||
|
x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype)
|
||||||
|
if not y_is_ragged:
|
||||||
|
y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return self.NOT_SUPPORTED
|
||||||
|
|
||||||
|
if ((x_is_ragged and y_is_ragged) or
|
||||||
|
(x_is_ragged and x.inner_values.shape.ndims <= y.shape.ndims) or
|
||||||
|
(y_is_ragged and y.inner_values.shape.ndims <= x.shape.ndims)):
|
||||||
|
bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(
|
||||||
|
ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x),
|
||||||
|
ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y))
|
||||||
|
x = ragged_tensor_shape.broadcast_to(
|
||||||
|
x, bcast_shape, broadcast_inner_dimensions=False)
|
||||||
|
y = ragged_tensor_shape.broadcast_to(
|
||||||
|
y, bcast_shape, broadcast_inner_dimensions=False)
|
||||||
|
|
||||||
|
x_values = x.inner_values if ragged_tensor.is_ragged(x) else x
|
||||||
|
y_values = y.inner_values if ragged_tensor.is_ragged(y) else y
|
||||||
|
mapped_values = self._original_op(x_values, y_values, *args, **kwargs)
|
||||||
|
if ragged_tensor.is_ragged(x):
|
||||||
|
return x.with_inner_values(mapped_values)
|
||||||
|
else:
|
||||||
|
return y.with_inner_values(mapped_values)
|
||||||
|
|
||||||
|
|
||||||
|
class RaggedDispatcher(dispatch.OpDispatcher):
|
||||||
|
"""OpDispatcher for ragged ops.
|
||||||
|
|
||||||
|
Dispatches to a wrapped op-handler if at least one of the `tensor_args`
|
||||||
|
arguments is a RaggedTensor or a RaggedTensorValue; and all of the
|
||||||
|
`tensor_args` arguments are convertible to Tensor or RaggedTensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, original_op, ragged_op, ragged_args):
|
||||||
|
op_arg_names = tf_inspect.getfullargspec(original_op)[0]
|
||||||
|
ragged_arg_names = tf_inspect.getfullargspec(ragged_op)[0]
|
||||||
|
if op_arg_names != ragged_arg_names:
|
||||||
|
raise AssertionError(
|
||||||
|
'Signature must exactly match when overriding %s with %s: %s vs %s' %
|
||||||
|
(original_op, ragged_op, op_arg_names, ragged_arg_names))
|
||||||
|
self._ragged_op = ragged_op
|
||||||
|
self._ragged_args = _get_arg_infos(ragged_op, ragged_args)
|
||||||
|
if _UPDATE_DOCSTRINGS:
|
||||||
|
arg_list = ' and '.join('`%s`' % arg for arg in ragged_args)
|
||||||
|
original_op.__doc__ = (
|
||||||
|
original_op.__doc__.rstrip() + '\n\n' +
|
||||||
|
' {0} may be a `tf.RaggedTensor`.\n'.format(arg_list))
|
||||||
|
|
||||||
|
def handle(self, args, kwargs):
|
||||||
|
if self.is_supported(args, kwargs):
|
||||||
|
return self._ragged_op(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return self.NOT_SUPPORTED
|
||||||
|
|
||||||
|
def is_supported(self, args, kwargs):
|
||||||
|
found_ragged = False
|
||||||
|
for arg_info in self._ragged_args:
|
||||||
|
if arg_info.position < len(args):
|
||||||
|
arg = args[arg_info.position]
|
||||||
|
else:
|
||||||
|
arg = kwargs.get(arg_info.name, None)
|
||||||
|
|
||||||
|
if arg_info.is_list:
|
||||||
|
if not isinstance(arg, (list, tuple)):
|
||||||
|
return False
|
||||||
|
for elt in arg:
|
||||||
|
if ragged_tensor.is_ragged(elt):
|
||||||
|
found_ragged = True
|
||||||
|
elif not _is_convertible_to_tensor(elt):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
if ragged_tensor.is_ragged(arg):
|
||||||
|
found_ragged = True
|
||||||
|
elif not _is_convertible_to_tensor(arg):
|
||||||
|
return False
|
||||||
|
return found_ragged
|
||||||
|
|
||||||
|
|
||||||
|
def ragged_dispatch(original_op, tensor_args):
|
||||||
|
|
||||||
|
def decorator(ragged_op):
|
||||||
|
dispatch.RaggedDispatcher(original_op, ragged_op,
|
||||||
|
tensor_args).register(original_op)
|
||||||
|
return ragged_op
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
_UNARY_ELEMENTWISE_OPS = [
|
||||||
|
array_ops.check_numerics,
|
||||||
|
array_ops.identity,
|
||||||
|
array_ops.ones_like,
|
||||||
|
array_ops.ones_like_v2,
|
||||||
|
array_ops.zeros_like,
|
||||||
|
array_ops.zeros_like_v2,
|
||||||
|
clip_ops.clip_by_value,
|
||||||
|
math_ops.abs,
|
||||||
|
math_ops.acos,
|
||||||
|
math_ops.acosh,
|
||||||
|
math_ops.angle,
|
||||||
|
math_ops.asin,
|
||||||
|
math_ops.asinh,
|
||||||
|
math_ops.atan,
|
||||||
|
math_ops.atanh,
|
||||||
|
math_ops.cast,
|
||||||
|
math_ops.ceil,
|
||||||
|
math_ops.conj,
|
||||||
|
math_ops.cos,
|
||||||
|
math_ops.cosh,
|
||||||
|
math_ops.digamma,
|
||||||
|
math_ops.erf,
|
||||||
|
math_ops.erfc,
|
||||||
|
math_ops.exp,
|
||||||
|
math_ops.expm1,
|
||||||
|
math_ops.floor,
|
||||||
|
math_ops.imag,
|
||||||
|
math_ops.is_finite,
|
||||||
|
math_ops.is_inf,
|
||||||
|
math_ops.is_nan,
|
||||||
|
math_ops.lgamma,
|
||||||
|
math_ops.log,
|
||||||
|
math_ops.log1p,
|
||||||
|
math_ops.log_sigmoid,
|
||||||
|
math_ops.logical_not,
|
||||||
|
math_ops.negative,
|
||||||
|
math_ops.real,
|
||||||
|
math_ops.reciprocal,
|
||||||
|
math_ops.rint,
|
||||||
|
math_ops.round,
|
||||||
|
math_ops.rsqrt,
|
||||||
|
math_ops.saturate_cast,
|
||||||
|
math_ops.sign,
|
||||||
|
math_ops.sin,
|
||||||
|
math_ops.sinh,
|
||||||
|
math_ops.sqrt,
|
||||||
|
math_ops.square,
|
||||||
|
math_ops.tan,
|
||||||
|
parsing_ops.decode_compressed,
|
||||||
|
string_ops.string_to_number,
|
||||||
|
string_ops.string_to_hash_bucket,
|
||||||
|
string_ops.as_string,
|
||||||
|
string_ops.decode_base64,
|
||||||
|
string_ops.encode_base64,
|
||||||
|
string_ops.regex_full_match,
|
||||||
|
string_ops.regex_replace,
|
||||||
|
string_ops.string_strip,
|
||||||
|
string_ops.string_to_hash_bucket,
|
||||||
|
string_ops.string_to_hash_bucket_fast,
|
||||||
|
string_ops.string_to_hash_bucket_strong,
|
||||||
|
string_ops.substr,
|
||||||
|
string_ops.substr_v2,
|
||||||
|
string_ops.string_length,
|
||||||
|
string_ops.string_length_v2,
|
||||||
|
string_ops.unicode_script,
|
||||||
|
]
|
||||||
|
|
||||||
|
_UNARY_LIST_ELEMENTWISE_OPS = [
|
||||||
|
math_ops.add_n,
|
||||||
|
string_ops.string_join,
|
||||||
|
]
|
||||||
|
|
||||||
|
_BINARY_ELEMENTWISE_OPS = [
|
||||||
|
math_ops.add,
|
||||||
|
math_ops.atan2,
|
||||||
|
math_ops.complex,
|
||||||
|
math_ops.div_no_nan,
|
||||||
|
math_ops.divide,
|
||||||
|
math_ops.equal,
|
||||||
|
math_ops.floordiv,
|
||||||
|
math_ops.floormod,
|
||||||
|
math_ops.greater,
|
||||||
|
math_ops.greater_equal,
|
||||||
|
math_ops.less,
|
||||||
|
math_ops.less_equal,
|
||||||
|
math_ops.logical_and,
|
||||||
|
math_ops.logical_or,
|
||||||
|
math_ops.logical_xor,
|
||||||
|
math_ops.maximum,
|
||||||
|
math_ops.minimum,
|
||||||
|
math_ops.multiply,
|
||||||
|
math_ops.not_equal,
|
||||||
|
math_ops.pow,
|
||||||
|
math_ops.realdiv,
|
||||||
|
math_ops.squared_difference,
|
||||||
|
math_ops.subtract,
|
||||||
|
math_ops.truediv,
|
||||||
|
math_ops.truncatediv,
|
||||||
|
math_ops.truncatemod,
|
||||||
|
]
|
||||||
|
|
||||||
|
# (original_op, ragged_op, ragged_args)
|
||||||
|
_RAGGED_DISPATCH_OPS = [
|
||||||
|
(array_ops.batch_gather, ragged_array_ops.batch_gather,
|
||||||
|
['params', 'indices']),
|
||||||
|
(array_ops.concat, ragged_array_ops.concat, ['values']),
|
||||||
|
(array_ops.expand_dims_v2, ragged_array_ops.expand_dims, ['input']),
|
||||||
|
(array_ops.gather_v2, ragged_array_ops.gather, ['params', 'indices']),
|
||||||
|
(array_ops.gather_nd, ragged_array_ops.gather_nd, ['params', 'indices']),
|
||||||
|
(array_ops.stack, ragged_array_ops.stack, ['values']),
|
||||||
|
(array_ops.tile, ragged_array_ops.tile, ['input']),
|
||||||
|
(array_ops.where, ragged_array_ops.where, ['condition', 'x', 'y']),
|
||||||
|
(math_ops.unsorted_segment_sum, ragged_math_ops.segment_sum,
|
||||||
|
['data', 'segment_ids']),
|
||||||
|
(math_ops.unsorted_segment_prod, ragged_math_ops.segment_prod,
|
||||||
|
['data', 'segment_ids']),
|
||||||
|
(math_ops.unsorted_segment_min, ragged_math_ops.segment_min,
|
||||||
|
['data', 'segment_ids']),
|
||||||
|
(math_ops.unsorted_segment_max, ragged_math_ops.segment_max,
|
||||||
|
['data', 'segment_ids']),
|
||||||
|
(math_ops.unsorted_segment_mean, ragged_math_ops.segment_mean,
|
||||||
|
['data', 'segment_ids']),
|
||||||
|
(math_ops.unsorted_segment_sqrt_n, ragged_math_ops.segment_sqrt_n,
|
||||||
|
['data', 'segment_ids']),
|
||||||
|
(math_ops.reduce_sum, ragged_math_ops.reduce_sum, ['input_tensor']),
|
||||||
|
(math_ops.reduce_prod, ragged_math_ops.reduce_prod, ['input_tensor']),
|
||||||
|
(math_ops.reduce_min, ragged_math_ops.reduce_min, ['input_tensor']),
|
||||||
|
(math_ops.reduce_max, ragged_math_ops.reduce_max, ['input_tensor']),
|
||||||
|
(math_ops.reduce_mean, ragged_math_ops.reduce_mean, ['input_tensor']),
|
||||||
|
(math_ops.reduce_any, ragged_math_ops.reduce_any, ['input_tensor']),
|
||||||
|
(math_ops.reduce_all, ragged_math_ops.reduce_all, ['input_tensor']),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def register_dispatchers():
|
||||||
|
"""Constructs & registers OpDispatchers for ragged ops."""
|
||||||
|
|
||||||
|
op_list = (
|
||||||
|
_UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS +
|
||||||
|
_BINARY_ELEMENTWISE_OPS + [x[0] for x in _RAGGED_DISPATCH_OPS])
|
||||||
|
for op in op_list:
|
||||||
|
_, undecorated_op = tf_decorator.unwrap(op)
|
||||||
|
if not hasattr(undecorated_op, tf_export.API_ATTRS['tensorflow'].names):
|
||||||
|
raise AssertionError('Expected %s to be an exported symbol '
|
||||||
|
'(while adding a RaggedTensor dispatcher)')
|
||||||
|
|
||||||
|
for op in _UNARY_ELEMENTWISE_OPS:
|
||||||
|
UnaryRaggedElementwiseDispatcher(op).register(op)
|
||||||
|
|
||||||
|
for op in _UNARY_LIST_ELEMENTWISE_OPS:
|
||||||
|
UnaryRaggedElementwiseDispatcher(op, True).register(op)
|
||||||
|
|
||||||
|
for op in _BINARY_ELEMENTWISE_OPS:
|
||||||
|
BinaryRaggedElementwiseDispatcher(op).register(op)
|
||||||
|
|
||||||
|
for (original_op, ragged_op, args) in _RAGGED_DISPATCH_OPS:
|
||||||
|
RaggedDispatcher(original_op, ragged_op, args).register(original_op)
|
||||||
|
|
||||||
|
docstring = (
|
||||||
|
'\n\n### Additional ops that support `RaggedTensor`\n\n' + '\n'.join([
|
||||||
|
'* `tf.%s`' % tf_export.get_canonical_name_for_symbol(op)
|
||||||
|
for op in op_list
|
||||||
|
]))
|
||||||
|
|
||||||
|
return docstring
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for ragged.elementwise_ops."""
|
"""Tests for RaggedTensor operator dispatch."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -21,106 +21,108 @@ from __future__ import print_function
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import clip_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import parsing_ops
|
||||||
from tensorflow.python.ops import ragged
|
from tensorflow.python.ops import ragged
|
||||||
|
from tensorflow.python.ops import string_ops
|
||||||
from tensorflow.python.platform import googletest
|
from tensorflow.python.platform import googletest
|
||||||
|
|
||||||
# Constants listing various op types to test. Each elementwise operation
|
# Constants listing various op types to test. Each operation
|
||||||
# should be included in at least one list below, or tested separately if
|
# should be included in at least one list below, or tested separately if
|
||||||
# necessary (e.g., because it expects additional arguments).
|
# necessary (e.g., because it expects additional arguments).
|
||||||
UNARY_FLOAT_OPS = [
|
UNARY_FLOAT_OPS = [
|
||||||
ragged.abs,
|
math_ops.abs,
|
||||||
ragged.acos,
|
math_ops.acos,
|
||||||
ragged.acosh,
|
math_ops.acosh,
|
||||||
ragged.angle,
|
math_ops.angle,
|
||||||
ragged.asin,
|
math_ops.asin,
|
||||||
ragged.asinh,
|
math_ops.asinh,
|
||||||
ragged.atan,
|
math_ops.atan,
|
||||||
ragged.atanh,
|
math_ops.atanh,
|
||||||
ragged.ceil,
|
math_ops.ceil,
|
||||||
ragged.conj,
|
math_ops.conj,
|
||||||
ragged.cos,
|
math_ops.cos,
|
||||||
ragged.cosh,
|
math_ops.cosh,
|
||||||
ragged.digamma,
|
math_ops.digamma,
|
||||||
ragged.erf,
|
math_ops.erf,
|
||||||
ragged.erfc,
|
math_ops.erfc,
|
||||||
ragged.exp,
|
math_ops.exp,
|
||||||
ragged.expm1,
|
math_ops.expm1,
|
||||||
ragged.floor,
|
math_ops.floor,
|
||||||
ragged.imag,
|
math_ops.imag,
|
||||||
ragged.is_finite,
|
math_ops.is_finite,
|
||||||
ragged.is_inf,
|
math_ops.is_inf,
|
||||||
ragged.is_nan,
|
math_ops.is_nan,
|
||||||
ragged.lgamma,
|
math_ops.lgamma,
|
||||||
ragged.log,
|
math_ops.log,
|
||||||
ragged.log1p,
|
math_ops.log1p,
|
||||||
ragged.log_sigmoid,
|
math_ops.log_sigmoid,
|
||||||
ragged.negative,
|
math_ops.negative,
|
||||||
ragged.real,
|
math_ops.real,
|
||||||
ragged.reciprocal,
|
math_ops.reciprocal,
|
||||||
ragged.rint,
|
math_ops.rint,
|
||||||
ragged.round,
|
math_ops.round,
|
||||||
ragged.rsqrt,
|
math_ops.rsqrt,
|
||||||
ragged.sign,
|
math_ops.sign,
|
||||||
ragged.sin,
|
math_ops.sin,
|
||||||
ragged.sinh,
|
math_ops.sinh,
|
||||||
ragged.sqrt,
|
math_ops.sqrt,
|
||||||
ragged.square,
|
math_ops.square,
|
||||||
ragged.tan,
|
math_ops.tan,
|
||||||
ragged.as_string,
|
array_ops.identity,
|
||||||
ragged.identity,
|
array_ops.ones_like,
|
||||||
ragged.ones_like,
|
array_ops.zeros_like,
|
||||||
ragged.zeros_like,
|
|
||||||
]
|
]
|
||||||
UNARY_BOOL_OPS = [
|
UNARY_BOOL_OPS = [
|
||||||
ragged.logical_not,
|
math_ops.logical_not,
|
||||||
]
|
]
|
||||||
UNARY_STRING_OPS = [
|
UNARY_STRING_OPS = [
|
||||||
ragged.decode_base64,
|
string_ops.decode_base64,
|
||||||
ragged.encode_base64,
|
string_ops.encode_base64,
|
||||||
ragged.string_strip,
|
string_ops.string_strip,
|
||||||
ragged.decode_compressed,
|
parsing_ops.decode_compressed,
|
||||||
]
|
]
|
||||||
BINARY_FLOAT_OPS = [
|
BINARY_FLOAT_OPS = [
|
||||||
ragged.add,
|
math_ops.add,
|
||||||
ragged.atan2,
|
math_ops.atan2,
|
||||||
ragged.complex,
|
math_ops.complex,
|
||||||
ragged.div,
|
math_ops.div_no_nan,
|
||||||
ragged.div_no_nan,
|
math_ops.divide,
|
||||||
ragged.divide,
|
math_ops.equal,
|
||||||
ragged.equal,
|
math_ops.floordiv,
|
||||||
ragged.floordiv,
|
math_ops.floormod,
|
||||||
ragged.floormod,
|
math_ops.greater,
|
||||||
ragged.greater,
|
math_ops.greater_equal,
|
||||||
ragged.greater_equal,
|
math_ops.less,
|
||||||
ragged.less,
|
math_ops.less_equal,
|
||||||
ragged.less_equal,
|
math_ops.maximum,
|
||||||
ragged.maximum,
|
math_ops.minimum,
|
||||||
ragged.minimum,
|
math_ops.multiply,
|
||||||
ragged.multiply,
|
math_ops.not_equal,
|
||||||
ragged.not_equal,
|
math_ops.pow,
|
||||||
ragged.pow,
|
math_ops.realdiv,
|
||||||
ragged.realdiv,
|
math_ops.squared_difference,
|
||||||
ragged.squared_difference,
|
math_ops.subtract,
|
||||||
ragged.subtract,
|
math_ops.truediv,
|
||||||
ragged.truediv,
|
|
||||||
]
|
]
|
||||||
BINARY_BOOL_OPS = [
|
BINARY_BOOL_OPS = [
|
||||||
ragged.logical_and,
|
math_ops.logical_and,
|
||||||
ragged.logical_or,
|
math_ops.logical_or,
|
||||||
ragged.logical_xor,
|
math_ops.logical_xor,
|
||||||
]
|
]
|
||||||
UNARY_INT_OPS = [
|
UNARY_INT_OPS = [
|
||||||
ragged.unicode_script,
|
string_ops.unicode_script,
|
||||||
]
|
]
|
||||||
BINARY_INT_OPS = [
|
BINARY_INT_OPS = [
|
||||||
ragged.truncatediv,
|
math_ops.truncatediv,
|
||||||
ragged.truncatemod,
|
math_ops.truncatemod,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -171,50 +173,49 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
|||||||
[{'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]), 'op': op}
|
[{'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]), 'op': op}
|
||||||
for op in UNARY_STRING_OPS] +
|
for op in UNARY_STRING_OPS] +
|
||||||
[
|
[
|
||||||
{'op': ragged.clip_by_value,
|
{'op': clip_ops.clip_by_value,
|
||||||
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
|
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
|
||||||
'clip_value_min': 0.1, 'clip_value_max': 4.0},
|
'clip_value_min': 0.1, 'clip_value_max': 4.0},
|
||||||
{'op': ragged.cast,
|
{'op': math_ops.cast,
|
||||||
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
|
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
|
||||||
'dtype': dtypes.int32},
|
'dtype': dtypes.int32},
|
||||||
{'op': ragged.saturate_cast,
|
{'op': math_ops.saturate_cast,
|
||||||
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
|
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
|
||||||
'dtype': dtypes.int32},
|
'dtype': dtypes.int32},
|
||||||
{'op': ragged.string_to_hash_bucket,
|
{'op': string_ops.string_to_hash_bucket,
|
||||||
'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
|
'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
|
||||||
'num_buckets': 1000},
|
'num_buckets': 1000},
|
||||||
{'op': ragged.string_to_hash_bucket_fast,
|
{'op': string_ops.string_to_hash_bucket_fast,
|
||||||
'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
|
'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
|
||||||
'num_buckets': 1000},
|
'num_buckets': 1000},
|
||||||
{'op': ragged.string_to_hash_bucket_strong,
|
{'op': string_ops.string_to_hash_bucket_strong,
|
||||||
'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
|
'x': ragged.constant_value([['abcd', 'efgh'], ['aabbccdd']]),
|
||||||
'num_buckets': 1000,
|
'num_buckets': 1000,
|
||||||
'key': [1231, 12512]},
|
'key': [1231, 12512]},
|
||||||
{'op': ragged.string_to_number,
|
{'op': string_ops.string_to_number,
|
||||||
'x': ragged.constant_value([['-2.0', '3.0'], ['-3.0']])},
|
'x': ragged.constant_value([['-2.0', '3.0'], ['-3.0']])},
|
||||||
{'op': ragged.regex_full_match,
|
{'op': string_ops.regex_full_match,
|
||||||
'x': ragged.constant_value([['hello', '123'], ['1+1']]),
|
'x': ragged.constant_value([['hello', '123'], ['1+1']]),
|
||||||
'pattern': r'\w+'},
|
'pattern': r'\w+'},
|
||||||
{'op': ragged.regex_replace,
|
{'op': string_ops.regex_replace,
|
||||||
'x': ragged.constant_value([['hello', '123'], ['1+1']]),
|
'x': ragged.constant_value([['hello', '123'], ['1+1']]),
|
||||||
'pattern': r'\d',
|
'pattern': r'\d',
|
||||||
'rewrite': '#'},
|
'rewrite': '#'},
|
||||||
{'op': ragged.substr,
|
{'op': string_ops.substr,
|
||||||
'x': ragged.constant_value([['hello', '123'], ['1+1']]),
|
'x': ragged.constant_value([['hello', '123'], ['1+1']]),
|
||||||
'pos': 2, 'len': 3},
|
'pos': 2, 'len': 3},
|
||||||
{'op': ragged.check_numerics,
|
{'op': array_ops.check_numerics,
|
||||||
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
|
'x': ragged.constant_value([[-2.0, 3.0], [-3.0]]),
|
||||||
'message': 'check-numerics'},
|
'message': 'check-numerics'},
|
||||||
]
|
]
|
||||||
) # pyformat: disable
|
) # pyformat: disable
|
||||||
def testUnaryOp(self, x, op=ragged.abs, **extra_args):
|
def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args):
|
||||||
x = ragged.convert_to_tensor_or_ragged_tensor(x)
|
x = ragged.convert_to_tensor_or_ragged_tensor(x)
|
||||||
result = op(x, **extra_args)
|
result = op(x, **extra_args)
|
||||||
|
|
||||||
# Run the wrapped op on the dense values, for comparison.
|
# Run the wrapped op on the dense values, for comparison.
|
||||||
dense_x = x.inner_values if isinstance(x, ragged.RaggedTensor) else x
|
dense_x = x.inner_values if isinstance(x, ragged.RaggedTensor) else x
|
||||||
expected_flat_values = array_ops.reshape(
|
expected_flat_values = array_ops.reshape(op(dense_x, **extra_args), [-1])
|
||||||
op.__wrapped__(dense_x, **extra_args), [-1])
|
|
||||||
|
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
# Check that the result has the expected shape.
|
# Check that the result has the expected shape.
|
||||||
@ -285,12 +286,17 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
|||||||
#=====================================================================
|
#=====================================================================
|
||||||
{'x': ragged.constant_value([[[1, 2], [3], [4]], [[], [5, 7, 8]]]),
|
{'x': ragged.constant_value([[[1, 2], [3], [4]], [[], [5, 7, 8]]]),
|
||||||
'y': ragged.constant_value([[[3, 8], [2], [5]], [[], [1, 9, 8]]]),
|
'y': ragged.constant_value([[[3, 8], [2], [5]], [[], [1, 9, 8]]]),
|
||||||
'use_kwargs': True},
|
'use_kwargs': ('x', 'y')},
|
||||||
{'x': ragged.constant_value([[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
|
{'x': ragged.constant_value([[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
|
||||||
ragged_rank=1),
|
ragged_rank=1),
|
||||||
'y': ragged.constant_value([[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
|
'y': ragged.constant_value([[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
|
||||||
ragged_rank=1),
|
ragged_rank=1),
|
||||||
'use_kwargs': True},
|
'use_kwargs': ('x', 'y')},
|
||||||
|
{'x': ragged.constant_value([[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
|
||||||
|
ragged_rank=1),
|
||||||
|
'y': ragged.constant_value([[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
|
||||||
|
ragged_rank=1),
|
||||||
|
'use_kwargs': ('x',)},
|
||||||
] +
|
] +
|
||||||
#=========================================================================
|
#=========================================================================
|
||||||
# Test each unary op.
|
# Test each unary op.
|
||||||
@ -306,16 +312,16 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
|||||||
[{'x': ragged.constant_value([[True, True], [False]]),
|
[{'x': ragged.constant_value([[True, True], [False]]),
|
||||||
'y': ragged.constant_value([[False, True], [False]]),
|
'y': ragged.constant_value([[False, True], [False]]),
|
||||||
'op': op}
|
'op': op}
|
||||||
for op in BINARY_BOOL_OPS] +
|
for op in BINARY_BOOL_OPS]
|
||||||
[
|
|
||||||
]
|
|
||||||
) # pyformat: disable
|
) # pyformat: disable
|
||||||
def testBinaryOp(self, x, y, op=ragged.add, **extra_args):
|
def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args):
|
||||||
use_kwargs = extra_args.pop('use_kwargs', False)
|
use_kwargs = extra_args.pop('use_kwargs', ())
|
||||||
x = ragged.convert_to_tensor_or_ragged_tensor(x)
|
x = ragged.convert_to_tensor_or_ragged_tensor(x)
|
||||||
y = ragged.convert_to_tensor_or_ragged_tensor(y)
|
y = ragged.convert_to_tensor_or_ragged_tensor(y)
|
||||||
if use_kwargs:
|
if 'x' in use_kwargs and 'y' in use_kwargs:
|
||||||
result = op(x=x, y=y, **extra_args)
|
result = op(x=x, y=y, **extra_args)
|
||||||
|
elif 'y' in use_kwargs:
|
||||||
|
result = op(x, y=y, **extra_args)
|
||||||
else:
|
else:
|
||||||
result = op(x, y, **extra_args)
|
result = op(x, y, **extra_args)
|
||||||
|
|
||||||
@ -323,7 +329,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
|||||||
dense_x = x.inner_values if isinstance(x, ragged.RaggedTensor) else x
|
dense_x = x.inner_values if isinstance(x, ragged.RaggedTensor) else x
|
||||||
dense_y = y.inner_values if isinstance(y, ragged.RaggedTensor) else y
|
dense_y = y.inner_values if isinstance(y, ragged.RaggedTensor) else y
|
||||||
expected_flat_values = array_ops.reshape(
|
expected_flat_values = array_ops.reshape(
|
||||||
op.__wrapped__(dense_x, dense_y, **extra_args), [-1])
|
op(dense_x, dense_y, **extra_args), [-1])
|
||||||
|
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
# Check that the result has the expected shape.
|
# Check that the result has the expected shape.
|
||||||
@ -358,16 +364,17 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
|||||||
ragged.constant_value([[[2, 9], [12]], [[8]]])),
|
ragged.constant_value([[[2, 9], [12]], [[8]]])),
|
||||||
'use_kwargs': True},
|
'use_kwargs': True},
|
||||||
] + [
|
] + [
|
||||||
{'op': ragged.add_n,
|
{'op': math_ops.add_n,
|
||||||
'inputs': (ragged.constant_value([[1, 3], [-3]]),
|
'inputs': (ragged.constant_value([[1, 3], [-3]]),
|
||||||
ragged.constant_value([[4, 7], [88]]),
|
ragged.constant_value([[4, 7], [88]]),
|
||||||
ragged.constant_value([[2, 9], [12]]))},
|
ragged.constant_value([[2, 9], [12]]))},
|
||||||
{'op': ragged.string_join,
|
{'op': string_ops.string_join,
|
||||||
'inputs': (ragged.constant_value([['a', 'b'], ['c']]),
|
'inputs': (ragged.constant_value([['a', 'b'], ['c']]),
|
||||||
ragged.constant_value([['foo', 'bar'], ['baz']]),
|
ragged.constant_value([['foo', 'bar'], ['baz']]),
|
||||||
ragged.constant_value([['2', '9'], ['12']]))},
|
ragged.constant_value([['2', '9'], ['12']]))},
|
||||||
]) # pyformat: disable
|
]) # pyformat: disable
|
||||||
def testListValuedOp(self, inputs, op=ragged.add_n, **extra_args):
|
def testListValuedElementwiseOp(self, inputs, op=math_ops.add_n,
|
||||||
|
**extra_args):
|
||||||
use_kwargs = extra_args.pop('use_kwargs', False)
|
use_kwargs = extra_args.pop('use_kwargs', False)
|
||||||
inputs = [ragged.convert_to_tensor_or_ragged_tensor(x) for x in inputs]
|
inputs = [ragged.convert_to_tensor_or_ragged_tensor(x) for x in inputs]
|
||||||
if use_kwargs:
|
if use_kwargs:
|
||||||
@ -381,7 +388,7 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
|||||||
for x in inputs
|
for x in inputs
|
||||||
]
|
]
|
||||||
expected_flat_values = array_ops.reshape(
|
expected_flat_values = array_ops.reshape(
|
||||||
op.__wrapped__(dense_inputs, **extra_args), [-1])
|
op(dense_inputs, **extra_args), [-1])
|
||||||
|
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
# Check that the result has the expected shape.
|
# Check that the result has the expected shape.
|
||||||
@ -395,13 +402,13 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
|||||||
self.assertAllEqual(expected_flat_values, result_flat_values)
|
self.assertAllEqual(expected_flat_values, result_flat_values)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testUnknownRankError(self):
|
def testElementwiseOpUnknownRankError(self):
|
||||||
x = ragged.constant([[1, 2], [3]])
|
x = ragged.constant([[1, 2], [3]])
|
||||||
y = ragged.from_row_splits(
|
y = ragged.from_row_splits(
|
||||||
array_ops.placeholder_with_default([1, 2, 3], shape=None), x.row_splits)
|
array_ops.placeholder_with_default([1, 2, 3], shape=None), x.row_splits)
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(ValueError,
|
||||||
ValueError, r'Unable to broadcast: unknown rank'):
|
r'Unable to broadcast: unknown rank'):
|
||||||
ragged.add(x, y)
|
math_ops.add(x, y)
|
||||||
|
|
||||||
@parameterized.parameters([
|
@parameterized.parameters([
|
||||||
dict(
|
dict(
|
||||||
@ -417,26 +424,31 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
|||||||
y=ragged.constant_value([[1]]),
|
y=ragged.constant_value([[1]]),
|
||||||
expected=[[[2]]]),
|
expected=[[[2]]]),
|
||||||
])
|
])
|
||||||
def testBroadcastAdd(self, x, y, expected):
|
def testElementwiseOpBroadcast(self, x, y, expected):
|
||||||
x = ragged.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32)
|
x = ragged.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32)
|
||||||
y = ragged.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32)
|
y = ragged.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32)
|
||||||
result = x + y
|
result = x + y
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertEqual(result.eval().tolist(), expected)
|
self.assertEqual(result.eval().tolist(), expected)
|
||||||
|
|
||||||
def testShapeMismatch(self):
|
def testElementwiseOpShapeMismatch(self):
|
||||||
x = ragged.constant([[1, 2, 3], [4, 5]])
|
x = ragged.constant([[1, 2, 3], [4, 5]])
|
||||||
y = ragged.constant([[1, 2, 3], [4, 5, 6]])
|
y = ragged.constant([[1, 2, 3], [4, 5, 6]])
|
||||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||||
'Incompatible shapes'):
|
'Incompatible shapes'):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
ragged.add(x, y).eval()
|
math_ops.add(x, y).eval()
|
||||||
|
|
||||||
def testDocstring(self):
|
def testBinaryOpSparseAndRagged(self):
|
||||||
self.assertRegexpMatches(
|
x = ragged.constant([[1, 2, 3], [4, 5]])
|
||||||
ragged.add.__doc__,
|
y = sparse_tensor.SparseTensor([[0, 0], [0, 1], [2, 0]], [1, 2, 3], [3, 2])
|
||||||
'Ragged version of the elementwise operation `tf.math.add`')
|
with self.assertRaises(TypeError):
|
||||||
self.assertEqual(ragged.add.__name__, 'add')
|
with self.cached_session():
|
||||||
|
math_ops.add(x, y).eval()
|
||||||
|
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
with self.cached_session():
|
||||||
|
math_ops.add_n([x, y]).eval()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
@ -1,389 +0,0 @@
|
|||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Elementwise operations for RaggedTensors."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import collections
|
|
||||||
|
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
from tensorflow.python.ops import array_ops
|
|
||||||
from tensorflow.python.ops import clip_ops
|
|
||||||
from tensorflow.python.ops import math_ops
|
|
||||||
from tensorflow.python.ops import parsing_ops
|
|
||||||
from tensorflow.python.ops import string_ops
|
|
||||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor_shape
|
|
||||||
from tensorflow.python.util import tf_decorator
|
|
||||||
from tensorflow.python.util import tf_export
|
|
||||||
from tensorflow.python.util import tf_inspect
|
|
||||||
|
|
||||||
# Information about an argument to an operation: The name of the argument, its
|
|
||||||
# position in the argument list, and a boolean flag indicating whether it
|
|
||||||
# expects a list of tensors.
|
|
||||||
_ArgInfo = collections.namedtuple('ArgInfo', ['name', 'position', 'is_list'])
|
|
||||||
|
|
||||||
|
|
||||||
def make_elementwise_op(op, *elementwise_args):
|
|
||||||
"""Returns a ragged-tensor version of the elementwise operation `op`.
|
|
||||||
|
|
||||||
The returned operation will:
|
|
||||||
|
|
||||||
1. Broadcast the elementwise arguments to have a compatible shape.
|
|
||||||
An exception is raised if the tensors not broadcast-compatible.
|
|
||||||
2. Call `op`, substituting the dense values of the broadcasted tensor for
|
|
||||||
each elementwise argument.
|
|
||||||
3. Return a potentially ragged tensor constructed from the output of `op`
|
|
||||||
and the broadcasted tensors' nested row splits.
|
|
||||||
|
|
||||||
For example, you can construct a ragged-tensor version of the standard
|
|
||||||
operation `tf.add` by calling `make_elementwise_op(tf.add, 'x', 'y')`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
op: The operation to wrap.
|
|
||||||
*elementwise_args: The names of arguments to `op` that are treated as
|
|
||||||
elementwise. Arguments that take a list of tensors should have their
|
|
||||||
names wrapped in square brackets (e.g. "[inputs]").
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If any name specified in `elementwise_args` is not the name
|
|
||||||
of an argument to `op`.
|
|
||||||
"""
|
|
||||||
elementwise_arg_infos = _get_arg_infos(op, elementwise_args)
|
|
||||||
|
|
||||||
def ragged_op(*args, **kwargs):
|
|
||||||
"""Ragged version of `op`."""
|
|
||||||
args = list(args)
|
|
||||||
|
|
||||||
# Collect all of the elementwise arguments, and put them in a single
|
|
||||||
# dict whose values are the (potentially ragged) tensors that need to
|
|
||||||
# be broadcast to a common shape. The keys of this dict are tuples
|
|
||||||
# (argkey, index), where argkey is an int for poitional args or a string
|
|
||||||
# for keyword args; and index is None for non-list args and the index of the
|
|
||||||
# tensor for list args.
|
|
||||||
elementwise_args = {}
|
|
||||||
for (name, position, is_list) in elementwise_arg_infos.values():
|
|
||||||
if position < len(args):
|
|
||||||
if is_list:
|
|
||||||
args[position] = list(args[position])
|
|
||||||
for (index, arg) in enumerate(args[position]):
|
|
||||||
elementwise_args[position, index] = arg
|
|
||||||
else:
|
|
||||||
elementwise_args[position, None] = args[position]
|
|
||||||
elif name in kwargs:
|
|
||||||
if is_list:
|
|
||||||
kwargs[name] = list(kwargs[name])
|
|
||||||
for (i, arg) in enumerate(kwargs[name]):
|
|
||||||
elementwise_args[name, i] = arg
|
|
||||||
else:
|
|
||||||
elementwise_args[name, None] = kwargs[name]
|
|
||||||
|
|
||||||
with ops.name_scope(None, op.__name__, elementwise_args.values()):
|
|
||||||
# Convert all inputs to tensors or ragged tensors.
|
|
||||||
for ((key, index), tensor) in elementwise_args.items():
|
|
||||||
argname = elementwise_arg_infos[key].name
|
|
||||||
converted = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
|
|
||||||
tensor, name=argname)
|
|
||||||
elementwise_args[key, index] = converted
|
|
||||||
|
|
||||||
# Broadcast tensors to have compatible shapes.
|
|
||||||
broadcast_args, result_splits, broadcast_check_ops = \
|
|
||||||
_broadcast_elementwise_args(elementwise_args)
|
|
||||||
|
|
||||||
# Replace tensor arguments with their dense values.
|
|
||||||
for ((key, index), tensor) in broadcast_args.items():
|
|
||||||
if ragged_tensor.is_ragged(tensor):
|
|
||||||
if isinstance(key, int) and index is None:
|
|
||||||
args[key] = tensor.inner_values
|
|
||||||
elif isinstance(key, int) and index is not None:
|
|
||||||
args[key][index] = tensor.inner_values
|
|
||||||
elif isinstance(key, str) and index is None:
|
|
||||||
kwargs[key] = tensor.inner_values
|
|
||||||
else:
|
|
||||||
assert isinstance(key, str) and index is not None
|
|
||||||
kwargs[key][index] = tensor.inner_values
|
|
||||||
|
|
||||||
# Call the elementwise op on the broadcasted dense values.
|
|
||||||
with ops.control_dependencies(broadcast_check_ops):
|
|
||||||
result_values = op(*args, **kwargs)
|
|
||||||
|
|
||||||
# Restore any ragged dimensions that we stripped off, and return the
|
|
||||||
# result.
|
|
||||||
return ragged_factory_ops.from_nested_row_splits(result_values,
|
|
||||||
result_splits)
|
|
||||||
|
|
||||||
# Construct the docstring.
|
|
||||||
op_name = tf_export.get_canonical_name_for_symbol(op)
|
|
||||||
assert op_name is not None, op
|
|
||||||
argnames = ', '.join('`%s`' % s.strip('[]') for s in elementwise_args)
|
|
||||||
docstring = _ELEMENTWISE_DOCSTRING % dict(op_name=op_name, argnames=argnames)
|
|
||||||
|
|
||||||
# Update name, docstring, signature, etc., for the wrapper, and return it.
|
|
||||||
return tf_decorator.make_decorator(op, ragged_op, decorator_doc=docstring)
|
|
||||||
|
|
||||||
|
|
||||||
_ELEMENTWISE_DOCSTRING = """\
|
|
||||||
Ragged version of the elementwise operation `tf.%(op_name)s`.
|
|
||||||
|
|
||||||
The following elementwise arguments may be ragged or dense:
|
|
||||||
%(argnames)s.
|
|
||||||
These arguments will be broadcast to a compatible shape if necessary.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def _get_arg_infos(func, elementwise_args):
|
|
||||||
"""Returns `_ArgInfo`s for each `func` arg specified by `elementwise_args`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func: The function whose arguments should be described.
|
|
||||||
elementwise_args: The names of the arguments to get info for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary that maps both names and positions of arguments to
|
|
||||||
`_ArgInfo` tuples.
|
|
||||||
"""
|
|
||||||
arg_infos = {}
|
|
||||||
|
|
||||||
# Inspect the func's argspec to find the position of each arg.
|
|
||||||
arg_spec = tf_inspect.getargspec(func)
|
|
||||||
for argname in elementwise_args:
|
|
||||||
assert isinstance(argname, str)
|
|
||||||
is_list = argname.startswith('[') and argname.endswith(']')
|
|
||||||
if is_list:
|
|
||||||
argname = argname[1:-1]
|
|
||||||
assert argname in arg_spec.args, (func, argname, arg_spec.args)
|
|
||||||
arg_info = _ArgInfo(argname, arg_spec.args.index(argname), is_list)
|
|
||||||
arg_infos[arg_info.name] = arg_info
|
|
||||||
arg_infos[arg_info.position] = arg_info
|
|
||||||
return arg_infos
|
|
||||||
|
|
||||||
|
|
||||||
def _broadcast_elementwise_args(elementwise_args):
|
|
||||||
"""Broadcasts the values of `elementwise_args` to have compatible shapes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
elementwise_args: A dictionary whose keys are potentially ragged tensors.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple `(broadcast_args, broadcast_splits, checks)` where:
|
|
||||||
|
|
||||||
* `broadcast_args` is a dictionary with the same keys as
|
|
||||||
`elementwise_args`, mapping to broadcasted tensors.
|
|
||||||
* `broadcast_splits` is the broadcasted nested row splits.
|
|
||||||
* `checks` is a possibly empty tuple of assertion operations that should
|
|
||||||
be added as control dependencies.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If broadcasting fails.
|
|
||||||
"""
|
|
||||||
# No elementwise arguments were used: nothing to do!
|
|
||||||
if not elementwise_args:
|
|
||||||
return elementwise_args, (), ()
|
|
||||||
|
|
||||||
# A single elementwise argument was used: no broadcasting necessary.
|
|
||||||
if len(elementwise_args) == 1:
|
|
||||||
arg = list(elementwise_args.values())[0]
|
|
||||||
if ragged_tensor.is_ragged(arg):
|
|
||||||
return elementwise_args, arg.nested_row_splits, ()
|
|
||||||
else:
|
|
||||||
return elementwise_args, (), ()
|
|
||||||
|
|
||||||
# Multiple elementwise arguments.
|
|
||||||
else:
|
|
||||||
is_ragged = [ragged_tensor.is_ragged(t) for t in elementwise_args.values()]
|
|
||||||
if not any(is_ragged):
|
|
||||||
return elementwise_args, (), ()
|
|
||||||
|
|
||||||
# If we have a single ragged tensor plus a set of scalars, then we can
|
|
||||||
# rely on the underlying elementwise op to do broadcasting.
|
|
||||||
if (sum(is_ragged) == 1 and
|
|
||||||
all((ragged_tensor.is_ragged(t) or t.shape.ndims == 0)
|
|
||||||
for t in elementwise_args.values())):
|
|
||||||
nested_splits_lists = [
|
|
||||||
t.nested_row_splits
|
|
||||||
for t in elementwise_args.values()
|
|
||||||
if ragged_tensor.is_ragged(t)][0]
|
|
||||||
return elementwise_args, nested_splits_lists, ()
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Get the shapes of all the elementwise arguments.
|
|
||||||
shapes = [ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(t)
|
|
||||||
for t in elementwise_args.values()]
|
|
||||||
|
|
||||||
# Broadcast the shapes to all have the same rank (the max rank).
|
|
||||||
ranks = [t.shape.ndims for t in elementwise_args.values()]
|
|
||||||
if any(rank is None for rank in ranks):
|
|
||||||
raise ValueError('Unable to broadcast: unknown rank')
|
|
||||||
broadcast_rank = max(ranks)
|
|
||||||
shapes = [shape.broadcast_to_rank(broadcast_rank) for shape in shapes]
|
|
||||||
|
|
||||||
# For each dimension, broadcast the shapes to be compatible.
|
|
||||||
for axis in range(broadcast_rank):
|
|
||||||
# For each i, broadcast shape[i+1] to be compatible with shape[i]; and
|
|
||||||
# then finally broadcast shape[0] to be compatible with shape[-1].
|
|
||||||
for i in range(len(shapes)):
|
|
||||||
j = (i + 1) % len(shapes)
|
|
||||||
dim_size = shapes[i].dimension_size(axis)
|
|
||||||
shapes[j] = shapes[j].broadcast_dimension(axis, dim_size)
|
|
||||||
broadcast_shape = shapes[0]
|
|
||||||
|
|
||||||
# Broadcast every elementwise arg to the shape that we calculated.
|
|
||||||
elementwise_args = dict([
|
|
||||||
(key, ragged_tensor_shape.broadcast_to(t, broadcast_shape, False))
|
|
||||||
for (key, t) in elementwise_args.items()])
|
|
||||||
nested_splits_lists = list(elementwise_args.values())[0].nested_row_splits
|
|
||||||
return elementwise_args, nested_splits_lists, ()
|
|
||||||
|
|
||||||
|
|
||||||
# A list of symbols that should be exported in the "ragged" package.
|
|
||||||
_symbols_to_export = []
|
|
||||||
|
|
||||||
|
|
||||||
def _add_elementwise_ops_to_this_module(specs, verbose=False):
|
|
||||||
"""Adds ragged versions of the given ops to this module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
specs: A list of tuples containing the arguments for `make_elementwise_op`.
|
|
||||||
verbose: If true, then display each op that gets added.
|
|
||||||
"""
|
|
||||||
for spec in specs:
|
|
||||||
original_op = spec[0]
|
|
||||||
ragged_op = make_elementwise_op(*spec)
|
|
||||||
canonical_name = tf_export.get_canonical_name_for_symbol(original_op)
|
|
||||||
if '.' not in canonical_name:
|
|
||||||
op_name = canonical_name
|
|
||||||
else:
|
|
||||||
op_name = original_op.__name__
|
|
||||||
|
|
||||||
# Temporary hack (will be removed once dispatch is added for RaggedTensors):
|
|
||||||
if op_name == 'neg': op_name = 'negative'
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print('Adding ragged_elementwise_op: tf.ragged.%s (based on tf.%s)' %
|
|
||||||
(op_name, canonical_name))
|
|
||||||
globals()[op_name] = ragged_op
|
|
||||||
_symbols_to_export.append(op_name)
|
|
||||||
|
|
||||||
|
|
||||||
# A list of tuples containing arguments for `make_elementwise_op`, for each
|
|
||||||
# elementwise operation that should have a ragged version built. Each tuple
|
|
||||||
# contains a standard `Tensor` operation, and the names of any arguments
|
|
||||||
# that are processed in elementwise fashion.
|
|
||||||
_TF_ELEMENTWISE_OPS = [
|
|
||||||
# Unary math operations.
|
|
||||||
(clip_ops.clip_by_value, 't'),
|
|
||||||
(math_ops.abs, 'x'),
|
|
||||||
(math_ops.acos, 'x'),
|
|
||||||
(math_ops.acosh, 'x'),
|
|
||||||
(math_ops.angle, 'input'),
|
|
||||||
(math_ops.asin, 'x'),
|
|
||||||
(math_ops.asinh, 'x'),
|
|
||||||
(math_ops.atan, 'x'),
|
|
||||||
(math_ops.atanh, 'x'),
|
|
||||||
(math_ops.cast, 'x'),
|
|
||||||
(math_ops.ceil, 'x'),
|
|
||||||
(math_ops.conj, 'x'),
|
|
||||||
(math_ops.cos, 'x'),
|
|
||||||
(math_ops.cosh, 'x'),
|
|
||||||
(math_ops.digamma, 'x'),
|
|
||||||
(math_ops.erf, 'x'),
|
|
||||||
(math_ops.erfc, 'x'),
|
|
||||||
(math_ops.exp, 'x'),
|
|
||||||
(math_ops.expm1, 'x'),
|
|
||||||
(math_ops.floor, 'x'),
|
|
||||||
(math_ops.imag, 'input'),
|
|
||||||
(math_ops.is_finite, 'x'),
|
|
||||||
(math_ops.is_inf, 'x'),
|
|
||||||
(math_ops.is_nan, 'x'),
|
|
||||||
(math_ops.lgamma, 'x'),
|
|
||||||
(math_ops.log, 'x'),
|
|
||||||
(math_ops.log1p, 'x'),
|
|
||||||
(math_ops.log_sigmoid, 'x'),
|
|
||||||
(math_ops.logical_not, 'x'),
|
|
||||||
(math_ops.negative, 'x'),
|
|
||||||
(math_ops.real, 'input'),
|
|
||||||
(math_ops.reciprocal, 'x'),
|
|
||||||
(math_ops.rint, 'x'),
|
|
||||||
(math_ops.round, 'x'),
|
|
||||||
(math_ops.rsqrt, 'x'),
|
|
||||||
(math_ops.saturate_cast, 'value'),
|
|
||||||
(math_ops.sign, 'x'),
|
|
||||||
(math_ops.sin, 'x'),
|
|
||||||
(math_ops.sinh, 'x'),
|
|
||||||
(math_ops.sqrt, 'x'),
|
|
||||||
(math_ops.square, 'x'),
|
|
||||||
(math_ops.tan, 'x'),
|
|
||||||
|
|
||||||
# Binary math operations
|
|
||||||
(math_ops.add, 'x', 'y'),
|
|
||||||
(math_ops.atan2, 'y', 'x'),
|
|
||||||
(math_ops.complex, 'real', 'imag'),
|
|
||||||
(math_ops.div, 'x', 'y'),
|
|
||||||
(math_ops.div_no_nan, 'x', 'y'),
|
|
||||||
(math_ops.divide, 'x', 'y'),
|
|
||||||
(math_ops.equal, 'x', 'y'),
|
|
||||||
(math_ops.floordiv, 'x', 'y'),
|
|
||||||
(math_ops.floormod, 'x', 'y'),
|
|
||||||
(math_ops.greater, 'x', 'y'),
|
|
||||||
(math_ops.greater_equal, 'x', 'y'),
|
|
||||||
(math_ops.less, 'x', 'y'),
|
|
||||||
(math_ops.less_equal, 'x', 'y'),
|
|
||||||
(math_ops.logical_and, 'x', 'y'),
|
|
||||||
(math_ops.logical_or, 'x', 'y'),
|
|
||||||
(math_ops.logical_xor, 'x', 'y'),
|
|
||||||
(math_ops.maximum, 'x', 'y'),
|
|
||||||
(math_ops.minimum, 'x', 'y'),
|
|
||||||
(math_ops.multiply, 'x', 'y'),
|
|
||||||
(math_ops.not_equal, 'x', 'y'),
|
|
||||||
(math_ops.pow, 'x', 'y'),
|
|
||||||
(math_ops.realdiv, 'x', 'y'),
|
|
||||||
(math_ops.squared_difference, 'x', 'y'),
|
|
||||||
(math_ops.subtract, 'x', 'y'),
|
|
||||||
(math_ops.truediv, 'x', 'y'),
|
|
||||||
(math_ops.truncatediv, 'x', 'y'),
|
|
||||||
(math_ops.truncatemod, 'x', 'y'),
|
|
||||||
|
|
||||||
# N-ary math operations
|
|
||||||
(math_ops.add_n, '[inputs]'),
|
|
||||||
|
|
||||||
# String operations
|
|
||||||
(string_ops.as_string, 'input'),
|
|
||||||
(string_ops.decode_base64, 'input'),
|
|
||||||
(string_ops.encode_base64, 'input'),
|
|
||||||
(string_ops.regex_full_match, 'input'),
|
|
||||||
(string_ops.regex_replace, 'input'),
|
|
||||||
(string_ops.string_join, '[inputs]'),
|
|
||||||
(string_ops.string_strip, 'input'),
|
|
||||||
(string_ops.string_to_hash_bucket, 'input'),
|
|
||||||
(string_ops.string_to_hash_bucket_fast, 'input'),
|
|
||||||
(string_ops.string_to_hash_bucket_strong, 'input'),
|
|
||||||
(string_ops.substr, 'input'),
|
|
||||||
(string_ops.unicode_script, 'input'),
|
|
||||||
|
|
||||||
# Array ops
|
|
||||||
(array_ops.check_numerics, 'tensor'),
|
|
||||||
(array_ops.identity, 'input'),
|
|
||||||
(array_ops.ones_like, 'tensor'),
|
|
||||||
(array_ops.zeros_like, 'tensor'),
|
|
||||||
|
|
||||||
# Parsing ops
|
|
||||||
(parsing_ops.decode_compressed, 'bytes'),
|
|
||||||
(parsing_ops.string_to_number, 'string_tensor'),
|
|
||||||
]
|
|
||||||
_add_elementwise_ops_to_this_module(_TF_ELEMENTWISE_OPS)
|
|
||||||
|
|
@ -18,6 +18,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
@ -55,7 +56,7 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
),
|
),
|
||||||
# [d1, (d2)] -> [d1, (d2)]
|
# [d1, (d2)] -> [d1, (d2)]
|
||||||
dict(
|
dict(
|
||||||
fn=lambda x: x+1,
|
fn=lambda x: x + np.int64(1),
|
||||||
elems=[[1, 2, 3], [4, 5], [6, 7]],
|
elems=[[1, 2, 3], [4, 5], [6, 7]],
|
||||||
expected_output=[[2, 3, 4], [5, 6], [7, 8]],
|
expected_output=[[2, 3, 4], [5, 6], [7, 8]],
|
||||||
dtype=dtypes.int64,
|
dtype=dtypes.int64,
|
||||||
@ -64,7 +65,7 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
),
|
),
|
||||||
# [d1, (d2), d3] -> [d1, (d2), d3]
|
# [d1, (d2), d3] -> [d1, (d2), d3]
|
||||||
dict(
|
dict(
|
||||||
fn=lambda x: x+1,
|
fn=lambda x: x + np.int64(1),
|
||||||
elems=[[[1, 2], [3, 4]], [], [[5, 6], [7, 8], [9, 0]]],
|
elems=[[[1, 2], [3, 4]], [], [[5, 6], [7, 8], [9, 0]]],
|
||||||
elems_ragged_rank=1,
|
elems_ragged_rank=1,
|
||||||
expected_ragged_rank=1,
|
expected_ragged_rank=1,
|
||||||
@ -131,7 +132,7 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
),
|
),
|
||||||
# [d1, (d2), (d3), (d4a), (d5)] -> [d1, (d2), (d3), (d4b), (d5)]
|
# [d1, (d2), (d3), (d4a), (d5)] -> [d1, (d2), (d3), (d4b), (d5)]
|
||||||
dict(
|
dict(
|
||||||
fn=lambda x: ragged.add(x, 1),
|
fn=lambda x: x + np.int64(1),
|
||||||
elems=[[[[[1, 2, 3]], [[4], [5]]]], [[[[6, 7]]], [[[8], []]]]],
|
elems=[[[[[1, 2, 3]], [[4], [5]]]], [[[[6, 7]]], [[[8], []]]]],
|
||||||
expected_output=[[[[[2, 3, 4]], [[5], [6]]]],
|
expected_output=[[[[[2, 3, 4]], [[5], [6]]]],
|
||||||
[[[[7, 8]]], [[[9], []]]]],
|
[[[[7, 8]]], [[[9], []]]]],
|
||||||
@ -196,8 +197,8 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
def _increment(f):
|
def _increment(f):
|
||||||
return {
|
return {
|
||||||
'batman': ragged.add(f['batman'], 1),
|
'batman': f['batman'] + 1,
|
||||||
'robin': ragged.add(f['robin'], 1),
|
'robin': f['robin'] + 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
output = ragged.map_fn(
|
output = ragged.map_fn(
|
||||||
|
@ -143,8 +143,11 @@ Computes the %(combination)s along segments of a RaggedTensor.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _ragged_segment_aggregate(unsorted_segment_op, data, segment_ids,
|
def _ragged_segment_aggregate(unsorted_segment_op,
|
||||||
num_segments, name=None):
|
data,
|
||||||
|
segment_ids,
|
||||||
|
num_segments,
|
||||||
|
name=None):
|
||||||
"""Aggregates along segments of a RaggedTensor using `unsorted_segment_op`.
|
"""Aggregates along segments of a RaggedTensor using `unsorted_segment_op`.
|
||||||
|
|
||||||
Returns a RaggedTensor `output` with `num_segments` rows, where the row
|
Returns a RaggedTensor `output` with `num_segments` rows, where the row
|
||||||
@ -212,8 +215,7 @@ def _ragged_segment_aggregate(unsorted_segment_op, data, segment_ids,
|
|||||||
assert output_row_lengths.dtype == dtypes.int64
|
assert output_row_lengths.dtype == dtypes.int64
|
||||||
|
|
||||||
# Build the splits tensor for the output RaggedTensor.
|
# Build the splits tensor for the output RaggedTensor.
|
||||||
output_splits = array_ops.concat(
|
output_splits = array_ops.concat([
|
||||||
[
|
|
||||||
array_ops.zeros([1], dtypes.int64),
|
array_ops.zeros([1], dtypes.int64),
|
||||||
math_ops.cumsum(output_row_lengths)
|
math_ops.cumsum(output_row_lengths)
|
||||||
],
|
],
|
||||||
@ -311,7 +313,7 @@ _set_ragged_segment_docstring(segment_sqrt_n, 'sum divided by sqrt(N)',
|
|||||||
_RAGGED_REDUCE_DOCSTRING = """\
|
_RAGGED_REDUCE_DOCSTRING = """\
|
||||||
Computes the %(combination)s of elements across dimensions of a `RaggedTensor`.
|
Computes the %(combination)s of elements across dimensions of a `RaggedTensor`.
|
||||||
|
|
||||||
Reduces `rt_input` along the dimensions given in `axis` by taking the
|
Reduces `input_tensor` along the dimensions given in `axis` by taking the
|
||||||
%(combination)s of values. If a reduced dimension has no elements for
|
%(combination)s of values. If a reduced dimension has no elements for
|
||||||
some index, then the value for that index will be %(default)s.
|
some index, then the value for that index will be %(default)s.
|
||||||
|
|
||||||
@ -319,18 +321,18 @@ Computes the %(combination)s of elements across dimensions of a `RaggedTensor`.
|
|||||||
`axis` is not specified, then all dimensions are reduced, and a scalar
|
`axis` is not specified, then all dimensions are reduced, and a scalar
|
||||||
value is returned.
|
value is returned.
|
||||||
Args:
|
Args:
|
||||||
rt_input: A `RaggedTensor` containing the values to be %(combined)s.
|
input_tensor: A `RaggedTensor` containing the values to be %(combined)s.
|
||||||
axis: The dimensions to reduce. May be `None` (to reduce all axes), an
|
axis: The dimensions to reduce. May be `None` (to reduce all axes), an
|
||||||
`int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce
|
`int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce
|
||||||
a given set of axes), or a `Tensor` with a constant value. Must be in
|
a given set of axes), or a `Tensor` with a constant value. Must be in
|
||||||
the range `[0, rt_input.rank]`.
|
the range `[0, input_tensor.rank]`.
|
||||||
name: A name prefix for the returned tensor (optional).
|
name: A name prefix for the returned tensor (optional).
|
||||||
Returns:
|
Returns:
|
||||||
A `RaggedTensor` containing the %(combined)s values. The returned tensor
|
A `RaggedTensor` containing the %(combined)s values. The returned tensor
|
||||||
has the same dtype as `data`, and its shape is given by removing the
|
has the same dtype as `data`, and its shape is given by removing the
|
||||||
dimensions specified in `axis` from `rt_input.shape`. The `ragged_rank`
|
dimensions specified in `axis` from `input_tensor.shape`. The `ragged_rank`
|
||||||
of the returned tensor is given by substracting any ragged dimensions
|
of the returned tensor is given by substracting any ragged dimensions
|
||||||
specified in `axis` from `rt_input.ragged_rank`.
|
specified in `axis` from `input_tensor.ragged_rank`.
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If `axis` contains a `Tensor` whose value is not constant.
|
ValueError: If `axis` contains a `Tensor` whose value is not constant.
|
||||||
####Example:
|
####Example:
|
||||||
@ -387,7 +389,11 @@ _RAGGED_REDUCE_ANY_EXAMPLE = """
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis,
|
def _ragged_reduce_aggregate(reduce_op,
|
||||||
|
unsorted_segment_op,
|
||||||
|
rt_input,
|
||||||
|
axis,
|
||||||
|
keepdims,
|
||||||
name=None):
|
name=None):
|
||||||
"""Aggregates across axes of a RaggedTensor using the given `Tensor` ops.
|
"""Aggregates across axes of a RaggedTensor using the given `Tensor` ops.
|
||||||
|
|
||||||
@ -412,6 +418,7 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis,
|
|||||||
`int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce a
|
`int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce a
|
||||||
given set of axes), or a `Tensor` with a constant value. Must be in the
|
given set of axes), or a `Tensor` with a constant value. Must be in the
|
||||||
range `[0, rt_input.rank)`.
|
range `[0, rt_input.rank)`.
|
||||||
|
keepdims: If true, retains reduced dimensions with length 1.
|
||||||
name: A name prefix for the returned tensor (optional).
|
name: A name prefix for the returned tensor (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -426,6 +433,9 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis,
|
|||||||
if not ragged_tensor.is_ragged(rt_input):
|
if not ragged_tensor.is_ragged(rt_input):
|
||||||
return reduce_op(rt_input, axis, name=name)
|
return reduce_op(rt_input, axis, name=name)
|
||||||
|
|
||||||
|
if keepdims:
|
||||||
|
raise ValueError('keepdims=True is not supported for RaggedTensors.')
|
||||||
|
|
||||||
if isinstance(axis, ops.Tensor):
|
if isinstance(axis, ops.Tensor):
|
||||||
axis = tensor_util.constant_value(axis)
|
axis = tensor_util.constant_value(axis)
|
||||||
if axis is None:
|
if axis is None:
|
||||||
@ -448,9 +458,9 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis,
|
|||||||
# once will probably require a nontrivial c++ op.
|
# once will probably require a nontrivial c++ op.
|
||||||
axis = sorted(axis)
|
axis = sorted(axis)
|
||||||
inner_reduced = _ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
|
inner_reduced = _ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
|
||||||
rt_input, axis[-1])
|
rt_input, axis[-1], keepdims)
|
||||||
return _ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
|
return _ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
|
||||||
inner_reduced, axis[:-1])
|
inner_reduced, axis[:-1], keepdims)
|
||||||
|
|
||||||
axis = ragged_util.get_positive_axis(axis, rt_input.shape.ndims)
|
axis = ragged_util.get_positive_axis(axis, rt_input.shape.ndims)
|
||||||
|
|
||||||
@ -476,48 +486,48 @@ def _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, rt_input, axis,
|
|||||||
# sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N]
|
# sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N]
|
||||||
return rt_input.with_values(
|
return rt_input.with_values(
|
||||||
_ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
|
_ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
|
||||||
rt_input.values, axis - 1))
|
rt_input.values, axis - 1, keepdims))
|
||||||
|
|
||||||
|
|
||||||
def reduce_sum(rt_input, axis=None, name=None):
|
def reduce_sum(input_tensor, axis=None, keepdims=None, name=None):
|
||||||
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
||||||
return _ragged_reduce_aggregate(math_ops.reduce_sum,
|
return _ragged_reduce_aggregate(math_ops.reduce_sum,
|
||||||
math_ops.unsorted_segment_sum, rt_input, axis,
|
math_ops.unsorted_segment_sum, input_tensor,
|
||||||
name or 'RaggedReduceSum')
|
axis, keepdims, name or 'RaggedReduceSum')
|
||||||
|
|
||||||
|
|
||||||
def reduce_prod(rt_input, axis=None, name=None):
|
def reduce_prod(input_tensor, axis=None, keepdims=None, name=None):
|
||||||
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
||||||
return _ragged_reduce_aggregate(math_ops.reduce_prod,
|
return _ragged_reduce_aggregate(math_ops.reduce_prod,
|
||||||
math_ops.unsorted_segment_prod, rt_input,
|
math_ops.unsorted_segment_prod, input_tensor,
|
||||||
axis, name or 'RaggedReduceProd')
|
axis, keepdims, name or 'RaggedReduceProd')
|
||||||
|
|
||||||
|
|
||||||
def reduce_min(rt_input, axis=None, name=None):
|
def reduce_min(input_tensor, axis=None, keepdims=None, name=None):
|
||||||
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
||||||
return _ragged_reduce_aggregate(math_ops.reduce_min,
|
return _ragged_reduce_aggregate(math_ops.reduce_min,
|
||||||
math_ops.unsorted_segment_min, rt_input, axis,
|
math_ops.unsorted_segment_min, input_tensor,
|
||||||
name or 'RaggedReduceMin')
|
axis, keepdims, name or 'RaggedReduceMin')
|
||||||
|
|
||||||
|
|
||||||
def reduce_max(rt_input, axis=None, name=None):
|
def reduce_max(input_tensor, axis=None, keepdims=None, name=None):
|
||||||
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
||||||
return _ragged_reduce_aggregate(math_ops.reduce_max,
|
return _ragged_reduce_aggregate(math_ops.reduce_max,
|
||||||
math_ops.unsorted_segment_max, rt_input, axis,
|
math_ops.unsorted_segment_max, input_tensor,
|
||||||
name or 'RaggedReduceMax')
|
axis, keepdims, name or 'RaggedReduceMax')
|
||||||
|
|
||||||
|
|
||||||
def reduce_mean(rt_input, axis=None, name=None):
|
def reduce_mean(input_tensor, axis=None, keepdims=None, name=None):
|
||||||
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
||||||
with ops.name_scope(name, 'RaggedReduceMean', [rt_input, axis]):
|
with ops.name_scope(name, 'RaggedReduceMean', [input_tensor, axis]):
|
||||||
total = reduce_sum(rt_input, axis)
|
total = reduce_sum(input_tensor, axis, keepdims)
|
||||||
if ragged_tensor.is_ragged(rt_input):
|
if ragged_tensor.is_ragged(input_tensor):
|
||||||
ones = ragged_factory_ops.from_nested_row_splits(
|
ones = ragged_factory_ops.from_nested_row_splits(
|
||||||
array_ops.ones_like(rt_input.inner_values),
|
array_ops.ones_like(input_tensor.inner_values),
|
||||||
rt_input.nested_row_splits)
|
input_tensor.nested_row_splits)
|
||||||
else:
|
else:
|
||||||
ones = array_ops.ones_like(rt_input)
|
ones = array_ops.ones_like(input_tensor)
|
||||||
count = reduce_sum(ones, axis)
|
count = reduce_sum(ones, axis, keepdims)
|
||||||
if ragged_tensor.is_ragged(total):
|
if ragged_tensor.is_ragged(total):
|
||||||
return ragged_factory_ops.from_nested_row_splits(
|
return ragged_factory_ops.from_nested_row_splits(
|
||||||
total.inner_values / count.inner_values, total.nested_row_splits)
|
total.inner_values / count.inner_values, total.nested_row_splits)
|
||||||
@ -525,20 +535,25 @@ def reduce_mean(rt_input, axis=None, name=None):
|
|||||||
return total / count
|
return total / count
|
||||||
|
|
||||||
|
|
||||||
def _cast(rt_input, dtype):
|
def _cast(input_tensor, dtype):
|
||||||
return ragged_functional_ops.map_inner_values(math_ops.cast, rt_input, dtype)
|
return ragged_functional_ops.map_inner_values(math_ops.cast, input_tensor,
|
||||||
|
dtype)
|
||||||
|
|
||||||
|
|
||||||
def reduce_all(rt_input, axis=None, name=None):
|
def reduce_all(input_tensor, axis=None, keepdims=None, name=None):
|
||||||
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
||||||
with ops.name_scope(name, 'RaggedReduceAll', [rt_input, axis]):
|
with ops.name_scope(name, 'RaggedReduceAll', [input_tensor, axis]):
|
||||||
return _cast(reduce_prod(_cast(rt_input, dtypes.int32), axis), dtypes.bool)
|
return _cast(
|
||||||
|
reduce_prod(_cast(input_tensor, dtypes.int32), axis, keepdims),
|
||||||
|
dtypes.bool)
|
||||||
|
|
||||||
|
|
||||||
def reduce_any(rt_input, axis=None, name=None):
|
def reduce_any(input_tensor, axis=None, keepdims=None, name=None):
|
||||||
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
|
||||||
with ops.name_scope(name, 'RaggedReduceAny', [rt_input, axis]):
|
with ops.name_scope(name, 'RaggedReduceAny', [input_tensor, axis]):
|
||||||
return _cast(reduce_sum(_cast(rt_input, dtypes.int32), axis), dtypes.bool)
|
return _cast(
|
||||||
|
reduce_sum(_cast(input_tensor, dtypes.int32), axis, keepdims),
|
||||||
|
dtypes.bool)
|
||||||
|
|
||||||
|
|
||||||
def _set_ragged_reduce_docstring(func, combination, combined, default, example):
|
def _set_ragged_reduce_docstring(func, combination, combined, default, example):
|
||||||
@ -554,9 +569,11 @@ _set_ragged_reduce_docstring(reduce_sum, 'sum', 'summed', '0',
|
|||||||
_set_ragged_reduce_docstring(reduce_prod, 'product', 'multiplied', '1',
|
_set_ragged_reduce_docstring(reduce_prod, 'product', 'multiplied', '1',
|
||||||
_RAGGED_REDUCE_PROD_EXAMPLE)
|
_RAGGED_REDUCE_PROD_EXAMPLE)
|
||||||
_set_ragged_reduce_docstring(reduce_min, 'minimum', 'minimized',
|
_set_ragged_reduce_docstring(reduce_min, 'minimum', 'minimized',
|
||||||
'`rt_input.dtype.min`', _RAGGED_REDUCE_MIN_EXAMPLE)
|
'`input_tensor.dtype.min`',
|
||||||
|
_RAGGED_REDUCE_MIN_EXAMPLE)
|
||||||
_set_ragged_reduce_docstring(reduce_max, 'maximum', 'maximized',
|
_set_ragged_reduce_docstring(reduce_max, 'maximum', 'maximized',
|
||||||
'`rt_input.dtype.max`', _RAGGED_REDUCE_MAX_EXAMPLE)
|
'`input_tensor.dtype.max`',
|
||||||
|
_RAGGED_REDUCE_MAX_EXAMPLE)
|
||||||
_set_ragged_reduce_docstring(reduce_mean, 'mean', 'averaged', 'NaN',
|
_set_ragged_reduce_docstring(reduce_mean, 'mean', 'averaged', 'NaN',
|
||||||
_RAGGED_REDUCE_MEAN_EXAMPLE)
|
_RAGGED_REDUCE_MEAN_EXAMPLE)
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.ops.ragged import ragged_elementwise_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_getitem
|
from tensorflow.python.ops.ragged import ragged_getitem
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.util import tf_decorator
|
from tensorflow.python.util import tf_decorator
|
||||||
@ -33,40 +33,39 @@ def _right(operator):
|
|||||||
ragged_tensor.RaggedTensor.__getitem__ = ragged_getitem.ragged_tensor_getitem
|
ragged_tensor.RaggedTensor.__getitem__ = ragged_getitem.ragged_tensor_getitem
|
||||||
|
|
||||||
# Ordering operators
|
# Ordering operators
|
||||||
ragged_tensor.RaggedTensor.__ge__ = ragged_elementwise_ops.greater_equal
|
ragged_tensor.RaggedTensor.__ge__ = math_ops.greater_equal
|
||||||
ragged_tensor.RaggedTensor.__gt__ = ragged_elementwise_ops.greater
|
ragged_tensor.RaggedTensor.__gt__ = math_ops.greater
|
||||||
ragged_tensor.RaggedTensor.__le__ = ragged_elementwise_ops.less_equal
|
ragged_tensor.RaggedTensor.__le__ = math_ops.less_equal
|
||||||
ragged_tensor.RaggedTensor.__lt__ = ragged_elementwise_ops.less
|
ragged_tensor.RaggedTensor.__lt__ = math_ops.less
|
||||||
|
|
||||||
# Logical operators
|
# Logical operators
|
||||||
ragged_tensor.RaggedTensor.__and__ = ragged_elementwise_ops.logical_and
|
ragged_tensor.RaggedTensor.__and__ = math_ops.logical_and
|
||||||
ragged_tensor.RaggedTensor.__rand__ = _right(ragged_elementwise_ops.logical_and)
|
ragged_tensor.RaggedTensor.__rand__ = _right(math_ops.logical_and)
|
||||||
ragged_tensor.RaggedTensor.__invert__ = ragged_elementwise_ops.logical_not
|
ragged_tensor.RaggedTensor.__invert__ = math_ops.logical_not
|
||||||
ragged_tensor.RaggedTensor.__ror__ = _right(ragged_elementwise_ops.logical_or)
|
ragged_tensor.RaggedTensor.__ror__ = _right(math_ops.logical_or)
|
||||||
ragged_tensor.RaggedTensor.__or__ = ragged_elementwise_ops.logical_or
|
ragged_tensor.RaggedTensor.__or__ = math_ops.logical_or
|
||||||
ragged_tensor.RaggedTensor.__xor__ = ragged_elementwise_ops.logical_xor
|
ragged_tensor.RaggedTensor.__xor__ = math_ops.logical_xor
|
||||||
ragged_tensor.RaggedTensor.__rxor__ = _right(ragged_elementwise_ops.logical_xor)
|
ragged_tensor.RaggedTensor.__rxor__ = _right(math_ops.logical_xor)
|
||||||
|
|
||||||
# Arithmetic operators
|
# Arithmetic operators
|
||||||
ragged_tensor.RaggedTensor.__abs__ = ragged_elementwise_ops.abs
|
ragged_tensor.RaggedTensor.__abs__ = math_ops.abs
|
||||||
ragged_tensor.RaggedTensor.__add__ = ragged_elementwise_ops.add
|
ragged_tensor.RaggedTensor.__add__ = math_ops.add
|
||||||
ragged_tensor.RaggedTensor.__radd__ = _right(ragged_elementwise_ops.add)
|
ragged_tensor.RaggedTensor.__radd__ = _right(math_ops.add)
|
||||||
ragged_tensor.RaggedTensor.__div__ = ragged_elementwise_ops.div
|
ragged_tensor.RaggedTensor.__div__ = math_ops.div
|
||||||
ragged_tensor.RaggedTensor.__rdiv__ = _right(ragged_elementwise_ops.div)
|
ragged_tensor.RaggedTensor.__rdiv__ = _right(math_ops.div)
|
||||||
ragged_tensor.RaggedTensor.__floordiv__ = ragged_elementwise_ops.floordiv
|
ragged_tensor.RaggedTensor.__floordiv__ = math_ops.floordiv
|
||||||
ragged_tensor.RaggedTensor.__rfloordiv__ = _right(
|
ragged_tensor.RaggedTensor.__rfloordiv__ = _right(math_ops.floordiv)
|
||||||
ragged_elementwise_ops.floordiv)
|
ragged_tensor.RaggedTensor.__mod__ = math_ops.floormod
|
||||||
ragged_tensor.RaggedTensor.__mod__ = ragged_elementwise_ops.floormod
|
ragged_tensor.RaggedTensor.__rmod__ = _right(math_ops.floormod)
|
||||||
ragged_tensor.RaggedTensor.__rmod__ = _right(ragged_elementwise_ops.floormod)
|
ragged_tensor.RaggedTensor.__mul__ = math_ops.multiply
|
||||||
ragged_tensor.RaggedTensor.__mul__ = ragged_elementwise_ops.multiply
|
ragged_tensor.RaggedTensor.__rmul__ = _right(math_ops.multiply)
|
||||||
ragged_tensor.RaggedTensor.__rmul__ = _right(ragged_elementwise_ops.multiply)
|
ragged_tensor.RaggedTensor.__neg__ = math_ops.negative
|
||||||
ragged_tensor.RaggedTensor.__neg__ = ragged_elementwise_ops.negative
|
ragged_tensor.RaggedTensor.__pow__ = math_ops.pow
|
||||||
ragged_tensor.RaggedTensor.__pow__ = ragged_elementwise_ops.pow
|
ragged_tensor.RaggedTensor.__rpow__ = _right(math_ops.pow)
|
||||||
ragged_tensor.RaggedTensor.__rpow__ = _right(ragged_elementwise_ops.pow)
|
ragged_tensor.RaggedTensor.__sub__ = math_ops.subtract
|
||||||
ragged_tensor.RaggedTensor.__sub__ = ragged_elementwise_ops.subtract
|
ragged_tensor.RaggedTensor.__rsub__ = _right(math_ops.subtract)
|
||||||
ragged_tensor.RaggedTensor.__rsub__ = _right(ragged_elementwise_ops.subtract)
|
ragged_tensor.RaggedTensor.__truediv__ = math_ops.truediv
|
||||||
ragged_tensor.RaggedTensor.__truediv__ = ragged_elementwise_ops.truediv
|
ragged_tensor.RaggedTensor.__rtruediv__ = _right(math_ops.truediv)
|
||||||
ragged_tensor.RaggedTensor.__rtruediv__ = _right(ragged_elementwise_ops.truediv)
|
|
||||||
|
|
||||||
|
|
||||||
# Dummy methods
|
# Dummy methods
|
||||||
|
@ -2696,6 +2696,7 @@ class _UnaryMapValueDispatcher(dispatch.OpDispatcher):
|
|||||||
if args:
|
if args:
|
||||||
x, args = args[0], args[1:]
|
x, args = args[0], args[1:]
|
||||||
else:
|
else:
|
||||||
|
kwargs = kwargs.copy()
|
||||||
x = kwargs.pop(self._x, None)
|
x = kwargs.pop(self._x, None)
|
||||||
if isinstance(x, sparse_tensor.SparseTensor):
|
if isinstance(x, sparse_tensor.SparseTensor):
|
||||||
return sparse_tensor.SparseTensor(
|
return sparse_tensor.SparseTensor(
|
||||||
|
@ -38,6 +38,7 @@ from tensorflow.python.ops import math_ops
|
|||||||
from tensorflow.python.ops.gen_string_ops import *
|
from tensorflow.python.ops.gen_string_ops import *
|
||||||
from tensorflow.python.util import compat as util_compat
|
from tensorflow.python.util import compat as util_compat
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
|
from tensorflow.python.util import dispatch
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
# pylint: enable=g-bad-import-order
|
# pylint: enable=g-bad-import-order
|
||||||
# pylint: enable=wildcard-import
|
# pylint: enable=wildcard-import
|
||||||
@ -45,6 +46,7 @@ from tensorflow.python.util.tf_export import tf_export
|
|||||||
|
|
||||||
# pylint: disable=redefined-builtin
|
# pylint: disable=redefined-builtin
|
||||||
@tf_export("strings.regex_full_match")
|
@tf_export("strings.regex_full_match")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def regex_full_match(input, pattern, name=None):
|
def regex_full_match(input, pattern, name=None):
|
||||||
r"""Match elements of `input` with regex `pattern`.
|
r"""Match elements of `input` with regex `pattern`.
|
||||||
|
|
||||||
@ -76,6 +78,7 @@ regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__
|
|||||||
@tf_export(
|
@tf_export(
|
||||||
"strings.regex_replace", v1=["strings.regex_replace", "regex_replace"])
|
"strings.regex_replace", v1=["strings.regex_replace", "regex_replace"])
|
||||||
@deprecation.deprecated_endpoints("regex_replace")
|
@deprecation.deprecated_endpoints("regex_replace")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def regex_replace(input, pattern, rewrite, replace_global=True, name=None):
|
def regex_replace(input, pattern, rewrite, replace_global=True, name=None):
|
||||||
r"""Replace elements of `input` matching regex `pattern` with `rewrite`.
|
r"""Replace elements of `input` matching regex `pattern` with `rewrite`.
|
||||||
|
|
||||||
@ -350,10 +353,13 @@ reduce_join.__doc__ = reduce_join.__doc__.replace("tf.reduce_join(",
|
|||||||
# This wrapper provides backwards compatibility for code that predates the
|
# This wrapper provides backwards compatibility for code that predates the
|
||||||
# unit argument and that passed 'name' as a positional argument.
|
# unit argument and that passed 'name' as a positional argument.
|
||||||
@tf_export(v1=["strings.length"])
|
@tf_export(v1=["strings.length"])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def string_length(input, name=None, unit="BYTE"):
|
def string_length(input, name=None, unit="BYTE"):
|
||||||
return gen_string_ops.string_length(input, unit=unit, name=name)
|
return gen_string_ops.string_length(input, unit=unit, name=name)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("strings.length", v1=[])
|
@tf_export("strings.length", v1=[])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def string_length_v2(input, unit="BYTE", name=None):
|
def string_length_v2(input, unit="BYTE", name=None):
|
||||||
return string_length(input, name, unit)
|
return string_length(input, name, unit)
|
||||||
|
|
||||||
@ -370,11 +376,13 @@ substr_deprecated.__doc__ = gen_string_ops.substr.__doc__
|
|||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["strings.substr"])
|
@tf_export(v1=["strings.substr"])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def substr(input, pos, len, name=None, unit="BYTE"):
|
def substr(input, pos, len, name=None, unit="BYTE"):
|
||||||
return gen_string_ops.substr(input, pos, len, unit=unit, name=name)
|
return gen_string_ops.substr(input, pos, len, unit=unit, name=name)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("strings.substr", v1=[])
|
@tf_export("strings.substr", v1=[])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def substr_v2(input, pos, len, unit="BYTE", name=None):
|
def substr_v2(input, pos, len, unit="BYTE", name=None):
|
||||||
return substr(input, pos, len, name=name, unit=unit)
|
return substr(input, pos, len, name=name, unit=unit)
|
||||||
|
|
||||||
@ -395,6 +403,7 @@ ops.NotDifferentiable("DecodeBase64")
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("strings.to_number", v1=[])
|
@tf_export("strings.to_number", v1=[])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def string_to_number(input, out_type=dtypes.float32, name=None):
|
def string_to_number(input, out_type=dtypes.float32, name=None):
|
||||||
r"""Converts each string in the input Tensor to the specified numeric type.
|
r"""Converts each string in the input Tensor to the specified numeric type.
|
||||||
|
|
||||||
@ -418,6 +427,7 @@ tf_export(v1=["strings.to_number", "string_to_number"])(
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("strings.to_hash_bucket", v1=[])
|
@tf_export("strings.to_hash_bucket", v1=[])
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
def string_to_hash_bucket(input, num_buckets, name=None):
|
def string_to_hash_bucket(input, num_buckets, name=None):
|
||||||
# pylint: disable=line-too-long
|
# pylint: disable=line-too-long
|
||||||
r"""Converts each string in the input Tensor to its hash mod by a number of buckets.
|
r"""Converts each string in the input Tensor to its hash mod by a number of buckets.
|
||||||
|
@ -166,15 +166,14 @@ def dispatch_for_types(op, *types):
|
|||||||
|
|
||||||
def add_dispatch_list(target):
|
def add_dispatch_list(target):
|
||||||
"""Decorator that adds a dispatch_list attribute to an op."""
|
"""Decorator that adds a dispatch_list attribute to an op."""
|
||||||
assert not hasattr(target, DISPATCH_ATTR)
|
if hasattr(target, DISPATCH_ATTR):
|
||||||
|
raise AssertionError("%s already has a dispatch list" % target)
|
||||||
setattr(target, DISPATCH_ATTR, [])
|
setattr(target, DISPATCH_ATTR, [])
|
||||||
return target
|
return target
|
||||||
|
|
||||||
|
|
||||||
def add_dispatch_support(target):
|
def add_dispatch_support(target):
|
||||||
"""Decorator that adds a dispatch handling wrapper to an op."""
|
"""Decorator that adds a dispatch handling wrapper to an op."""
|
||||||
add_dispatch_list(target)
|
|
||||||
|
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
"""Call target, and fall back on dispatchers if there is a TypeError."""
|
"""Call target, and fall back on dispatchers if there is a TypeError."""
|
||||||
try:
|
try:
|
||||||
@ -188,5 +187,5 @@ def add_dispatch_support(target):
|
|||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
setattr(wrapper, DISPATCH_ATTR, [])
|
add_dispatch_list(wrapper)
|
||||||
return tf_decorator.make_decorator(target, wrapper)
|
return tf_decorator.make_decorator(target, wrapper)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user