578 lines
20 KiB
Python
578 lines
20 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Operator dispatch for RaggedTensors."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
import numpy as np
|
|
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import sparse_tensor
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import clip_ops
|
|
from tensorflow.python.ops import data_flow_ops
|
|
from tensorflow.python.ops import gen_bitwise_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import parsing_ops
|
|
from tensorflow.python.ops import string_ops
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.ops.ragged import ragged_array_ops
|
|
from tensorflow.python.ops.ragged import ragged_batch_gather_ops
|
|
from tensorflow.python.ops.ragged import ragged_concat_ops
|
|
from tensorflow.python.ops.ragged import ragged_gather_ops
|
|
from tensorflow.python.ops.ragged import ragged_math_ops
|
|
from tensorflow.python.ops.ragged import ragged_squeeze_op
|
|
from tensorflow.python.ops.ragged import ragged_string_ops
|
|
from tensorflow.python.ops.ragged import ragged_tensor
|
|
from tensorflow.python.ops.ragged import ragged_tensor_shape
|
|
from tensorflow.python.ops.ragged import ragged_util
|
|
from tensorflow.python.ops.ragged import ragged_where_op
|
|
from tensorflow.python.util import deprecation
|
|
from tensorflow.python.util import dispatch
|
|
from tensorflow.python.util import tf_decorator
|
|
from tensorflow.python.util import tf_export
|
|
from tensorflow.python.util import tf_inspect
|
|
|
|
# @TODO(edloper): Set this to True in the CL that exports RaggedTensors.
|
|
_UPDATE_DOCSTRINGS = False
|
|
|
|
# Information about an argument to an operation: The name of the argument, its
|
|
# position in the argument list, and a boolean flag indicating whether it
|
|
# expects a list of tensors.
|
|
_ArgInfo = collections.namedtuple('ArgInfo', ['name', 'position', 'is_list'])
|
|
|
|
|
|
def _get_arg_infos(func, arg_names):
|
|
"""Returns an `_ArgInfo` for each argument of `func` specified by `arg_names`.
|
|
|
|
Args:
|
|
func: The function whose arguments should be described.
|
|
arg_names: The names of the arguments to get info for.
|
|
|
|
Returns:
|
|
A tuple of `_ArgInfo`s.
|
|
"""
|
|
arg_infos = []
|
|
|
|
# Inspect the func's argspec to find the position of each arg.
|
|
arg_spec = tf_inspect.getargspec(func)
|
|
for argname in arg_names:
|
|
assert isinstance(argname, str)
|
|
is_list = argname.startswith('[') and argname.endswith(']')
|
|
if is_list:
|
|
argname = argname[1:-1]
|
|
if argname not in arg_spec.args:
|
|
raise ValueError('Argument %r not found function in %s. Args=%s' %
|
|
(argname, func, arg_spec.args))
|
|
arg_infos.append(_ArgInfo(argname, arg_spec.args.index(argname), is_list))
|
|
return arg_infos
|
|
|
|
|
|
def _is_convertible_to_tensor(value):
|
|
"""Returns true if `value` is convertible to a `Tensor`."""
|
|
if value is None:
|
|
return True
|
|
if isinstance(value,
|
|
(ops.Tensor, variables.Variable, np.ndarray, int, float, str)):
|
|
return True
|
|
elif isinstance(value, (sparse_tensor.SparseTensor,)):
|
|
return False
|
|
else:
|
|
try:
|
|
ops.convert_to_tensor(value)
|
|
return True
|
|
except (TypeError, ValueError):
|
|
return False
|
|
|
|
|
|
class UnaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
|
|
"""OpDispatcher for unary ops that map a base op across ragged values."""
|
|
|
|
def __init__(self, original_op, arg_is_list=False):
|
|
self._original_op = original_op
|
|
self._arg_is_list = arg_is_list
|
|
arg_names = tf_inspect.getfullargspec(original_op)[0]
|
|
self._x = arg_names[0]
|
|
if _UPDATE_DOCSTRINGS:
|
|
original_op.__doc__ = (
|
|
original_op.__doc__.rstrip() + '\n\n' +
|
|
' `{x}` may be a `tf.RaggedTensor`.\n'.format(x=self._x))
|
|
|
|
def handle(self, args, kwargs):
|
|
if args:
|
|
x, args = args[0], args[1:]
|
|
else:
|
|
kwargs = kwargs.copy()
|
|
x = kwargs.pop(self._x, None)
|
|
if x is None:
|
|
return self.NOT_SUPPORTED
|
|
if self._arg_is_list:
|
|
found_ragged = False
|
|
for elt in x:
|
|
if ragged_tensor.is_ragged(elt):
|
|
found_ragged = True
|
|
elif not _is_convertible_to_tensor(elt):
|
|
return self.NOT_SUPPORTED
|
|
if found_ragged:
|
|
x = ragged_tensor.match_row_splits_dtypes(*x)
|
|
nested_splits_lists = [
|
|
elt.nested_row_splits for elt in x if ragged_tensor.is_ragged(elt)
|
|
]
|
|
flat_values = [
|
|
elt.flat_values if ragged_tensor.is_ragged(elt) else elt
|
|
for elt in x
|
|
]
|
|
with ops.control_dependencies(
|
|
ragged_util.assert_splits_match(nested_splits_lists)):
|
|
return ragged_tensor.RaggedTensor.from_nested_row_splits(
|
|
self._original_op(flat_values, *args, **kwargs),
|
|
nested_splits_lists[0], validate=False)
|
|
else:
|
|
return self.NOT_SUPPORTED
|
|
else:
|
|
found_ragged = ragged_tensor.is_ragged(x)
|
|
if found_ragged:
|
|
mapped_values = self._original_op(x.flat_values, *args, **kwargs)
|
|
return x.with_flat_values(mapped_values)
|
|
else:
|
|
return self.NOT_SUPPORTED
|
|
|
|
|
|
class BinaryRaggedElementwiseDispatcher(dispatch.OpDispatcher):
|
|
"""OpDispatcher for binary ops that map a base op across ragged values.
|
|
|
|
Supports broadcasting.
|
|
"""
|
|
|
|
def __init__(self, original_op):
|
|
self._original_op = original_op
|
|
arg_names = tf_inspect.getfullargspec(original_op)[0]
|
|
self._x = arg_names[0]
|
|
self._y = arg_names[1]
|
|
if _UPDATE_DOCSTRINGS:
|
|
original_op.__doc__ = (
|
|
original_op.__doc__.rstrip() + '\n\n' +
|
|
' `{x}` and `{y}` may be a `tf.RaggedTensor`.\n'.format(
|
|
x=self._x, y=self._y))
|
|
|
|
def handle(self, args, kwargs):
|
|
# Extract the binary args.
|
|
if len(args) > 1:
|
|
x = args[0]
|
|
y = args[1]
|
|
args = args[2:]
|
|
elif args:
|
|
kwargs = kwargs.copy()
|
|
x = args[0]
|
|
y = kwargs.pop(self._y, None)
|
|
args = args[1:]
|
|
else:
|
|
kwargs = kwargs.copy()
|
|
x = kwargs.pop(self._x, None)
|
|
y = kwargs.pop(self._y, None)
|
|
|
|
# Bail if we don't have at least one ragged argument.
|
|
x_is_ragged = ragged_tensor.is_ragged(x)
|
|
y_is_ragged = ragged_tensor.is_ragged(y)
|
|
if not (x_is_ragged or y_is_ragged):
|
|
return self.NOT_SUPPORTED
|
|
|
|
# Convert args to tensors. Bail if conversion fails.
|
|
try:
|
|
if not x_is_ragged:
|
|
x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype)
|
|
if not y_is_ragged:
|
|
y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype)
|
|
except (TypeError, ValueError):
|
|
return self.NOT_SUPPORTED
|
|
|
|
if x_is_ragged and y_is_ragged:
|
|
x, y = ragged_tensor.match_row_splits_dtypes(x, y)
|
|
|
|
if ((x_is_ragged and y_is_ragged) or
|
|
(x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or
|
|
(y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
|
|
bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(
|
|
ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x),
|
|
ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y))
|
|
x = ragged_tensor_shape.broadcast_to(
|
|
x, bcast_shape, broadcast_inner_dimensions=False)
|
|
y = ragged_tensor_shape.broadcast_to(
|
|
y, bcast_shape, broadcast_inner_dimensions=False)
|
|
|
|
x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
|
|
y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
|
|
mapped_values = self._original_op(x_values, y_values, *args, **kwargs)
|
|
if ragged_tensor.is_ragged(x):
|
|
return x.with_flat_values(mapped_values)
|
|
else:
|
|
return y.with_flat_values(mapped_values)
|
|
|
|
|
|
class RaggedDispatcher(dispatch.OpDispatcher):
|
|
"""OpDispatcher for ragged ops.
|
|
|
|
Dispatches to a wrapped op-handler if at least one of the `tensor_args`
|
|
arguments is a RaggedTensor or a RaggedTensorValue; and all of the
|
|
`tensor_args` arguments are convertible to Tensor or RaggedTensor.
|
|
"""
|
|
|
|
def __init__(self, original_op, ragged_op, ragged_args):
|
|
op_arg_names = tf_inspect.getfullargspec(original_op)[0]
|
|
ragged_arg_names = tf_inspect.getfullargspec(ragged_op)[0]
|
|
if op_arg_names != ragged_arg_names:
|
|
raise AssertionError(
|
|
'Signature must exactly match when overriding %s with %s: %s vs %s' %
|
|
(original_op, ragged_op, op_arg_names, ragged_arg_names))
|
|
self._ragged_op = ragged_op
|
|
self._ragged_args = _get_arg_infos(ragged_op, ragged_args)
|
|
if _UPDATE_DOCSTRINGS:
|
|
arg_list = ' and '.join('`%s`' % arg for arg in ragged_args)
|
|
original_op.__doc__ = (
|
|
original_op.__doc__.rstrip() + '\n\n' +
|
|
' {0} may be a `tf.RaggedTensor`.\n'.format(arg_list))
|
|
|
|
def handle(self, args, kwargs):
|
|
if self.is_supported(args, kwargs):
|
|
return self._ragged_op(*args, **kwargs)
|
|
else:
|
|
return self.NOT_SUPPORTED
|
|
|
|
def is_supported(self, args, kwargs):
|
|
found_ragged = False
|
|
for arg_info in self._ragged_args:
|
|
if arg_info.position < len(args):
|
|
arg = args[arg_info.position]
|
|
else:
|
|
arg = kwargs.get(arg_info.name, None)
|
|
|
|
if arg_info.is_list:
|
|
if not isinstance(arg, (list, tuple)):
|
|
return False
|
|
for elt in arg:
|
|
if ragged_tensor.is_ragged(elt):
|
|
found_ragged = True
|
|
elif not _is_convertible_to_tensor(elt):
|
|
return False
|
|
else:
|
|
if ragged_tensor.is_ragged(arg):
|
|
found_ragged = True
|
|
elif not _is_convertible_to_tensor(arg):
|
|
return False
|
|
return found_ragged
|
|
|
|
|
|
_UNARY_ELEMENTWISE_OPS = [
|
|
array_ops.check_numerics,
|
|
array_ops.identity,
|
|
array_ops.ones_like,
|
|
array_ops.ones_like_v2,
|
|
array_ops.zeros_like,
|
|
array_ops.zeros_like_v2,
|
|
clip_ops.clip_by_value,
|
|
gen_bitwise_ops.invert,
|
|
math_ops.abs,
|
|
math_ops.acos,
|
|
math_ops.acosh,
|
|
math_ops.angle,
|
|
math_ops.asin,
|
|
math_ops.asinh,
|
|
math_ops.atan,
|
|
math_ops.atanh,
|
|
math_ops.cast,
|
|
math_ops.ceil,
|
|
math_ops.conj,
|
|
math_ops.cos,
|
|
math_ops.cosh,
|
|
math_ops.digamma,
|
|
math_ops.erf,
|
|
math_ops.erfc,
|
|
math_ops.erfinv,
|
|
math_ops.exp,
|
|
math_ops.expm1,
|
|
math_ops.floor,
|
|
math_ops.imag,
|
|
math_ops.is_finite,
|
|
math_ops.is_inf,
|
|
math_ops.is_nan,
|
|
math_ops.lgamma,
|
|
math_ops.log,
|
|
math_ops.log1p,
|
|
math_ops.log_sigmoid,
|
|
math_ops.logical_not,
|
|
math_ops.ndtri,
|
|
math_ops.negative,
|
|
math_ops.real,
|
|
math_ops.reciprocal,
|
|
math_ops.rint,
|
|
math_ops.round,
|
|
math_ops.rsqrt,
|
|
math_ops.saturate_cast,
|
|
math_ops.sign,
|
|
math_ops.sin,
|
|
math_ops.sinh,
|
|
math_ops.sqrt,
|
|
math_ops.square,
|
|
math_ops.tan,
|
|
parsing_ops.decode_compressed,
|
|
string_ops.string_to_number,
|
|
string_ops.string_to_hash_bucket,
|
|
string_ops.as_string,
|
|
string_ops.decode_base64,
|
|
string_ops.encode_base64,
|
|
string_ops.regex_full_match,
|
|
string_ops.regex_replace,
|
|
string_ops.string_strip,
|
|
string_ops.string_to_hash_bucket,
|
|
string_ops.string_to_hash_bucket_fast,
|
|
string_ops.string_to_hash_bucket_strong,
|
|
string_ops.substr,
|
|
string_ops.substr_v2,
|
|
string_ops.string_length,
|
|
string_ops.string_length_v2,
|
|
string_ops.unicode_script,
|
|
]
|
|
|
|
_UNARY_LIST_ELEMENTWISE_OPS = [
|
|
math_ops.add_n,
|
|
string_ops.string_join,
|
|
]
|
|
|
|
_BINARY_ELEMENTWISE_OPS = [
|
|
gen_bitwise_ops.bitwise_and,
|
|
gen_bitwise_ops.bitwise_or,
|
|
gen_bitwise_ops.bitwise_xor,
|
|
gen_bitwise_ops.left_shift,
|
|
gen_bitwise_ops.right_shift,
|
|
math_ops.add,
|
|
math_ops.atan2,
|
|
math_ops.complex,
|
|
math_ops.div_no_nan,
|
|
math_ops.divide,
|
|
math_ops.equal,
|
|
math_ops.floordiv,
|
|
math_ops.floormod,
|
|
math_ops.greater,
|
|
math_ops.greater_equal,
|
|
math_ops.less,
|
|
math_ops.less_equal,
|
|
math_ops.logical_and,
|
|
math_ops.logical_or,
|
|
math_ops.logical_xor,
|
|
math_ops.maximum,
|
|
math_ops.minimum,
|
|
math_ops.multiply,
|
|
math_ops.not_equal,
|
|
math_ops.pow,
|
|
math_ops.realdiv,
|
|
math_ops.squared_difference,
|
|
math_ops.subtract,
|
|
math_ops.truediv,
|
|
math_ops.truncatediv,
|
|
math_ops.truncatemod,
|
|
]
|
|
|
|
|
|
# We don't need to register a separate delegation handler for these v1 ops,
|
|
# since they delegate to the v2 ops (which already have a handler). But we
|
|
# still want to include them in the ragged_op_list() output.
|
|
_V2_OPS_THAT_ARE_DELEGATED_TO_FROM_V1_OPS = [
|
|
math_ops.reduce_sum,
|
|
math_ops.reduce_prod,
|
|
math_ops.reduce_min,
|
|
math_ops.reduce_max,
|
|
math_ops.reduce_mean,
|
|
math_ops.reduce_any,
|
|
math_ops.reduce_all,
|
|
string_ops.string_to_number,
|
|
string_ops.string_to_hash_bucket,
|
|
string_ops.reduce_join_v2,
|
|
]
|
|
|
|
|
|
def _ragged_gather_v1(params, indices, validate_indices=None, name=None,
|
|
axis=0, batch_dims=0):
|
|
return ragged_gather_ops.gather(
|
|
params=params,
|
|
indices=indices,
|
|
validate_indices=validate_indices,
|
|
axis=axis,
|
|
batch_dims=batch_dims,
|
|
name=name)
|
|
|
|
|
|
def _ragged_gather_nd_v1(params, indices, name=None, batch_dims=0):
|
|
return ragged_gather_ops.gather_nd(
|
|
params=params,
|
|
indices=indices,
|
|
batch_dims=batch_dims,
|
|
name=name)
|
|
|
|
|
|
def _ragged_expand_dims_v1(input, axis=None, name=None, dim=None): # pylint: disable=redefined-builtin
|
|
if dim is not None:
|
|
axis = dim
|
|
return ragged_array_ops.expand_dims(input=input, axis=axis, name=name)
|
|
|
|
|
|
def _ragged_size_v1(input, name=None, out_type=dtypes.int32): # pylint: disable=redefined-builtin
|
|
return ragged_array_ops.size(input=input, out_type=out_type, name=name)
|
|
|
|
|
|
def _ragged_squeeze_v1(input, axis=None, name=None, squeeze_dims=None): # pylint: disable=redefined-builtin
|
|
axis = deprecation.deprecated_argument_lookup('axis', axis, 'squeeze_dims',
|
|
squeeze_dims)
|
|
return ragged_squeeze_op.squeeze(input, axis, name)
|
|
|
|
|
|
def _ragged_dynamic_partition(data, partitions, num_partitions, name=None):
|
|
"""RaggedTensor Dispatch override for tf.dynamic_partition."""
|
|
if not isinstance(num_partitions, int) or num_partitions < 0:
|
|
raise TypeError('num_partitions must be a non-negative integer')
|
|
result = ragged_array_ops.stack_dynamic_partitions(data, partitions,
|
|
num_partitions, name)
|
|
return [result[i] for i in range(num_partitions)]
|
|
|
|
# (original_op, ragged_op, ragged_args)
|
|
_RAGGED_DISPATCH_OPS = [
|
|
(array_ops.batch_gather, ragged_batch_gather_ops.batch_gather,
|
|
['params', 'indices']),
|
|
(array_ops.concat, ragged_concat_ops.concat, ['[values]']),
|
|
(array_ops.expand_dims, _ragged_expand_dims_v1, ['input']),
|
|
(array_ops.expand_dims_v2, ragged_array_ops.expand_dims, ['input']),
|
|
(array_ops.gather, _ragged_gather_v1, ['params', 'indices']),
|
|
(array_ops.gather_v2, ragged_gather_ops.gather, ['params', 'indices']),
|
|
(array_ops.gather_nd, _ragged_gather_nd_v1, ['params', 'indices']),
|
|
(array_ops.gather_nd_v2, ragged_gather_ops.gather_nd, ['params',
|
|
'indices']),
|
|
(array_ops.one_hot, ragged_array_ops.ragged_one_hot, ['indices']),
|
|
(array_ops.rank, ragged_array_ops.rank, ['input']),
|
|
(array_ops.reverse, ragged_array_ops.reverse, ['tensor']),
|
|
(array_ops.size, _ragged_size_v1, ['input']),
|
|
(array_ops.size_v2, ragged_array_ops.size, ['input']),
|
|
(array_ops.squeeze, _ragged_squeeze_v1, ['input']),
|
|
(array_ops.squeeze_v2, ragged_squeeze_op.squeeze, ['input']),
|
|
(array_ops.stack, ragged_concat_ops.stack, ['[values]']),
|
|
(array_ops.tile, ragged_array_ops.tile, ['input']),
|
|
(array_ops.where, ragged_where_op.where, ['condition', 'x', 'y']),
|
|
(data_flow_ops.dynamic_partition, _ragged_dynamic_partition,
|
|
['data', 'partitions']),
|
|
(math_ops.unsorted_segment_sum, ragged_math_ops.segment_sum,
|
|
['data', 'segment_ids']),
|
|
(math_ops.unsorted_segment_prod, ragged_math_ops.segment_prod,
|
|
['data', 'segment_ids']),
|
|
(math_ops.unsorted_segment_min, ragged_math_ops.segment_min,
|
|
['data', 'segment_ids']),
|
|
(math_ops.unsorted_segment_max, ragged_math_ops.segment_max,
|
|
['data', 'segment_ids']),
|
|
(math_ops.unsorted_segment_mean, ragged_math_ops.segment_mean,
|
|
['data', 'segment_ids']),
|
|
(math_ops.unsorted_segment_sqrt_n, ragged_math_ops.segment_sqrt_n,
|
|
['data', 'segment_ids']),
|
|
(string_ops.reduce_join_v2, ragged_string_ops.reduce_join, ['inputs']),
|
|
(math_ops.reduce_sum, ragged_math_ops.reduce_sum, ['input_tensor']),
|
|
(math_ops.reduce_prod, ragged_math_ops.reduce_prod, ['input_tensor']),
|
|
(math_ops.reduce_min, ragged_math_ops.reduce_min, ['input_tensor']),
|
|
(math_ops.reduce_max, ragged_math_ops.reduce_max, ['input_tensor']),
|
|
(math_ops.reduce_mean, ragged_math_ops.reduce_mean, ['input_tensor']),
|
|
(math_ops.reduce_any, ragged_math_ops.reduce_any, ['input_tensor']),
|
|
(math_ops.reduce_all, ragged_math_ops.reduce_all, ['input_tensor']),
|
|
]
|
|
|
|
|
|
def register_dispatchers():
|
|
"""Constructs & registers OpDispatchers for ragged ops."""
|
|
|
|
op_list = (
|
|
_UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS +
|
|
_BINARY_ELEMENTWISE_OPS + [x[0] for x in _RAGGED_DISPATCH_OPS])
|
|
for op in op_list:
|
|
_, undecorated_op = tf_decorator.unwrap(op)
|
|
if not hasattr(undecorated_op,
|
|
tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names):
|
|
raise AssertionError('Expected %s to be an exported symbol '
|
|
'(while adding a RaggedTensor dispatcher)')
|
|
|
|
for op in _UNARY_ELEMENTWISE_OPS:
|
|
UnaryRaggedElementwiseDispatcher(op).register(op)
|
|
|
|
for op in _UNARY_LIST_ELEMENTWISE_OPS:
|
|
UnaryRaggedElementwiseDispatcher(op, True).register(op)
|
|
|
|
for op in _BINARY_ELEMENTWISE_OPS:
|
|
BinaryRaggedElementwiseDispatcher(op).register(op)
|
|
|
|
for (original_op, ragged_op, args) in _RAGGED_DISPATCH_OPS:
|
|
RaggedDispatcher(original_op, ragged_op, args).register(original_op)
|
|
|
|
|
|
def _ragged_op_signature(op, ragged_args):
|
|
"""Returns a signature for the given op, marking ragged args in bold."""
|
|
op_name = tf_export.get_canonical_name_for_symbol(op)
|
|
argspec = tf_inspect.getfullargspec(op)
|
|
arg_names = argspec.args
|
|
|
|
# Mark ragged arguments in bold.
|
|
for pos in ragged_args:
|
|
arg_names[pos] = '**' + arg_names[pos] + '**'
|
|
|
|
# Add argument defaults.
|
|
for pos in range(-1, -len(argspec.defaults) - 1, -1):
|
|
arg_names[pos] += '=`{!r}`'.format(argspec.defaults[pos])
|
|
|
|
# Add varargs and keyword args
|
|
if argspec.varargs:
|
|
arg_names.append('*' + argspec.varargs)
|
|
if argspec.varkw:
|
|
arg_names.append('**' + argspec.varkw)
|
|
|
|
return '* `tf.{}`({})'.format(op_name, ', '.join(arg_names))
|
|
|
|
|
|
def _op_is_in_tf_version(op, version):
|
|
if version == 1:
|
|
return (tf_export.get_v1_names(tf_decorator.unwrap(op)[1]) or
|
|
op in _V2_OPS_THAT_ARE_DELEGATED_TO_FROM_V1_OPS)
|
|
elif version == 2:
|
|
return tf_export.get_v2_names(tf_decorator.unwrap(op)[1])
|
|
else:
|
|
raise ValueError('Expected version 1 or 2.')
|
|
|
|
|
|
def ragged_op_list(tf_version=1):
|
|
"""Returns a string listing operators that have dispathers registered."""
|
|
lines = []
|
|
for op in _UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS:
|
|
if _op_is_in_tf_version(op, tf_version):
|
|
lines.append(_ragged_op_signature(op, [0]))
|
|
for op in _BINARY_ELEMENTWISE_OPS:
|
|
if _op_is_in_tf_version(op, tf_version):
|
|
lines.append(_ragged_op_signature(op, [0, 1]))
|
|
for op, _, ragged_args in _RAGGED_DISPATCH_OPS:
|
|
if _op_is_in_tf_version(op, tf_version):
|
|
arginfos = _get_arg_infos(op, ragged_args)
|
|
ragged_args = [arginfo.position for arginfo in arginfos]
|
|
lines.append(_ragged_op_signature(op, ragged_args))
|
|
return ('\n\n### Additional ops that support `RaggedTensor`\n\n'
|
|
'Arguments that accept `RaggedTensor`s are marked in **bold**.\n\n' +
|
|
'\n'.join(sorted(lines)) + 'n')
|
|
|
|
|
|
register_dispatchers()
|