Add dispatch support to more Python APIs.

PiperOrigin-RevId: 311763060
Change-Id: Ib35371483aa083e245996508a82fd13d8ac43131
This commit is contained in:
Edward Loper 2020-05-15 10:58:42 -07:00 committed by TensorFlower Gardener
parent 26104505b8
commit 77245d07d1
52 changed files with 696 additions and 28 deletions

View File

@ -24,6 +24,7 @@ from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import keras_export
# b/123041942
@ -41,6 +42,7 @@ _TF_ACTIVATIONS_V2 = {
@keras_export('keras.activations.softmax')
@dispatch.add_dispatch_support
def softmax(x, axis=-1):
"""Softmax converts a real vector to a vector of categorical probabilities.
@ -82,6 +84,7 @@ def softmax(x, axis=-1):
@keras_export('keras.activations.elu')
@dispatch.add_dispatch_support
def elu(x, alpha=1.0):
"""Exponential linear unit.
@ -100,6 +103,7 @@ def elu(x, alpha=1.0):
@keras_export('keras.activations.selu')
@dispatch.add_dispatch_support
def selu(x):
"""Scaled Exponential Linear Unit (SELU).
@ -153,6 +157,7 @@ def selu(x):
@keras_export('keras.activations.softplus')
@dispatch.add_dispatch_support
def softplus(x):
"""Softplus activation function, `softplus(x) = log(exp(x) + 1)`.
@ -174,6 +179,7 @@ def softplus(x):
@keras_export('keras.activations.softsign')
@dispatch.add_dispatch_support
def softsign(x):
"""Softsign activation function, `softsign(x) = x / (abs(x) + 1)`.
@ -194,6 +200,7 @@ def softsign(x):
@keras_export('keras.activations.swish')
@dispatch.add_dispatch_support
def swish(x):
"""Swish activation function, `swish(x) = x * sigmoid(x)`.
@ -224,6 +231,7 @@ def swish(x):
@keras_export('keras.activations.relu')
@dispatch.add_dispatch_support
def relu(x, alpha=0., max_value=None, threshold=0):
"""Applies the rectified linear unit activation function.
@ -264,6 +272,7 @@ def relu(x, alpha=0., max_value=None, threshold=0):
@keras_export('keras.activations.tanh')
@dispatch.add_dispatch_support
def tanh(x):
"""Hyperbolic tangent activation function.
@ -285,6 +294,7 @@ def tanh(x):
@keras_export('keras.activations.sigmoid')
@dispatch.add_dispatch_support
def sigmoid(x):
"""Sigmoid activation function, `sigmoid(x) = 1 / (1 + exp(-x))`.
@ -314,6 +324,7 @@ def sigmoid(x):
@keras_export('keras.activations.exponential')
@dispatch.add_dispatch_support
def exponential(x):
"""Exponential activation function.
@ -334,6 +345,7 @@ def exponential(x):
@keras_export('keras.activations.hard_sigmoid')
@dispatch.add_dispatch_support
def hard_sigmoid(x):
"""Hard sigmoid activation function.
@ -360,6 +372,7 @@ def hard_sigmoid(x):
@keras_export('keras.activations.linear')
@dispatch.add_dispatch_support
def linear(x):
"""Linear activation function (pass-through).
@ -380,6 +393,7 @@ def linear(x):
@keras_export('keras.activations.serialize')
@dispatch.add_dispatch_support
def serialize(activation):
"""Returns the string identifier of an activation function.
@ -410,6 +424,7 @@ def serialize(activation):
@keras_export('keras.activations.deserialize')
@dispatch.add_dispatch_support
def deserialize(name, custom_objects=None):
"""Returns activation function given a string identifier.
@ -447,6 +462,7 @@ def deserialize(name, custom_objects=None):
@keras_export('keras.activations.get')
@dispatch.add_dispatch_support
def get(identifier):
"""Returns function.

View File

@ -76,6 +76,7 @@ from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import moving_averages
from tensorflow.python.training.tracking import util as tracking_util
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
from tensorflow.python.util import tf_contextlib
@ -173,6 +174,7 @@ def backend():
@keras_export('keras.backend.cast_to_floatx')
@dispatch.add_dispatch_support
def cast_to_floatx(x):
"""Cast a Numpy array to the default Keras float type.
@ -799,6 +801,7 @@ def is_sparse(tensor):
@keras_export('keras.backend.to_dense')
@dispatch.add_dispatch_support
def to_dense(tensor):
"""Converts a sparse tensor into a dense tensor and returns it.
@ -1007,6 +1010,7 @@ def _initialize_variables(session):
@keras_export('keras.backend.constant')
@dispatch.add_dispatch_support
def constant(value, dtype=None, shape=None, name=None):
"""Creates a constant tensor.
@ -1163,6 +1167,7 @@ def is_placeholder(x):
@keras_export('keras.backend.shape')
@dispatch.add_dispatch_support
def shape(x):
"""Returns the symbolic shape of a tensor or variable.
@ -1245,6 +1250,7 @@ def ndim(x):
@keras_export('keras.backend.dtype')
@dispatch.add_dispatch_support
def dtype(x):
"""Returns the dtype of a Keras tensor or variable, as a string.
@ -1343,6 +1349,7 @@ def zeros(shape, dtype=None, name=None):
@keras_export('keras.backend.ones')
@dispatch.add_dispatch_support
def ones(shape, dtype=None, name=None):
"""Instantiates an all-ones variable and returns it.
@ -1377,6 +1384,7 @@ def ones(shape, dtype=None, name=None):
@keras_export('keras.backend.eye')
@dispatch.add_dispatch_support
def eye(size, dtype=None, name=None):
"""Instantiate an identity matrix and returns it.
@ -1433,6 +1441,7 @@ def zeros_like(x, dtype=None, name=None):
@keras_export('keras.backend.ones_like')
@dispatch.add_dispatch_support
def ones_like(x, dtype=None, name=None):
"""Instantiates an all-ones variable of the same shape as another tensor.
@ -1563,6 +1572,7 @@ def count_params(x):
@keras_export('keras.backend.cast')
@dispatch.add_dispatch_support
def cast(x, dtype):
"""Casts a tensor to a different dtype and returns it.
@ -1647,6 +1657,7 @@ def moving_average_update(x, value, momentum):
@keras_export('keras.backend.dot')
@dispatch.add_dispatch_support
def dot(x, y):
"""Multiplies 2 tensors (and/or variables) and returns a tensor.
@ -1707,6 +1718,7 @@ def dot(x, y):
@keras_export('keras.backend.batch_dot')
@dispatch.add_dispatch_support
def batch_dot(x, y, axes=None):
"""Batchwise dot product.
@ -1895,6 +1907,7 @@ def batch_dot(x, y, axes=None):
@keras_export('keras.backend.transpose')
@dispatch.add_dispatch_support
def transpose(x):
"""Transposes a tensor and returns it.
@ -1926,6 +1939,7 @@ def transpose(x):
@keras_export('keras.backend.gather')
@dispatch.add_dispatch_support
def gather(reference, indices):
"""Retrieves the elements of indices `indices` in the tensor `reference`.
@ -1961,6 +1975,7 @@ def gather(reference, indices):
@keras_export('keras.backend.max')
@dispatch.add_dispatch_support
def max(x, axis=None, keepdims=False):
"""Maximum value in a tensor.
@ -1979,6 +1994,7 @@ def max(x, axis=None, keepdims=False):
@keras_export('keras.backend.min')
@dispatch.add_dispatch_support
def min(x, axis=None, keepdims=False):
"""Minimum value in a tensor.
@ -1997,6 +2013,7 @@ def min(x, axis=None, keepdims=False):
@keras_export('keras.backend.sum')
@dispatch.add_dispatch_support
def sum(x, axis=None, keepdims=False):
"""Sum of the values in a tensor, alongside the specified axis.
@ -2015,6 +2032,7 @@ def sum(x, axis=None, keepdims=False):
@keras_export('keras.backend.prod')
@dispatch.add_dispatch_support
def prod(x, axis=None, keepdims=False):
"""Multiplies the values in a tensor, alongside the specified axis.
@ -2033,6 +2051,7 @@ def prod(x, axis=None, keepdims=False):
@keras_export('keras.backend.cumsum')
@dispatch.add_dispatch_support
def cumsum(x, axis=0):
"""Cumulative sum of the values in a tensor, alongside the specified axis.
@ -2047,6 +2066,7 @@ def cumsum(x, axis=0):
@keras_export('keras.backend.cumprod')
@dispatch.add_dispatch_support
def cumprod(x, axis=0):
"""Cumulative product of the values in a tensor, alongside the specified axis.
@ -2081,6 +2101,7 @@ def var(x, axis=None, keepdims=False):
@keras_export('keras.backend.std')
@dispatch.add_dispatch_support
def std(x, axis=None, keepdims=False):
"""Standard deviation of a tensor, alongside the specified axis.
@ -2107,6 +2128,7 @@ def std(x, axis=None, keepdims=False):
@keras_export('keras.backend.mean')
@dispatch.add_dispatch_support
def mean(x, axis=None, keepdims=False):
"""Mean of a tensor, alongside the specified axis.
@ -2127,6 +2149,7 @@ def mean(x, axis=None, keepdims=False):
@keras_export('keras.backend.any')
@dispatch.add_dispatch_support
def any(x, axis=None, keepdims=False):
"""Bitwise reduction (logical OR).
@ -2143,6 +2166,7 @@ def any(x, axis=None, keepdims=False):
@keras_export('keras.backend.all')
@dispatch.add_dispatch_support
def all(x, axis=None, keepdims=False):
"""Bitwise reduction (logical AND).
@ -2159,6 +2183,7 @@ def all(x, axis=None, keepdims=False):
@keras_export('keras.backend.argmax')
@dispatch.add_dispatch_support
def argmax(x, axis=-1):
"""Returns the index of the maximum value along an axis.
@ -2173,6 +2198,7 @@ def argmax(x, axis=-1):
@keras_export('keras.backend.argmin')
@dispatch.add_dispatch_support
def argmin(x, axis=-1):
"""Returns the index of the minimum value along an axis.
@ -2187,6 +2213,7 @@ def argmin(x, axis=-1):
@keras_export('keras.backend.square')
@dispatch.add_dispatch_support
def square(x):
"""Element-wise square.
@ -2200,6 +2227,7 @@ def square(x):
@keras_export('keras.backend.abs')
@dispatch.add_dispatch_support
def abs(x):
"""Element-wise absolute value.
@ -2213,6 +2241,7 @@ def abs(x):
@keras_export('keras.backend.sqrt')
@dispatch.add_dispatch_support
def sqrt(x):
"""Element-wise square root.
@ -2229,6 +2258,7 @@ def sqrt(x):
@keras_export('keras.backend.exp')
@dispatch.add_dispatch_support
def exp(x):
"""Element-wise exponential.
@ -2242,6 +2272,7 @@ def exp(x):
@keras_export('keras.backend.log')
@dispatch.add_dispatch_support
def log(x):
"""Element-wise log.
@ -2276,6 +2307,7 @@ def logsumexp(x, axis=None, keepdims=False):
@keras_export('keras.backend.round')
@dispatch.add_dispatch_support
def round(x):
"""Element-wise rounding to the closest integer.
@ -2291,6 +2323,7 @@ def round(x):
@keras_export('keras.backend.sign')
@dispatch.add_dispatch_support
def sign(x):
"""Element-wise sign.
@ -2304,6 +2337,7 @@ def sign(x):
@keras_export('keras.backend.pow')
@dispatch.add_dispatch_support
def pow(x, a):
"""Element-wise exponentiation.
@ -2318,6 +2352,7 @@ def pow(x, a):
@keras_export('keras.backend.clip')
@dispatch.add_dispatch_support
def clip(x, min_value, max_value):
"""Element-wise value clipping.
@ -2341,6 +2376,7 @@ def clip(x, min_value, max_value):
@keras_export('keras.backend.equal')
@dispatch.add_dispatch_support
def equal(x, y):
"""Element-wise equality between two tensors.
@ -2355,6 +2391,7 @@ def equal(x, y):
@keras_export('keras.backend.not_equal')
@dispatch.add_dispatch_support
def not_equal(x, y):
"""Element-wise inequality between two tensors.
@ -2369,6 +2406,7 @@ def not_equal(x, y):
@keras_export('keras.backend.greater')
@dispatch.add_dispatch_support
def greater(x, y):
"""Element-wise truth value of (x > y).
@ -2383,6 +2421,7 @@ def greater(x, y):
@keras_export('keras.backend.greater_equal')
@dispatch.add_dispatch_support
def greater_equal(x, y):
"""Element-wise truth value of (x >= y).
@ -2397,6 +2436,7 @@ def greater_equal(x, y):
@keras_export('keras.backend.less')
@dispatch.add_dispatch_support
def less(x, y):
"""Element-wise truth value of (x < y).
@ -2411,6 +2451,7 @@ def less(x, y):
@keras_export('keras.backend.less_equal')
@dispatch.add_dispatch_support
def less_equal(x, y):
"""Element-wise truth value of (x <= y).
@ -2425,6 +2466,7 @@ def less_equal(x, y):
@keras_export('keras.backend.maximum')
@dispatch.add_dispatch_support
def maximum(x, y):
"""Element-wise maximum of two tensors.
@ -2449,6 +2491,7 @@ def maximum(x, y):
@keras_export('keras.backend.minimum')
@dispatch.add_dispatch_support
def minimum(x, y):
"""Element-wise minimum of two tensors.
@ -2463,6 +2506,7 @@ def minimum(x, y):
@keras_export('keras.backend.sin')
@dispatch.add_dispatch_support
def sin(x):
"""Computes sin of x element-wise.
@ -2476,6 +2520,7 @@ def sin(x):
@keras_export('keras.backend.cos')
@dispatch.add_dispatch_support
def cos(x):
"""Computes cos of x element-wise.
@ -2621,6 +2666,7 @@ def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
@keras_export('keras.backend.batch_normalization')
@dispatch.add_dispatch_support
def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
"""Applies batch normalization on x given mean, var, beta and gamma.
@ -2683,6 +2729,7 @@ def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
@keras_export('keras.backend.concatenate')
@dispatch.add_dispatch_support
def concatenate(tensors, axis=-1):
"""Concatenates a list of tensors alongside the specified axis.
@ -2720,6 +2767,7 @@ def concatenate(tensors, axis=-1):
@keras_export('keras.backend.reshape')
@dispatch.add_dispatch_support
def reshape(x, shape):
"""Reshapes a tensor to the specified shape.
@ -2749,6 +2797,7 @@ def reshape(x, shape):
@keras_export('keras.backend.permute_dimensions')
@dispatch.add_dispatch_support
def permute_dimensions(x, pattern):
"""Permutes axes in a tensor.
@ -2780,6 +2829,7 @@ def permute_dimensions(x, pattern):
@keras_export('keras.backend.resize_images')
@dispatch.add_dispatch_support
def resize_images(x, height_factor, width_factor, data_format,
interpolation='nearest'):
"""Resizes the images contained in a 4D tensor.
@ -2843,6 +2893,7 @@ def resize_images(x, height_factor, width_factor, data_format,
@keras_export('keras.backend.resize_volumes')
@dispatch.add_dispatch_support
def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
"""Resizes the volume contained in a 5D tensor.
@ -2875,6 +2926,7 @@ def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
@keras_export('keras.backend.repeat_elements')
@dispatch.add_dispatch_support
def repeat_elements(x, rep, axis):
"""Repeats the elements of a tensor along an axis, like `np.repeat`.
@ -2936,6 +2988,7 @@ def repeat_elements(x, rep, axis):
@keras_export('keras.backend.repeat')
@dispatch.add_dispatch_support
def repeat(x, n):
"""Repeats a 2D tensor.
@ -2971,6 +3024,7 @@ def repeat(x, n):
@keras_export('keras.backend.arange')
@dispatch.add_dispatch_support
def arange(start, stop=None, step=1, dtype='int32'):
"""Creates a 1D tensor containing a sequence of integers.
@ -3009,6 +3063,7 @@ def arange(start, stop=None, step=1, dtype='int32'):
@keras_export('keras.backend.tile')
@dispatch.add_dispatch_support
def tile(x, n):
"""Creates a tensor by tiling `x` by `n`.
@ -3026,6 +3081,7 @@ def tile(x, n):
@keras_export('keras.backend.flatten')
@dispatch.add_dispatch_support
def flatten(x):
"""Flatten a tensor.
@ -3051,6 +3107,7 @@ def flatten(x):
@keras_export('keras.backend.batch_flatten')
@dispatch.add_dispatch_support
def batch_flatten(x):
"""Turn a nD tensor into a 2D tensor with same 0th dimension.
@ -3076,6 +3133,7 @@ def batch_flatten(x):
@keras_export('keras.backend.expand_dims')
@dispatch.add_dispatch_support
def expand_dims(x, axis=-1):
"""Adds a 1-sized dimension at index "axis".
@ -3090,6 +3148,7 @@ def expand_dims(x, axis=-1):
@keras_export('keras.backend.squeeze')
@dispatch.add_dispatch_support
def squeeze(x, axis):
"""Removes a 1-dimension from the tensor at index "axis".
@ -3104,6 +3163,7 @@ def squeeze(x, axis):
@keras_export('keras.backend.temporal_padding')
@dispatch.add_dispatch_support
def temporal_padding(x, padding=(1, 1)):
"""Pads the middle dimension of a 3D tensor.
@ -3121,6 +3181,7 @@ def temporal_padding(x, padding=(1, 1)):
@keras_export('keras.backend.spatial_2d_padding')
@dispatch.add_dispatch_support
def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
"""Pads the 2nd and 3rd dimensions of a 4D tensor.
@ -3152,6 +3213,7 @@ def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
@keras_export('keras.backend.spatial_3d_padding')
@dispatch.add_dispatch_support
def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
"""Pads 5D tensor with zeros along the depth, height, width dimensions.
@ -3196,6 +3258,7 @@ def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
@keras_export('keras.backend.stack')
@dispatch.add_dispatch_support
def stack(x, axis=0):
"""Stacks a list of rank `R` tensors into a rank `R+1` tensor.
@ -3222,6 +3285,7 @@ def stack(x, axis=0):
@keras_export('keras.backend.one_hot')
@dispatch.add_dispatch_support
def one_hot(indices, num_classes):
"""Computes the one-hot representation of an integer tensor.
@ -3241,6 +3305,7 @@ def one_hot(indices, num_classes):
@keras_export('keras.backend.reverse')
@dispatch.add_dispatch_support
def reverse(x, axes):
"""Reverse a tensor along the specified axes.
@ -3321,6 +3386,7 @@ def get_value(x):
@keras_export('keras.backend.batch_get_value')
@dispatch.add_dispatch_support
def batch_get_value(tensors):
"""Returns the value of more than one tensor variable.
@ -3382,6 +3448,7 @@ def set_value(x, value):
@keras_export('keras.backend.batch_set_value')
@dispatch.add_dispatch_support
def batch_set_value(tuples):
"""Sets the values of many tensor variables at once.
@ -3424,6 +3491,7 @@ set_value.__doc__ = set_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING)
@keras_export('keras.backend.print_tensor')
@dispatch.add_dispatch_support
def print_tensor(x, message=''):
"""Prints `message` and the tensor value when evaluated.
@ -3861,6 +3929,7 @@ def gradients(loss, variables):
@keras_export('keras.backend.stop_gradient')
@dispatch.add_dispatch_support
def stop_gradient(variables):
"""Returns `variables` but with zero gradient w.r.t. every other variable.
@ -3882,6 +3951,7 @@ def stop_gradient(variables):
@keras_export('keras.backend.rnn')
@dispatch.add_dispatch_support
def rnn(step_function,
inputs,
initial_states,
@ -4276,6 +4346,7 @@ def rnn(step_function,
@keras_export('keras.backend.switch')
@dispatch.add_dispatch_support
def switch(condition, then_expression, else_expression):
"""Switches between two operations depending on a scalar value.
@ -4409,6 +4480,7 @@ def in_test_phase(x, alt, training=None):
@keras_export('keras.backend.relu')
@dispatch.add_dispatch_support
def relu(x, alpha=0., max_value=None, threshold=0):
"""Rectified linear unit.
@ -4462,6 +4534,7 @@ def relu(x, alpha=0., max_value=None, threshold=0):
@keras_export('keras.backend.elu')
@dispatch.add_dispatch_support
def elu(x, alpha=1.):
"""Exponential linear unit.
@ -4480,6 +4553,7 @@ def elu(x, alpha=1.):
@keras_export('keras.backend.softmax')
@dispatch.add_dispatch_support
def softmax(x, axis=-1):
"""Softmax of a tensor.
@ -4495,6 +4569,7 @@ def softmax(x, axis=-1):
@keras_export('keras.backend.softplus')
@dispatch.add_dispatch_support
def softplus(x):
"""Softplus of a tensor.
@ -4508,6 +4583,7 @@ def softplus(x):
@keras_export('keras.backend.softsign')
@dispatch.add_dispatch_support
def softsign(x):
"""Softsign of a tensor.
@ -4527,6 +4603,7 @@ def _backtrack_identity(tensor):
@keras_export('keras.backend.categorical_crossentropy')
@dispatch.add_dispatch_support
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
"""Categorical crossentropy between an output tensor and a target tensor.
@ -4595,6 +4672,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
@keras_export('keras.backend.sparse_categorical_crossentropy')
@dispatch.add_dispatch_support
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
"""Categorical crossentropy with integer targets.
@ -4676,6 +4754,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
@keras_export('keras.backend.binary_crossentropy')
@dispatch.add_dispatch_support
def binary_crossentropy(target, output, from_logits=False):
"""Binary crossentropy between an output tensor and a target tensor.
@ -4712,6 +4791,7 @@ def binary_crossentropy(target, output, from_logits=False):
@keras_export('keras.backend.sigmoid')
@dispatch.add_dispatch_support
def sigmoid(x):
"""Element-wise sigmoid.
@ -4725,6 +4805,7 @@ def sigmoid(x):
@keras_export('keras.backend.hard_sigmoid')
@dispatch.add_dispatch_support
def hard_sigmoid(x):
"""Segment-wise linear approximation of sigmoid.
@ -4747,6 +4828,7 @@ def hard_sigmoid(x):
@keras_export('keras.backend.tanh')
@dispatch.add_dispatch_support
def tanh(x):
"""Element-wise tanh.
@ -4760,6 +4842,7 @@ def tanh(x):
@keras_export('keras.backend.dropout')
@dispatch.add_dispatch_support
def dropout(x, level, noise_shape=None, seed=None):
"""Sets entries in `x` to zero at random, while scaling the entire tensor.
@ -4780,6 +4863,7 @@ def dropout(x, level, noise_shape=None, seed=None):
@keras_export('keras.backend.l2_normalize')
@dispatch.add_dispatch_support
def l2_normalize(x, axis=None):
"""Normalizes a tensor wrt the L2 norm alongside the specified axis.
@ -4794,6 +4878,7 @@ def l2_normalize(x, axis=None):
@keras_export('keras.backend.in_top_k')
@dispatch.add_dispatch_support
def in_top_k(predictions, targets, k):
"""Returns whether the `targets` are in the top `k` `predictions`.
@ -4896,6 +4981,7 @@ def _preprocess_padding(padding):
@keras_export('keras.backend.conv1d')
@dispatch.add_dispatch_support
def conv1d(x,
kernel,
strides=1,
@ -4946,6 +5032,7 @@ def conv1d(x,
@keras_export('keras.backend.conv2d')
@dispatch.add_dispatch_support
def conv2d(x,
kernel,
strides=(1, 1),
@ -4989,6 +5076,7 @@ def conv2d(x,
@keras_export('keras.backend.conv2d_transpose')
@dispatch.add_dispatch_support
def conv2d_transpose(x,
kernel,
output_shape,
@ -5129,6 +5217,7 @@ def separable_conv1d(x,
@keras_export('keras.backend.separable_conv2d')
@dispatch.add_dispatch_support
def separable_conv2d(x,
depthwise_kernel,
pointwise_kernel,
@ -5186,6 +5275,7 @@ def separable_conv2d(x,
@keras_export('keras.backend.depthwise_conv2d')
@dispatch.add_dispatch_support
def depthwise_conv2d(x,
depthwise_kernel,
strides=(1, 1),
@ -5235,6 +5325,7 @@ def depthwise_conv2d(x,
@keras_export('keras.backend.conv3d')
@dispatch.add_dispatch_support
def conv3d(x,
kernel,
strides=(1, 1, 1),
@ -5337,6 +5428,7 @@ def conv3d_transpose(x,
@keras_export('keras.backend.pool2d')
@dispatch.add_dispatch_support
def pool2d(x,
pool_size,
strides=(1, 1),
@ -5396,6 +5488,7 @@ def pool2d(x,
@keras_export('keras.backend.pool3d')
@dispatch.add_dispatch_support
def pool3d(x,
pool_size,
strides=(1, 1, 1),
@ -5526,6 +5619,7 @@ def local_conv(inputs,
@keras_export('keras.backend.local_conv1d')
@dispatch.add_dispatch_support
def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
"""Apply 1D conv with un-shared weights.
@ -5561,6 +5655,7 @@ def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
@keras_export('keras.backend.local_conv2d')
@dispatch.add_dispatch_support
def local_conv2d(inputs,
kernel,
kernel_size,
@ -5602,6 +5697,7 @@ def local_conv2d(inputs,
@keras_export('keras.backend.bias_add')
@dispatch.add_dispatch_support
def bias_add(x, bias, data_format=None):
"""Adds a bias vector to a tensor.
@ -5646,6 +5742,7 @@ def bias_add(x, bias, data_format=None):
@keras_export('keras.backend.random_normal')
@dispatch.add_dispatch_support
def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
"""Returns a tensor with normal distribution of values.
@ -5682,6 +5779,7 @@ def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
@keras_export('keras.backend.random_uniform')
@dispatch.add_dispatch_support
def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
"""Returns a tensor with uniform distribution of values.
@ -5715,6 +5813,7 @@ def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
@deprecated(None, 'Use `tf.keras.backend.random_bernoulli` instead.')
@keras_export('keras.backend.random_binomial')
@dispatch.add_dispatch_support
def random_binomial(shape, p=0.0, dtype=None, seed=None):
"""Returns a tensor with random binomial distribution of values.
@ -5751,6 +5850,7 @@ def random_binomial(shape, p=0.0, dtype=None, seed=None):
@keras_export('keras.backend.random_bernoulli')
@dispatch.add_dispatch_support
def random_bernoulli(shape, p=0.0, dtype=None, seed=None):
"""Returns a tensor with random bernoulli distribution of values.
@ -5767,6 +5867,7 @@ def random_bernoulli(shape, p=0.0, dtype=None, seed=None):
@keras_export('keras.backend.truncated_normal')
@dispatch.add_dispatch_support
def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
"""Returns a tensor with truncated random normal distribution of values.
@ -5801,6 +5902,7 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
@keras_export('keras.backend.ctc_label_dense_to_sparse')
@dispatch.add_dispatch_support
def ctc_label_dense_to_sparse(labels, label_lengths):
"""Converts CTC labels from dense to sparse.
@ -5847,6 +5949,7 @@ def ctc_label_dense_to_sparse(labels, label_lengths):
@keras_export('keras.backend.ctc_batch_cost')
@dispatch.add_dispatch_support
def ctc_batch_cost(y_true, y_pred, input_length, label_length):
"""Runs CTC loss algorithm on each batch element.
@ -5879,6 +5982,7 @@ def ctc_batch_cost(y_true, y_pred, input_length, label_length):
@keras_export('keras.backend.ctc_decode')
@dispatch.add_dispatch_support
def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
"""Decodes the output of a softmax.

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import keras_export
# The type of float to use throughout a session.
@ -30,6 +31,7 @@ _IMAGE_DATA_FORMAT = 'channels_last'
@keras_export('keras.backend.epsilon')
@dispatch.add_dispatch_support
def epsilon():
"""Returns the value of the fuzz factor used in numeric expressions.
@ -110,6 +112,7 @@ def set_floatx(value):
@keras_export('keras.backend.image_data_format')
@dispatch.add_dispatch_support
def image_data_format():
"""Returns the default image data format convention.

View File

@ -38,6 +38,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.ops.losses import util as tf_losses_util
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls
@ -1164,6 +1165,7 @@ class Huber(LossFunctionWrapper):
'keras.losses.mean_squared_error',
'keras.losses.mse',
'keras.losses.MSE')
@dispatch.add_dispatch_support
def mean_squared_error(y_true, y_pred):
"""Computes the mean squared error between labels and predictions.
@ -1199,6 +1201,7 @@ def mean_squared_error(y_true, y_pred):
'keras.losses.mean_absolute_error',
'keras.losses.mae',
'keras.losses.MAE')
@dispatch.add_dispatch_support
def mean_absolute_error(y_true, y_pred):
"""Computes the mean absolute error between labels and predictions.
@ -1231,6 +1234,7 @@ def mean_absolute_error(y_true, y_pred):
'keras.losses.mean_absolute_percentage_error',
'keras.losses.mape',
'keras.losses.MAPE')
@dispatch.add_dispatch_support
def mean_absolute_percentage_error(y_true, y_pred):
"""Computes the mean absolute percentage error between `y_true` and `y_pred`.
@ -1267,6 +1271,7 @@ def mean_absolute_percentage_error(y_true, y_pred):
'keras.losses.mean_squared_logarithmic_error',
'keras.losses.msle',
'keras.losses.MSLE')
@dispatch.add_dispatch_support
def mean_squared_logarithmic_error(y_true, y_pred):
"""Computes the mean squared logarithmic error between `y_true` and `y_pred`.
@ -1315,6 +1320,7 @@ def _maybe_convert_labels(y_true):
@keras_export('keras.metrics.squared_hinge', 'keras.losses.squared_hinge')
@dispatch.add_dispatch_support
def squared_hinge(y_true, y_pred):
"""Computes the squared hinge loss between `y_true` and `y_pred`.
@ -1347,6 +1353,7 @@ def squared_hinge(y_true, y_pred):
@keras_export('keras.metrics.hinge', 'keras.losses.hinge')
@dispatch.add_dispatch_support
def hinge(y_true, y_pred):
"""Computes the hinge loss between `y_true` and `y_pred`.
@ -1378,6 +1385,7 @@ def hinge(y_true, y_pred):
@keras_export('keras.losses.categorical_hinge')
@dispatch.add_dispatch_support
def categorical_hinge(y_true, y_pred):
"""Computes the categorical hinge loss between `y_true` and `y_pred`.
@ -1410,6 +1418,7 @@ def categorical_hinge(y_true, y_pred):
@keras_export('keras.losses.huber', v1=[])
@dispatch.add_dispatch_support
def huber(y_true, y_pred, delta=1.0):
"""Computes Huber loss value.
@ -1447,6 +1456,7 @@ def huber(y_true, y_pred, delta=1.0):
@keras_export('keras.losses.log_cosh', 'keras.losses.logcosh')
@dispatch.add_dispatch_support
def log_cosh(y_true, y_pred):
"""Logarithm of the hyperbolic cosine of the prediction error.
@ -1485,6 +1495,7 @@ def log_cosh(y_true, y_pred):
@keras_export('keras.metrics.categorical_crossentropy',
'keras.losses.categorical_crossentropy')
@dispatch.add_dispatch_support
def categorical_crossentropy(y_true,
y_pred,
from_logits=False,
@ -1525,6 +1536,7 @@ def categorical_crossentropy(y_true,
@keras_export('keras.metrics.sparse_categorical_crossentropy',
'keras.losses.sparse_categorical_crossentropy')
@dispatch.add_dispatch_support
def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
"""Computes the sparse categorical crossentropy loss.
@ -1556,6 +1568,7 @@ def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
@keras_export('keras.metrics.binary_crossentropy',
'keras.losses.binary_crossentropy')
@dispatch.add_dispatch_support
def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0):
"""Computes the binary crossentropy loss.
@ -1599,6 +1612,7 @@ def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0):
'keras.losses.kullback_leibler_divergence',
'keras.losses.kld',
'keras.losses.KLD')
@dispatch.add_dispatch_support
def kl_divergence(y_true, y_pred):
"""Computes Kullback-Leibler divergence loss between `y_true` and `y_pred`.
@ -1635,6 +1649,7 @@ def kl_divergence(y_true, y_pred):
@keras_export('keras.metrics.poisson', 'keras.losses.poisson')
@dispatch.add_dispatch_support
def poisson(y_true, y_pred):
"""Computes the Poisson loss between y_true and y_pred.
@ -1676,6 +1691,7 @@ def poisson(y_true, y_pred):
'keras.losses.cosine',
'keras.losses.cosine_similarity',
])
@dispatch.add_dispatch_support
def cosine_similarity(y_true, y_pred, axis=-1):
"""Computes the cosine similarity between labels and predictions.

View File

@ -69,6 +69,7 @@ from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.ops.losses import util as tf_losses_utils
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import keras_export
@ -3212,6 +3213,7 @@ def accuracy(y_true, y_pred):
@keras_export('keras.metrics.binary_accuracy')
@dispatch.add_dispatch_support
def binary_accuracy(y_true, y_pred, threshold=0.5):
"""Calculates how often predictions matches binary labels.
@ -3239,6 +3241,7 @@ def binary_accuracy(y_true, y_pred, threshold=0.5):
@keras_export('keras.metrics.categorical_accuracy')
@dispatch.add_dispatch_support
def categorical_accuracy(y_true, y_pred):
"""Calculates how often predictions matches one-hot labels.
@ -3267,6 +3270,7 @@ def categorical_accuracy(y_true, y_pred):
@keras_export('keras.metrics.sparse_categorical_accuracy')
@dispatch.add_dispatch_support
def sparse_categorical_accuracy(y_true, y_pred):
"""Calculates how often predictions matches integer labels.
@ -3307,6 +3311,7 @@ def sparse_categorical_accuracy(y_true, y_pred):
@keras_export('keras.metrics.top_k_categorical_accuracy')
@dispatch.add_dispatch_support
def top_k_categorical_accuracy(y_true, y_pred, k=5):
"""Computes how often targets are in the top `K` predictions.
@ -3332,6 +3337,7 @@ def top_k_categorical_accuracy(y_true, y_pred, k=5):
@keras_export('keras.metrics.sparse_top_k_categorical_accuracy')
@dispatch.add_dispatch_support
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
"""Computes how often integer targets are in the top `K` predictions.

View File

@ -57,6 +57,7 @@ _BaseSlice = slice
@tf_export("reshape", v1=["reshape", "manip.reshape"])
@dispatch.add_dispatch_support
def reshape(tensor, shape, name=None): # pylint: disable=redefined-outer-name
r"""Reshapes a tensor.
@ -197,6 +198,7 @@ def reshape(tensor, shape, name=None): # pylint: disable=redefined-outer-name
@tf_export("fill")
@dispatch.add_dispatch_support
def fill(dims, value, name=None):
r"""Creates a tensor filled with a scalar value.
@ -455,6 +457,7 @@ listdiff.__doc__ = gen_array_ops.list_diff.__doc__ + "\n" + listdiff.__doc__
"This op will be removed after the deprecation date. "
"Please switch to tf.sets.difference().")
@tf_export(v1=["setdiff1d"])
@dispatch.add_dispatch_support
def setdiff1d(x, y, index_dtype=dtypes.int32, name=None):
"""Computes the difference between two lists of numbers or strings.
@ -498,6 +501,7 @@ setdiff1d.__doc__ = gen_array_ops.list_diff.__doc__
@tf_export("broadcast_dynamic_shape")
@dispatch.add_dispatch_support
def broadcast_dynamic_shape(shape_x, shape_y):
"""Computes the shape of a broadcast given symbolic shapes.
@ -523,6 +527,7 @@ def broadcast_dynamic_shape(shape_x, shape_y):
@tf_export("broadcast_static_shape")
@dispatch.add_dispatch_support
def broadcast_static_shape(shape_x, shape_y):
"""Computes the shape of a broadcast given known shapes.
@ -550,6 +555,7 @@ def broadcast_static_shape(shape_x, shape_y):
@tf_export("shape", v1=[])
@dispatch.add_dispatch_support
def shape_v2(input, out_type=dtypes.int32, name=None):
# pylint: disable=redefined-builtin
"""Returns the shape of a tensor.
@ -596,6 +602,7 @@ def shape_v2(input, out_type=dtypes.int32, name=None):
@tf_export(v1=["shape"])
@dispatch.add_dispatch_support
def shape(input, name=None, out_type=dtypes.int32):
# pylint: disable=redefined-builtin
"""Returns the shape of a tensor.
@ -650,6 +657,7 @@ def shape_internal(input, name=None, optimize=True, out_type=dtypes.int32):
@tf_export("shape_n")
@dispatch.add_dispatch_support
def shape_n(input, out_type=dtypes.int32, name=None):
# pylint: disable=redefined-builtin
"""Returns shape of tensors.
@ -1007,6 +1015,7 @@ def _slice_helper(tensor, slice_spec, var=None):
# pylint: disable=undefined-variable,protected-access,redefined-outer-name
@tf_export("slice")
@dispatch.add_dispatch_support
def slice(input_, begin, size, name=None):
# pylint: disable=redefined-builtin
"""Extracts a slice from a tensor.
@ -1062,6 +1071,7 @@ def slice(input_, begin, size, name=None):
# pylint: disable=invalid-name
@tf_export("strided_slice")
@dispatch.add_dispatch_support
def strided_slice(input_,
begin,
end,
@ -1253,6 +1263,7 @@ ops.Tensor._override_operator("__getitem__", _slice_helper)
@tf_export("parallel_stack")
@dispatch.add_dispatch_support
def parallel_stack(values, name="parallel_stack"):
"""Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor in parallel.
@ -1489,6 +1500,7 @@ ops.register_tensor_conversion_function((list, tuple),
@tf_export("unstack")
@dispatch.add_dispatch_support
def unstack(value, num=None, axis=0, name="unstack"):
"""Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
@ -1632,6 +1644,7 @@ def concat(values, axis, name="concat"):
@tf_export(v1=["boolean_mask"])
@dispatch.add_dispatch_support
def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
"""Apply boolean mask to tensor.
@ -1824,6 +1837,7 @@ def sparse_mask(a, mask_indices, name=None):
@tf_export("unique")
@dispatch.add_dispatch_support
def unique(x, out_idx=dtypes.int32, name=None):
"""Finds unique elements in a 1-D tensor.
@ -1871,6 +1885,7 @@ unique.__doc__ = gen_array_ops.unique.__doc__
@tf_export("unique_with_counts")
@dispatch.add_dispatch_support
def unique_with_counts(x, out_idx=dtypes.int32, name=None):
"""Finds unique elements in a 1-D tensor.
@ -1923,6 +1938,7 @@ unique_with_counts.__doc__ = gen_array_ops.unique_with_counts.__doc__
@tf_export("split")
@dispatch.add_dispatch_support
def split(value, num_or_size_splits, axis=0, num=None, name="split"):
"""Splits a tensor `value` into a list of sub tensors.
@ -2000,6 +2016,7 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"):
@tf_export("transpose", v1=[])
@dispatch.add_dispatch_support
def transpose_v2(a, perm=None, conjugate=False, name="transpose"):
"""Transposes `a`, where `a` is a Tensor.
@ -2080,6 +2097,7 @@ def transpose_v2(a, perm=None, conjugate=False, name="transpose"):
@tf_export(v1=["transpose"])
@dispatch.add_dispatch_support
def transpose(a, perm=None, name="transpose", conjugate=False):
"""Transposes `a`.
@ -2170,6 +2188,7 @@ def transpose(a, perm=None, name="transpose", conjugate=False):
@tf_export(
"linalg.matrix_transpose",
v1=["linalg.transpose", "linalg.matrix_transpose", "matrix_transpose"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("matrix_transpose", "linalg.transpose")
def matrix_transpose(a, name="matrix_transpose", conjugate=False):
"""Transposes last two dimensions of tensor `a`.
@ -2248,6 +2267,7 @@ def matrix_transpose(a, name="matrix_transpose", conjugate=False):
@tf_export("linalg.diag", v1=["linalg.diag", "matrix_diag"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("matrix_diag")
def matrix_diag(diagonal,
name="diag",
@ -2416,6 +2436,7 @@ def matrix_diag(diagonal,
@tf_export("linalg.diag_part", v1=["linalg.diag_part", "matrix_diag_part"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("matrix_diag_part")
@dispatch.add_dispatch_support
def matrix_diag_part(
@ -2556,6 +2577,7 @@ def matrix_diag_part(
@tf_export("linalg.set_diag", v1=["linalg.set_diag", "matrix_set_diag"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("matrix_set_diag")
def matrix_set_diag(
input, # pylint:disable=redefined-builtin
@ -2719,6 +2741,7 @@ def _tag_zeros_tensor(fun):
@tf_export("zeros")
@dispatch.add_dispatch_support
@_tag_zeros_tensor
def zeros(shape, dtype=dtypes.float32, name=None):
"""Creates a tensor with all elements set to zero.
@ -2971,6 +2994,7 @@ def ones_like_impl(tensor, dtype, name, optimize=True):
@tf_export("ones")
@dispatch.add_dispatch_support
def ones(shape, dtype=dtypes.float32, name=None):
"""Creates a tensor with all elements set to one (1).
@ -3182,6 +3206,7 @@ def sparse_placeholder(dtype, shape=None, name=None):
@tf_export("pad", v1=[])
@dispatch.add_dispatch_support
def pad_v2(tensor, paddings, mode="CONSTANT", constant_values=0, name=None):
"""Pads a tensor.
@ -3240,6 +3265,7 @@ def pad_v2(tensor, paddings, mode="CONSTANT", constant_values=0, name=None):
@tf_export(v1=["pad"])
@dispatch.add_dispatch_support
def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pylint: disable=invalid-name
"""Pads a tensor.
@ -3357,6 +3383,7 @@ def _get_paddings_constant(paddings):
@tf_export("meshgrid")
@dispatch.add_dispatch_support
def meshgrid(*args, **kwargs):
"""Broadcasts parameters for evaluation on an N-D grid.
@ -3500,6 +3527,7 @@ def _TileGradShape(op):
@tf_export("edit_distance")
@dispatch.add_dispatch_support
def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"):
"""Computes the Levenshtein distance between sequences.
@ -3694,6 +3722,7 @@ def required_space_to_batch_paddings(input_shape,
@tf_export(v1=["nn.space_to_batch", "space_to_batch"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("space_to_batch")
def space_to_batch( # pylint: disable=missing-docstring
input, # pylint: disable=redefined-builtin
@ -3717,6 +3746,7 @@ space_to_batch.__doc__ = gen_array_ops.space_to_batch.__doc__
@tf_export("space_to_batch", "nn.space_to_batch", v1=[])
@dispatch.add_dispatch_support
def space_to_batch_v2(input, block_shape, paddings, name=None): # pylint: disable=redefined-builtin
return space_to_batch_nd(input, block_shape, paddings, name)
@ -3725,6 +3755,7 @@ space_to_batch_v2.__doc__ = gen_array_ops.space_to_batch_nd.__doc__
@tf_export(v1=["nn.space_to_depth", "space_to_depth"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("space_to_depth")
def space_to_depth(input, block_size, name=None, data_format="NHWC"): # pylint: disable=redefined-builtin
return gen_array_ops.space_to_depth(input, block_size, data_format, name=name)
@ -3734,6 +3765,7 @@ space_to_depth.__doc__ = gen_array_ops.space_to_depth.__doc__
@tf_export("nn.space_to_depth", v1=[])
@dispatch.add_dispatch_support
def space_to_depth_v2(input, block_size, data_format="NHWC", name=None): # pylint: disable=redefined-builtin
return gen_array_ops.space_to_depth(input, block_size, data_format, name=name)
@ -3742,6 +3774,7 @@ space_to_depth_v2.__doc__ = gen_array_ops.space_to_depth.__doc__
@tf_export(v1=["nn.depth_to_space", "depth_to_space"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("depth_to_space")
def depth_to_space(input, block_size, name=None, data_format="NHWC"): # pylint: disable=redefined-builtin
return gen_array_ops.depth_to_space(input, block_size, data_format, name=name)
@ -3751,6 +3784,7 @@ depth_to_space.__doc__ = gen_array_ops.depth_to_space.__doc__
@tf_export("nn.depth_to_space", v1=[])
@dispatch.add_dispatch_support
def depth_to_space_v2(input, block_size, data_format="NHWC", name=None): # pylint: disable=redefined-builtin
return gen_array_ops.depth_to_space(input, block_size, data_format, name=name)
@ -3759,6 +3793,7 @@ depth_to_space_v2.__doc__ = gen_array_ops.depth_to_space.__doc__
@tf_export(v1=["batch_to_space"])
@dispatch.add_dispatch_support
def batch_to_space(input, crops, block_size, name=None, block_shape=None): # pylint: disable=redefined-builtin,missing-docstring
block_size = deprecation.deprecated_argument_lookup("block_shape",
block_shape, "block_size",
@ -3776,6 +3811,7 @@ batch_to_space.__doc__ = gen_array_ops.batch_to_space.__doc__
@tf_export("batch_to_space", v1=[])
@dispatch.add_dispatch_support
def batch_to_space_v2(input, block_shape, crops, name=None): # pylint: disable=redefined-builtin
"""BatchToSpace for N-D tensors of type T.
@ -4091,6 +4127,7 @@ def _all_dimensions(x):
@tf_export("sequence_mask")
@dispatch.add_dispatch_support
def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None):
"""Returns a mask tensor representing the first N positions of each cell.
@ -4317,6 +4354,7 @@ def where(condition, x=None, y=None, name=None):
@tf_export("where", v1=["where_v2"])
@dispatch.add_dispatch_support
def where_v2(condition, x=None, y=None, name=None):
"""Return the elements where `condition` is `True` (multiplexing `x` and `y`).
@ -5003,6 +5041,7 @@ def batch_gather_nd(params, indices, batch_dims, name=None):
# because round_mode was added later.
# (And also now because of 'axis' processing).
@tf_export(v1=["quantize_v2"])
@dispatch.add_dispatch_support
@deprecation.deprecated(
"2017-10-25",
"`tf.quantize_v2` is deprecated, please use `tf.quantization.quantize` "
@ -5056,6 +5095,7 @@ quantize_v2.__doc__ = """Please use `tf.quantization.quantize` instead."""
# tf.quantization.quantize; we can deprecate tf.quantization.quantize in next
# version of TensorFlow.
@tf_export("quantization.quantize", v1=["quantization.quantize", "quantize"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("quantize")
def quantize(
input, # pylint: disable=redefined-builtin
@ -5095,6 +5135,7 @@ def quantize(
@tf_export("quantization.dequantize", v1=["quantization.dequantize",
"dequantize"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("dequantize")
def dequantize( # pylint: disable=missing-docstring
input, # pylint: disable=redefined-builtin
@ -5130,6 +5171,7 @@ dequantize.__doc__ = gen_array_ops.dequantize.__doc__
@tf_export("quantization.quantize_and_dequantize")
@dispatch.add_dispatch_support
def quantize_and_dequantize(
input, # pylint: disable=redefined-builtin
input_min,
@ -5189,6 +5231,7 @@ def quantize_and_dequantize(
@tf_export("searchsorted")
@dispatch.add_dispatch_support
def searchsorted(sorted_sequence,
values,
side="left",
@ -5253,6 +5296,7 @@ quantize.__doc__ = gen_array_ops.quantize_v2.__doc__
@tf_export("image.extract_patches")
@dispatch.add_dispatch_support
def extract_image_patches_v2(images, sizes, strides, rates, padding, name=None):
r"""Extract `patches` from `images`.
@ -5374,6 +5418,7 @@ def extract_image_patches_v2(images, sizes, strides, rates, padding, name=None):
@tf_export(v1=["image.extract_image_patches", "extract_image_patches"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, "ksizes is deprecated, use sizes instead",
"ksizes")
def extract_image_patches( # pylint: disable=missing-docstring
@ -5422,6 +5467,7 @@ extract_image_patches.__doc__ = gen_array_ops.extract_image_patches.__doc__
@tf_export("fingerprint")
@dispatch.add_dispatch_support
def fingerprint(data, method="farmhash64", name=None):
r"""Generates fingerprint values.
@ -5668,6 +5714,7 @@ def _with_nonzero_rank(data):
@tf_export("repeat")
@dispatch.add_dispatch_support
def repeat(input, repeats, axis=None, name=None): # pylint: disable=redefined-builtin
"""Repeat elements of `input`.

View File

@ -24,12 +24,14 @@ from tensorflow.python.ops import array_ops # pylint: disable=unused-import
from tensorflow.python.ops import gen_candidate_sampling_ops
from tensorflow.python.ops import math_ops # pylint: disable=unused-import
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@tf_export(
'random.uniform_candidate_sampler',
v1=['random.uniform_candidate_sampler', 'nn.uniform_candidate_sampler'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('nn.uniform_candidate_sampler')
def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
range_max, seed=None, name=None):
@ -92,6 +94,7 @@ def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
'random.log_uniform_candidate_sampler',
'nn.log_uniform_candidate_sampler'
])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('nn.log_uniform_candidate_sampler')
def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
range_max, seed=None, name=None):
@ -154,6 +157,7 @@ def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
@tf_export(
'random.learned_unigram_candidate_sampler',
'nn.learned_unigram_candidate_sampler')
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints(['nn.learned_unigram_candidate_sampler'])
def learned_unigram_candidate_sampler(true_classes, num_true, num_sampled,
unique, range_max, seed=None, name=None):
@ -213,6 +217,7 @@ def learned_unigram_candidate_sampler(true_classes, num_true, num_sampled,
@tf_export('random.fixed_unigram_candidate_sampler',
'nn.fixed_unigram_candidate_sampler')
@dispatch.add_dispatch_support
def fixed_unigram_candidate_sampler(true_classes,
num_true,
num_sampled,
@ -341,6 +346,7 @@ def all_candidate_sampler(true_classes, num_true, num_sampled, unique,
@tf_export('nn.compute_accidental_hits')
@dispatch.add_dispatch_support
def compute_accidental_hits(true_classes, sampled_candidates, num_true,
seed=None, name=None):
"""Compute the position ids in `sampled_candidates` matching `true_classes`.

View File

@ -35,6 +35,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
NUMERIC_TYPES = frozenset(
@ -375,6 +376,7 @@ def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
@tf_export(
'debugging.assert_proper_iterable',
v1=['debugging.assert_proper_iterable', 'assert_proper_iterable'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_proper_iterable')
def assert_proper_iterable(values):
"""Static assert that values is a "proper" iterable.
@ -404,6 +406,7 @@ def assert_proper_iterable(values):
@tf_export('debugging.assert_negative', v1=[])
@dispatch.add_dispatch_support
def assert_negative_v2(x, message=None, summarize=None, name=None):
"""Assert the condition `x < 0` holds element-wise.
@ -436,6 +439,7 @@ def assert_negative_v2(x, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_negative', 'assert_negative'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_negative')
@_unary_assert_doc('< 0', 'negative')
def assert_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
@ -456,6 +460,7 @@ def assert_negative(x, data=None, summarize=None, message=None, name=None): # p
@tf_export('debugging.assert_positive', v1=[])
@dispatch.add_dispatch_support
def assert_positive_v2(x, message=None, summarize=None, name=None):
"""Assert the condition `x > 0` holds element-wise.
@ -488,6 +493,7 @@ def assert_positive_v2(x, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_positive', 'assert_positive'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_positive')
@_unary_assert_doc('> 0', 'positive')
def assert_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
@ -507,6 +513,7 @@ def assert_positive(x, data=None, summarize=None, message=None, name=None): # p
@tf_export('debugging.assert_non_negative', v1=[])
@dispatch.add_dispatch_support
def assert_non_negative_v2(x, message=None, summarize=None, name=None):
"""Assert the condition `x >= 0` holds element-wise.
@ -541,6 +548,7 @@ def assert_non_negative_v2(x, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_non_negative', 'assert_non_negative'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_non_negative')
@_unary_assert_doc('>= 0', 'non-negative')
def assert_non_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
@ -561,6 +569,7 @@ def assert_non_negative(x, data=None, summarize=None, message=None, name=None):
@tf_export('debugging.assert_non_positive', v1=[])
@dispatch.add_dispatch_support
def assert_non_positive_v2(x, message=None, summarize=None, name=None):
"""Assert the condition `x <= 0` holds element-wise.
@ -595,6 +604,7 @@ def assert_non_positive_v2(x, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_non_positive', 'assert_non_positive'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_non_positive')
@_unary_assert_doc('<= 0', 'non-positive')
def assert_non_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
@ -615,6 +625,7 @@ def assert_non_positive(x, data=None, summarize=None, message=None, name=None):
@tf_export('debugging.assert_equal', 'assert_equal', v1=[])
@dispatch.add_dispatch_support
def assert_equal_v2(x, y, message=None, summarize=None, name=None):
"""Assert the condition `x == y` holds element-wise.
@ -649,6 +660,7 @@ def assert_equal_v2(x, y, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_equal', 'assert_equal'])
@dispatch.add_dispatch_support
@_binary_assert_doc('==')
def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
with ops.name_scope(name, 'assert_equal', [x, y, data]):
@ -660,6 +672,7 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # p
@tf_export('debugging.assert_none_equal', v1=[])
@dispatch.add_dispatch_support
def assert_none_equal_v2(x, y, summarize=None, message=None, name=None):
"""Assert the condition `x != y` holds for all elements.
@ -698,6 +711,7 @@ def assert_none_equal_v2(x, y, summarize=None, message=None, name=None):
@tf_export(v1=['debugging.assert_none_equal', 'assert_none_equal'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_none_equal')
@_binary_assert_doc('!=')
def assert_none_equal(
@ -707,6 +721,7 @@ def assert_none_equal(
@tf_export('debugging.assert_near', v1=[])
@dispatch.add_dispatch_support
def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None,
name=None):
"""Assert the condition `x` and `y` are close element-wise.
@ -760,6 +775,7 @@ def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None,
@tf_export(v1=['debugging.assert_near', 'assert_near'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_near')
def assert_near(
x, y, rtol=None, atol=None, data=None, summarize=None, message=None,
@ -839,6 +855,7 @@ def assert_near(
@tf_export('debugging.assert_less', 'assert_less', v1=[])
@dispatch.add_dispatch_support
def assert_less_v2(x, y, message=None, summarize=None, name=None):
"""Assert the condition `x < y` holds element-wise.
@ -874,6 +891,7 @@ def assert_less_v2(x, y, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_less', 'assert_less'])
@dispatch.add_dispatch_support
@_binary_assert_doc('<')
def assert_less(x, y, data=None, summarize=None, message=None, name=None):
return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data,
@ -881,6 +899,7 @@ def assert_less(x, y, data=None, summarize=None, message=None, name=None):
@tf_export('debugging.assert_less_equal', v1=[])
@dispatch.add_dispatch_support
def assert_less_equal_v2(x, y, message=None, summarize=None, name=None):
"""Assert the condition `x <= y` holds element-wise.
@ -917,6 +936,7 @@ def assert_less_equal_v2(x, y, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_less_equal', 'assert_less_equal'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_less_equal')
@_binary_assert_doc('<=')
def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
@ -925,6 +945,7 @@ def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
@tf_export('debugging.assert_greater', 'assert_greater', v1=[])
@dispatch.add_dispatch_support
def assert_greater_v2(x, y, message=None, summarize=None, name=None):
"""Assert the condition `x > y` holds element-wise.
@ -961,6 +982,7 @@ def assert_greater_v2(x, y, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_greater', 'assert_greater'])
@dispatch.add_dispatch_support
@_binary_assert_doc('>')
def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x,
@ -968,6 +990,7 @@ def assert_greater(x, y, data=None, summarize=None, message=None, name=None): #
@tf_export('debugging.assert_greater_equal', v1=[])
@dispatch.add_dispatch_support
def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None):
"""Assert the condition `x >= y` holds element-wise.
@ -1005,6 +1028,7 @@ def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_greater_equal', 'assert_greater_equal'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_greater_equal')
@_binary_assert_doc('>=')
def assert_greater_equal(x, y, data=None, summarize=None, message=None,
@ -1062,6 +1086,7 @@ def _assert_rank_condition(
@tf_export('debugging.assert_rank', 'assert_rank', v1=[])
@dispatch.add_dispatch_support
def assert_rank_v2(x, rank, message=None, name=None):
"""Assert that `x` has rank equal to `rank`.
@ -1095,6 +1120,7 @@ def assert_rank_v2(x, rank, message=None, name=None):
@tf_export(v1=['debugging.assert_rank', 'assert_rank'])
@dispatch.add_dispatch_support
def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
"""Assert `x` has rank equal to `rank`.
@ -1157,6 +1183,7 @@ def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
@tf_export('debugging.assert_rank_at_least', v1=[])
@dispatch.add_dispatch_support
def assert_rank_at_least_v2(x, rank, message=None, name=None):
"""Assert that `x` has rank of at least `rank`.
@ -1190,6 +1217,7 @@ def assert_rank_at_least_v2(x, rank, message=None, name=None):
@tf_export(v1=['debugging.assert_rank_at_least', 'assert_rank_at_least'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_rank_at_least')
def assert_rank_at_least(
x, rank, data=None, summarize=None, message=None, name=None):
@ -1322,6 +1350,7 @@ def _assert_ranks_condition(
@tf_export('debugging.assert_rank_in', v1=[])
@dispatch.add_dispatch_support
def assert_rank_in_v2(x, ranks, message=None, name=None):
"""Assert that `x` has a rank in `ranks`.
@ -1354,6 +1383,7 @@ def assert_rank_in_v2(x, ranks, message=None, name=None):
@tf_export(v1=['debugging.assert_rank_in', 'assert_rank_in'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_rank_in')
def assert_rank_in(
x, ranks, data=None, summarize=None, message=None, name=None):
@ -1417,6 +1447,7 @@ def assert_rank_in(
@tf_export('debugging.assert_integer', v1=[])
@dispatch.add_dispatch_support
def assert_integer_v2(x, message=None, name=None):
"""Assert that `x` is of integer dtype.
@ -1437,6 +1468,7 @@ def assert_integer_v2(x, message=None, name=None):
@tf_export(v1=['debugging.assert_integer', 'assert_integer'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_integer')
def assert_integer(x, message=None, name=None):
"""Assert that `x` is of integer dtype.
@ -1476,6 +1508,7 @@ def assert_integer(x, message=None, name=None):
@tf_export('debugging.assert_type', v1=[])
@dispatch.add_dispatch_support
def assert_type_v2(tensor, tf_type, message=None, name=None):
"""Asserts that the given `Tensor` is of the specified type.
@ -1495,6 +1528,7 @@ def assert_type_v2(tensor, tf_type, message=None, name=None):
@tf_export(v1=['debugging.assert_type', 'assert_type'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_type')
def assert_type(tensor, tf_type, message=None, name=None):
"""Statically asserts that the given `Tensor` is of the specified type.
@ -1584,6 +1618,7 @@ _TensorDimSizes = collections.namedtuple(
@tf_export('debugging.assert_shapes', v1=[])
@dispatch.add_dispatch_support
def assert_shapes_v2(shapes, data=None, summarize=None, message=None,
name=None):
"""Assert tensor shapes and dimension size relationships between tensors.
@ -1650,6 +1685,7 @@ def assert_shapes_v2(shapes, data=None, summarize=None, message=None,
@tf_export(v1=['debugging.assert_shapes'])
@dispatch.add_dispatch_support
def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
"""Assert tensor shapes and dimension size relationships between tensors.
@ -1939,6 +1975,7 @@ def is_numeric_tensor(tensor):
'math.is_non_decreasing', 'debugging.is_non_decreasing',
'is_non_decreasing'
])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('debugging.is_non_decreasing',
'is_non_decreasing')
def is_non_decreasing(x, name=None):
@ -1980,6 +2017,7 @@ def is_non_decreasing(x, name=None):
'math.is_strictly_increasing', 'debugging.is_strictly_increasing',
'is_strictly_increasing'
])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('debugging.is_strictly_increasing',
'is_strictly_increasing')
def is_strictly_increasing(x, name=None):
@ -2066,6 +2104,7 @@ def _assert_same_base_type(items, expected_type=None):
@tf_export(
'debugging.assert_same_float_dtype',
v1=['debugging.assert_same_float_dtype', 'assert_same_float_dtype'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_same_float_dtype')
def assert_same_float_dtype(tensors=None, dtype=None):
"""Validate and return float type based on `tensors` and `dtype`.
@ -2098,6 +2137,7 @@ def assert_same_float_dtype(tensors=None, dtype=None):
@tf_export('debugging.assert_scalar', v1=[])
@dispatch.add_dispatch_support
def assert_scalar_v2(tensor, message=None, name=None):
"""Asserts that the given `tensor` is a scalar.
@ -2120,6 +2160,7 @@ def assert_scalar_v2(tensor, message=None, name=None):
@tf_export(v1=['debugging.assert_scalar', 'assert_scalar'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_scalar')
def assert_scalar(tensor, name=None, message=None):
"""Asserts that the given `tensor` is a scalar (i.e. zero-dimensional).
@ -2154,6 +2195,7 @@ def assert_scalar(tensor, name=None, message=None):
@tf_export('ensure_shape')
@dispatch.add_dispatch_support
def ensure_shape(x, shape, name=None):
"""Updates the shape of a tensor and checks at runtime that the shape holds.

View File

@ -152,6 +152,7 @@ def _clip_by_value_grad(op, grad):
@tf_export("clip_by_norm")
@dispatch.add_dispatch_support
def clip_by_norm(t, clip_norm, axes=None, name=None):
"""Clips tensor values to a maximum L2-norm.
@ -235,6 +236,7 @@ def clip_by_norm(t, clip_norm, axes=None, name=None):
@tf_export("linalg.global_norm", v1=["linalg.global_norm", "global_norm"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("global_norm")
def global_norm(t_list, name=None):
"""Computes the global norm of multiple tensors.
@ -285,6 +287,7 @@ def global_norm(t_list, name=None):
@tf_export("clip_by_global_norm")
@dispatch.add_dispatch_support
def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
"""Clips values of multiple tensors by the ratio of the sum of their norms.
@ -382,6 +385,7 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
"use clip_by_norm(t, clip_norm * tf.cast(tf.size(t), tf.float32), name) "
"instead.")
@tf_export(v1=["clip_by_average_norm"])
@dispatch.add_dispatch_support
def clip_by_average_norm(t, clip_norm, name=None):
"""Clips tensor values to a maximum average L2-norm.

View File

@ -27,6 +27,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@ -93,6 +94,7 @@ def remove_squeezable_dimensions(
@tf_export('math.confusion_matrix', v1=[])
@dispatch.add_dispatch_support
def confusion_matrix(labels,
predictions,
num_classes=None,
@ -202,6 +204,7 @@ def confusion_matrix(labels,
@tf_export(v1=['math.confusion_matrix', 'confusion_matrix'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('confusion_matrix', 'train.confusion_matrix')
def confusion_matrix_v1(labels,
predictions,

View File

@ -54,6 +54,7 @@ from tensorflow.python.ops.gen_control_flow_ops import *
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.lazy_loader import LazyLoader
@ -110,6 +111,7 @@ def _summarize_eager(tensor, summarize=None):
# Assert and Print are special symbols in python, so we must
# use an upper-case version of them.
@tf_export("debugging.Assert", "Assert")
@dispatch.add_dispatch_support
@tf_should_use.should_use_result
def Assert(condition, data, summarize=None, name=None):
"""Asserts that the given condition is true.
@ -1095,6 +1097,7 @@ def _UnpackIfSingleton(res):
# pylint: disable=redefined-outer-name
# pylint: disable=g-doc-args
@tf_export(v1=["cond"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(
None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
"fn1", "fn2")
@ -1318,6 +1321,7 @@ def _cast_indexed_slice_indices(a, b):
@tf_export("cond", v1=[])
@dispatch.add_dispatch_support
def cond_for_tf_v2(pred, true_fn=None, false_fn=None, name=None):
"""Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
@ -2942,6 +2946,7 @@ def group(*inputs, **kwargs):
@tf_export("tuple", v1=[])
@dispatch.add_dispatch_support
def tuple_v2(tensors, control_inputs=None, name=None):
"""Group tensors together.
@ -2978,6 +2983,7 @@ def tuple_v2(tensors, control_inputs=None, name=None):
@tf_export(v1=["tuple"])
@dispatch.add_dispatch_support
def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined-builtin
"""Group tensors together.
@ -3312,6 +3318,7 @@ def _indexed_case_helper(branch_fns, default, branch_index, name):
@tf_export("case", v1=[])
@dispatch.add_dispatch_support
def case_v2(pred_fn_pairs,
default=None,
exclusive=False,
@ -3416,6 +3423,7 @@ def case_v2(pred_fn_pairs,
@tf_export(v1=["case"])
@dispatch.add_dispatch_support
def case(pred_fn_pairs,
default=None,
exclusive=False,

View File

@ -43,6 +43,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.nn_grad import _BroadcastMul
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@ -70,6 +71,7 @@ def _generate_defun_backend(unique_api_name, preferred_device, func):
# pylint: disable=protected-access, invalid-name
@tf_export(v1=["nn.ctc_loss"])
@dispatch.add_dispatch_support
def ctc_loss(labels,
inputs=None,
sequence_length=None,
@ -284,6 +286,7 @@ def _CTCLossV2Grad(op, grad_loss, _):
@tf_export("nn.ctc_greedy_decoder")
@dispatch.add_dispatch_support
def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
"""Performs greedy decoding on the logits given in input (best path).
@ -333,6 +336,7 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
@tf_export(v1=["nn.ctc_beam_search_decoder"])
@dispatch.add_dispatch_support
def ctc_beam_search_decoder(inputs,
sequence_length,
beam_width=100,
@ -395,6 +399,7 @@ def ctc_beam_search_decoder(inputs,
@tf_export("nn.ctc_beam_search_decoder", v1=["nn.ctc_beam_search_decoder_v2"])
@dispatch.add_dispatch_support
def ctc_beam_search_decoder_v2(inputs,
sequence_length,
beam_width=100,
@ -731,6 +736,7 @@ def _ctc_loss_shape(op):
# pylint: disable=protected-access, invalid-name
@tf_export(v1=["nn.ctc_loss_v2"])
@dispatch.add_dispatch_support
def ctc_loss_v2(labels,
logits,
label_length,
@ -825,6 +831,7 @@ def ctc_loss_v2(labels,
@tf_export("nn.ctc_loss", v1=[])
@dispatch.add_dispatch_support
def ctc_loss_v3(labels,
logits,
label_length,
@ -1056,6 +1063,7 @@ def ctc_loss_dense(labels,
@tf_export("nn.collapse_repeated")
@dispatch.add_dispatch_support
def collapse_repeated(labels, seq_length, name=None):
"""Merge repeated labels into single labels.
@ -1153,6 +1161,7 @@ def dense_labels_to_sparse(dense, length):
@tf_export("nn.ctc_unique_labels")
@dispatch.add_dispatch_support
def ctc_unique_labels(labels, name=None):
"""Get unique labels and indices for batched labels for `tf.nn.ctc_loss`.

View File

@ -36,6 +36,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_functional_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@ -250,6 +251,7 @@ def _embedding_lookup_and_transform(params,
@tf_export(v1=["nn.embedding_lookup"])
@dispatch.add_dispatch_support
def embedding_lookup(
params,
ids,
@ -327,6 +329,7 @@ def embedding_lookup(
@tf_export("nn.embedding_lookup", v1=[])
@dispatch.add_dispatch_support
def embedding_lookup_v2(params, ids, max_norm=None, name=None):
"""Looks up embeddings for the given `ids` from a list of tensors.
@ -392,6 +395,7 @@ def embedding_lookup_v2(params, ids, max_norm=None, name=None):
@tf_export(v1=["nn.embedding_lookup_sparse"])
@dispatch.add_dispatch_support
def embedding_lookup_sparse(params,
sp_ids,
sp_weights,
@ -574,6 +578,7 @@ def embedding_lookup_sparse(params,
@tf_export("nn.embedding_lookup_sparse", v1=[])
@dispatch.add_dispatch_support
def embedding_lookup_sparse_v2(params,
sp_ids,
sp_weights,
@ -664,6 +669,7 @@ def embedding_lookup_sparse_v2(params,
@tf_export("nn.safe_embedding_lookup_sparse", v1=[])
@dispatch.add_dispatch_support
def safe_embedding_lookup_sparse_v2(embedding_weights,
sparse_ids,
sparse_weights=None,
@ -765,6 +771,7 @@ def safe_embedding_lookup_sparse_v2(embedding_weights,
@tf_export(v1=["nn.safe_embedding_lookup_sparse"])
@dispatch.add_dispatch_support
def safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
sparse_weights=None,

View File

@ -38,6 +38,7 @@ from tensorflow.python.ops.gen_functional_ops import remote_call
from tensorflow.python.ops.gen_functional_ops import symbolic_gradient
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@ -45,6 +46,7 @@ from tensorflow.python.util.tf_export import tf_export
# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
@tf_export(v1=["foldl"])
@dispatch.add_dispatch_support
def foldl(fn,
elems,
initializer=None,
@ -162,6 +164,7 @@ def foldl(fn,
@tf_export("foldl", v1=[])
@dispatch.add_dispatch_support
@deprecation.deprecated_arg_values(
None,
"""back_prop=False is deprecated. Consider using tf.stop_gradient instead.
@ -238,6 +241,7 @@ def foldl_v2(fn,
@tf_export(v1=["foldr"])
@dispatch.add_dispatch_support
def foldr(fn,
elems,
initializer=None,
@ -356,6 +360,7 @@ def foldr(fn,
@tf_export("foldr", v1=[])
@dispatch.add_dispatch_support
@deprecation.deprecated_arg_values(
None,
"""back_prop=False is deprecated. Consider using tf.stop_gradient instead.
@ -432,6 +437,7 @@ def foldr_v2(fn,
@tf_export(v1=["scan"])
@dispatch.add_dispatch_support
def scan(fn,
elems,
initializer=None,
@ -686,6 +692,7 @@ def scan(fn,
@tf_export("scan", v1=[])
@dispatch.add_dispatch_support
@deprecation.deprecated_arg_values(
None,
"""back_prop=False is deprecated. Consider using tf.stop_gradient instead.

View File

@ -26,10 +26,12 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@tf_export('histogram_fixed_width_bins')
@dispatch.add_dispatch_support
def histogram_fixed_width_bins(values,
value_range,
nbins=100,
@ -101,6 +103,7 @@ def histogram_fixed_width_bins(values,
@tf_export('histogram_fixed_width')
@dispatch.add_dispatch_support
def histogram_fixed_width(values,
value_range,
nbins=100,

View File

@ -39,6 +39,7 @@ from tensorflow.python.ops import sort_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
ops.NotDifferentiable('RandomCrop')
@ -323,6 +324,7 @@ def fix_image_flip_shape(image, result):
@tf_export('image.random_flip_up_down')
@dispatch.add_dispatch_support
def random_flip_up_down(image, seed=None):
"""Randomly flips an image vertically (upside down).
@ -363,6 +365,7 @@ def random_flip_up_down(image, seed=None):
@tf_export('image.random_flip_left_right')
@dispatch.add_dispatch_support
def random_flip_left_right(image, seed=None):
"""Randomly flip an image horizontally (left to right).
@ -450,6 +453,7 @@ def _random_flip(image, flip_index, seed, scope_name):
@tf_export('image.flip_left_right')
@dispatch.add_dispatch_support
def flip_left_right(image):
"""Flip an image horizontally (left to right).
@ -484,6 +488,7 @@ def flip_left_right(image):
@tf_export('image.flip_up_down')
@dispatch.add_dispatch_support
def flip_up_down(image):
"""Flip an image vertically (upside down).
@ -549,6 +554,7 @@ def _flip(image, flip_index, scope_name):
@tf_export('image.rot90')
@dispatch.add_dispatch_support
def rot90(image, k=1, name=None):
"""Rotate image(s) counter-clockwise by 90 degrees.
@ -660,6 +666,7 @@ def _rot90_4D(images, k, name_scope):
@tf_export('image.transpose', v1=['image.transpose', 'image.transpose_image'])
@dispatch.add_dispatch_support
def transpose(image, name=None):
"""Transpose image(s) by swapping the height and width dimension.
@ -718,6 +725,7 @@ def transpose(image, name=None):
@tf_export('image.central_crop')
@dispatch.add_dispatch_support
def central_crop(image, central_fraction):
"""Crop the central region of the image(s).
@ -850,6 +858,7 @@ def central_crop(image, central_fraction):
@tf_export('image.pad_to_bounding_box')
@dispatch.add_dispatch_support
def pad_to_bounding_box(image, offset_height, offset_width, target_height,
target_width):
"""Pad `image` with zeros to the specified `height` and `width`.
@ -959,6 +968,7 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height,
@tf_export('image.crop_to_bounding_box')
@dispatch.add_dispatch_support
def crop_to_bounding_box(image, offset_height, offset_width, target_height,
target_width):
"""Crops an image to a specified bounding box.
@ -1041,6 +1051,7 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height,
@tf_export(
'image.resize_with_crop_or_pad',
v1=['image.resize_with_crop_or_pad', 'image.resize_image_with_crop_or_pad'])
@dispatch.add_dispatch_support
def resize_image_with_crop_or_pad(image, target_height, target_width):
"""Crops and/or pads an image to a target width and height.
@ -1258,6 +1269,7 @@ def _resize_images_common(images, resizer_fn, size, preserve_aspect_ratio, name,
@tf_export(v1=['image.resize_images', 'image.resize'])
@dispatch.add_dispatch_support
def resize_images(images,
size,
method=ResizeMethodV1.BILINEAR,
@ -1343,6 +1355,7 @@ def resize_images(images,
@tf_export('image.resize', v1=[])
@dispatch.add_dispatch_support
def resize_images_v2(images,
size,
method=ResizeMethod.BILINEAR,
@ -1594,6 +1607,7 @@ def _resize_image_with_pad_common(image, target_height, target_width,
@tf_export(v1=['image.resize_image_with_pad'])
@dispatch.add_dispatch_support
def resize_image_with_pad_v1(image,
target_height,
target_width,
@ -1636,6 +1650,7 @@ def resize_image_with_pad_v1(image,
@tf_export('image.resize_with_pad', v1=[])
@dispatch.add_dispatch_support
def resize_image_with_pad_v2(image,
target_height,
target_width,
@ -1676,6 +1691,7 @@ def resize_image_with_pad_v2(image,
@tf_export('image.per_image_standardization')
@dispatch.add_dispatch_support
def per_image_standardization(image):
"""Linearly scales each image in `image` to have mean 0 and variance 1.
@ -1721,6 +1737,7 @@ def per_image_standardization(image):
@tf_export('image.random_brightness')
@dispatch.add_dispatch_support
def random_brightness(image, max_delta, seed=None):
"""Adjust the brightness of images by a random factor.
@ -1756,6 +1773,7 @@ def random_brightness(image, max_delta, seed=None):
@tf_export('image.random_contrast')
@dispatch.add_dispatch_support
def random_contrast(image, lower, upper, seed=None):
"""Adjust the contrast of an image or images by a random factor.
@ -1796,6 +1814,7 @@ def random_contrast(image, lower, upper, seed=None):
@tf_export('image.adjust_brightness')
@dispatch.add_dispatch_support
def adjust_brightness(image, delta):
"""Adjust the brightness of RGB or Grayscale images.
@ -1847,6 +1866,7 @@ def adjust_brightness(image, delta):
@tf_export('image.adjust_contrast')
@dispatch.add_dispatch_support
def adjust_contrast(images, contrast_factor):
"""Adjust contrast of RGB or grayscale images.
@ -1903,6 +1923,7 @@ def adjust_contrast(images, contrast_factor):
@tf_export('image.adjust_gamma')
@dispatch.add_dispatch_support
def adjust_gamma(image, gamma=1, gain=1):
"""Performs [Gamma Correction](http://en.wikipedia.org/wiki/Gamma_correction).
@ -1967,6 +1988,7 @@ def adjust_gamma(image, gamma=1, gain=1):
@tf_export('image.convert_image_dtype')
@dispatch.add_dispatch_support
def convert_image_dtype(image, dtype, saturate=False, name=None):
"""Convert `image` to `dtype`, scaling its values if needed.
@ -2066,6 +2088,7 @@ def convert_image_dtype(image, dtype, saturate=False, name=None):
@tf_export('image.rgb_to_grayscale')
@dispatch.add_dispatch_support
def rgb_to_grayscale(images, name=None):
"""Converts one or more images from RGB to Grayscale.
@ -2101,6 +2124,7 @@ def rgb_to_grayscale(images, name=None):
@tf_export('image.grayscale_to_rgb')
@dispatch.add_dispatch_support
def grayscale_to_rgb(images, name=None):
"""Converts one or more images from Grayscale to RGB.
@ -2137,6 +2161,7 @@ def grayscale_to_rgb(images, name=None):
# pylint: disable=invalid-name
@tf_export('image.random_hue')
@dispatch.add_dispatch_support
def random_hue(image, max_delta, seed=None):
"""Adjust the hue of RGB images by a random factor.
@ -2179,6 +2204,7 @@ def random_hue(image, max_delta, seed=None):
@tf_export('image.adjust_hue')
@dispatch.add_dispatch_support
def adjust_hue(image, delta, name=None):
"""Adjust hue of RGB images.
@ -2246,6 +2272,7 @@ def adjust_hue(image, delta, name=None):
# pylint: disable=invalid-name
@tf_export('image.random_jpeg_quality')
@dispatch.add_dispatch_support
def random_jpeg_quality(image, min_jpeg_quality, max_jpeg_quality, seed=None):
"""Randomly changes jpeg encoding quality for inducing jpeg noise.
@ -2293,6 +2320,7 @@ def random_jpeg_quality(image, min_jpeg_quality, max_jpeg_quality, seed=None):
@tf_export('image.adjust_jpeg_quality')
@dispatch.add_dispatch_support
def adjust_jpeg_quality(image, jpeg_quality, name=None):
"""Adjust jpeg encoding quality of an image.
@ -2343,6 +2371,7 @@ def adjust_jpeg_quality(image, jpeg_quality, name=None):
@tf_export('image.random_saturation')
@dispatch.add_dispatch_support
def random_saturation(image, lower, upper, seed=None):
"""Adjust the saturation of RGB images by a random factor.
@ -2389,6 +2418,7 @@ def random_saturation(image, lower, upper, seed=None):
@tf_export('image.adjust_saturation')
@dispatch.add_dispatch_support
def adjust_saturation(image, saturation_factor, name=None):
"""Adjust saturation of RGB images.
@ -2480,42 +2510,43 @@ tf_export(
'io.decode_and_crop_jpeg',
'image.decode_and_crop_jpeg',
v1=['io.decode_and_crop_jpeg', 'image.decode_and_crop_jpeg'])(
gen_image_ops.decode_and_crop_jpeg)
dispatch.add_dispatch_support(gen_image_ops.decode_and_crop_jpeg))
tf_export(
'io.decode_bmp',
'image.decode_bmp',
v1=['io.decode_bmp', 'image.decode_bmp'])(
gen_image_ops.decode_bmp)
dispatch.add_dispatch_support(gen_image_ops.decode_bmp))
tf_export(
'io.decode_gif',
'image.decode_gif',
v1=['io.decode_gif', 'image.decode_gif'])(
gen_image_ops.decode_gif)
dispatch.add_dispatch_support(gen_image_ops.decode_gif))
tf_export(
'io.decode_jpeg',
'image.decode_jpeg',
v1=['io.decode_jpeg', 'image.decode_jpeg'])(
gen_image_ops.decode_jpeg)
dispatch.add_dispatch_support(gen_image_ops.decode_jpeg))
tf_export(
'io.decode_png',
'image.decode_png',
v1=['io.decode_png', 'image.decode_png'])(
gen_image_ops.decode_png)
dispatch.add_dispatch_support(gen_image_ops.decode_png))
tf_export(
'io.encode_jpeg',
'image.encode_jpeg',
v1=['io.encode_jpeg', 'image.encode_jpeg'])(
gen_image_ops.encode_jpeg)
dispatch.add_dispatch_support(gen_image_ops.encode_jpeg))
tf_export(
'io.extract_jpeg_shape',
'image.extract_jpeg_shape',
v1=['io.extract_jpeg_shape', 'image.extract_jpeg_shape'])(
gen_image_ops.extract_jpeg_shape)
dispatch.add_dispatch_support(gen_image_ops.extract_jpeg_shape))
@tf_export('io.encode_png', 'image.encode_png')
@dispatch.add_dispatch_support
def encode_png(image, compression=-1, name=None):
r"""PNG-encode an image.
@ -2548,6 +2579,7 @@ def encode_png(image, compression=-1, name=None):
'io.decode_image',
'image.decode_image',
v1=['io.decode_image', 'image.decode_image'])
@dispatch.add_dispatch_support
def decode_image(contents,
channels=None,
dtype=dtypes.uint8,
@ -2661,6 +2693,7 @@ def decode_image(contents,
@tf_export('image.total_variation')
@dispatch.add_dispatch_support
def total_variation(images, name=None):
"""Calculate and return the total variation for one or more images.
@ -2732,6 +2765,7 @@ def total_variation(images, name=None):
@tf_export('image.sample_distorted_bounding_box', v1=[])
@dispatch.add_dispatch_support
def sample_distorted_bounding_box_v2(image_size,
bounding_boxes,
seed=0,
@ -2831,6 +2865,7 @@ def sample_distorted_bounding_box_v2(image_size,
@tf_export(v1=['image.sample_distorted_bounding_box'])
@dispatch.add_dispatch_support
@deprecation.deprecated(
date=None,
instructions='`seed2` arg is deprecated.'
@ -2945,6 +2980,7 @@ def sample_distorted_bounding_box(image_size,
@tf_export('image.non_max_suppression')
@dispatch.add_dispatch_support
def non_max_suppression(boxes,
scores,
max_output_size,
@ -2997,6 +3033,7 @@ def non_max_suppression(boxes,
@tf_export('image.non_max_suppression_with_scores')
@dispatch.add_dispatch_support
def non_max_suppression_with_scores(boxes,
scores,
max_output_size,
@ -3083,6 +3120,7 @@ def non_max_suppression_with_scores(boxes,
@tf_export('image.non_max_suppression_overlaps')
@dispatch.add_dispatch_support
def non_max_suppression_with_overlaps(overlaps,
scores,
max_output_size,
@ -3134,6 +3172,7 @@ _rgb_to_yiq_kernel = [[0.299, 0.59590059, 0.2115],
@tf_export('image.rgb_to_yiq')
@dispatch.add_dispatch_support
def rgb_to_yiq(images):
"""Converts one or more images from RGB to YIQ.
@ -3167,6 +3206,7 @@ _yiq_to_rgb_kernel = [[1, 1, 1], [0.95598634, -0.27201283, -1.10674021],
@tf_export('image.yiq_to_rgb')
@dispatch.add_dispatch_support
def yiq_to_rgb(images):
"""Converts one or more images from YIQ to RGB.
@ -3195,6 +3235,7 @@ _rgb_to_yuv_kernel = [[0.299, -0.14714119, 0.61497538],
@tf_export('image.rgb_to_yuv')
@dispatch.add_dispatch_support
def rgb_to_yuv(images):
"""Converts one or more images from RGB to YUV.
@ -3221,6 +3262,7 @@ _yuv_to_rgb_kernel = [[1, 1, 1], [0, -0.394642334, 2.03206185],
@tf_export('image.yuv_to_rgb')
@dispatch.add_dispatch_support
def yuv_to_rgb(images):
"""Converts one or more images from YUV to RGB.
@ -3314,6 +3356,7 @@ def _verify_compatible_image_shapes(img1, img2):
@tf_export('image.psnr')
@dispatch.add_dispatch_support
def psnr(a, b, max_val, name=None):
"""Returns the Peak Signal-to-Noise Ratio between a and b.
@ -3525,6 +3568,7 @@ def _ssim_per_channel(img1,
@tf_export('image.ssim')
@dispatch.add_dispatch_support
def ssim(img1,
img2,
max_val,
@ -3604,6 +3648,7 @@ _MSSSIM_WEIGHTS = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333)
@tf_export('image.ssim_multiscale')
@dispatch.add_dispatch_support
def ssim_multiscale(img1,
img2,
max_val,
@ -3731,6 +3776,7 @@ def ssim_multiscale(img1,
@tf_export('image.image_gradients')
@dispatch.add_dispatch_support
def image_gradients(image):
"""Returns image gradients (dy, dx) for each color channel.
@ -3804,6 +3850,7 @@ def image_gradients(image):
@tf_export('image.sobel_edges')
@dispatch.add_dispatch_support
def sobel_edges(image):
"""Returns a tensor holding Sobel edge maps.
@ -3888,21 +3935,22 @@ resize_area_deprecation = deprecation.deprecated(
instructions=(
'Use `tf.image.resize(...method=ResizeMethod.AREA...)` instead.'))
tf_export(v1=['image.resize_area'])(
resize_area_deprecation(gen_image_ops.resize_area))
resize_area_deprecation(
dispatch.add_dispatch_support(gen_image_ops.resize_area)))
resize_bicubic_deprecation = deprecation.deprecated(
date=None,
instructions=(
'Use `tf.image.resize(...method=ResizeMethod.BICUBIC...)` instead.'))
tf_export(v1=['image.resize_bicubic'])(
resize_bicubic_deprecation(resize_bicubic))
dispatch.add_dispatch_support(resize_bicubic_deprecation(resize_bicubic)))
resize_bilinear_deprecation = deprecation.deprecated(
date=None,
instructions=(
'Use `tf.image.resize(...method=ResizeMethod.BILINEAR...)` instead.'))
tf_export(v1=['image.resize_bilinear'])(
resize_bilinear_deprecation(resize_bilinear))
dispatch.add_dispatch_support(resize_bilinear_deprecation(resize_bilinear)))
resize_nearest_neighbor_deprecation = deprecation.deprecated(
date=None,
@ -3910,10 +3958,12 @@ resize_nearest_neighbor_deprecation = deprecation.deprecated(
'Use `tf.image.resize(...method=ResizeMethod.NEAREST_NEIGHBOR...)` '
'instead.'))
tf_export(v1=['image.resize_nearest_neighbor'])(
resize_nearest_neighbor_deprecation(resize_nearest_neighbor))
dispatch.add_dispatch_support(
resize_nearest_neighbor_deprecation(resize_nearest_neighbor)))
@tf_export('image.crop_and_resize', v1=[])
@dispatch.add_dispatch_support
def crop_and_resize_v2(image,
boxes,
box_indices,
@ -3997,6 +4047,7 @@ def crop_and_resize_v2(image,
@tf_export(v1=['image.crop_and_resize'])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None,
'box_ind is deprecated, use box_indices instead',
'box_ind')
@ -4019,6 +4070,7 @@ crop_and_resize_v1.__doc__ = gen_image_ops.crop_and_resize.__doc__
@tf_export(v1=['image.extract_glimpse'])
@dispatch.add_dispatch_support
def extract_glimpse(
input, # pylint: disable=redefined-builtin
size,
@ -4104,6 +4156,7 @@ def extract_glimpse(
@tf_export('image.extract_glimpse', v1=[])
@dispatch.add_dispatch_support
def extract_glimpse_v2(
input, # pylint: disable=redefined-builtin
size,
@ -4190,6 +4243,7 @@ def extract_glimpse_v2(
@tf_export('image.combined_non_max_suppression')
@dispatch.add_dispatch_support
def combined_non_max_suppression(boxes,
scores,
max_output_size_per_class,
@ -4442,6 +4496,7 @@ def _suppression_loop_body(boxes, iou_threshold, output_size, idx, tile_size):
@tf_export('image.non_max_suppression_padded')
@dispatch.add_dispatch_support
def non_max_suppression_padded(boxes,
scores,
max_output_size,
@ -4816,6 +4871,7 @@ def non_max_suppression_padded_v1(boxes,
@tf_export('image.draw_bounding_boxes', v1=[])
@dispatch.add_dispatch_support
def draw_bounding_boxes_v2(images, boxes, colors, name=None):
"""Draw bounding boxes on a batch of images.
@ -4870,6 +4926,7 @@ def draw_bounding_boxes_v2(images, boxes, colors, name=None):
@tf_export(v1=['image.draw_bounding_boxes'])
@dispatch.add_dispatch_support
def draw_bounding_boxes(images, boxes, name=None, colors=None):
"""Draw bounding boxes on a batch of images.
@ -4922,6 +4979,7 @@ def draw_bounding_boxes(images, boxes, name=None, colors=None):
@tf_export('image.generate_bounding_box_proposals')
@dispatch.add_dispatch_support
def generate_bounding_box_proposals(scores,
bbox_deltas,
image_info,

View File

@ -41,7 +41,7 @@ cholesky = linalg_ops.cholesky
cholesky_solve = linalg_ops.cholesky_solve
det = linalg_ops.matrix_determinant
slogdet = gen_linalg_ops.log_matrix_determinant
tf_export('linalg.slogdet')(slogdet)
tf_export('linalg.slogdet')(dispatch.add_dispatch_support(slogdet))
diag = array_ops.matrix_diag
diag_part = array_ops.matrix_diag_part
eigh = linalg_ops.self_adjoint_eig
@ -51,7 +51,7 @@ eye = linalg_ops.eye
inv = linalg_ops.matrix_inverse
logm = gen_linalg_ops.matrix_logarithm
lu = gen_linalg_ops.lu
tf_export('linalg.logm')(logm)
tf_export('linalg.logm')(dispatch.add_dispatch_support(logm))
lstsq = linalg_ops.matrix_solve_ls
norm = linalg_ops.norm
qr = linalg_ops.qr
@ -230,6 +230,7 @@ def _matrix_exp_pade13(matrix):
@tf_export('linalg.expm')
@dispatch.add_dispatch_support
def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
r"""Computes the matrix exponential of one or more square matrices.
@ -340,6 +341,7 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
@tf_export('linalg.tridiagonal_solve')
@dispatch.add_dispatch_support
def tridiagonal_solve(diagonals,
rhs,
diagonals_format='compact',
@ -541,6 +543,7 @@ def _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
@tf_export('linalg.tridiagonal_matmul')
@dispatch.add_dispatch_support
def tridiagonal_matmul(diagonals, rhs, diagonals_format='compact', name=None):
r"""Multiplies tridiagonal matrix by matrix.
@ -638,6 +641,7 @@ def _maybe_validate_matrix(a, validate_args):
@tf_export('linalg.matrix_rank')
@dispatch.add_dispatch_support
def matrix_rank(a, tol=None, validate_args=False, name=None):
"""Compute the matrix rank of one or more matrices.
@ -676,6 +680,7 @@ def matrix_rank(a, tol=None, validate_args=False, name=None):
@tf_export('linalg.pinv')
@dispatch.add_dispatch_support
def pinv(a, rcond=None, validate_args=False, name=None):
"""Compute the Moore-Penrose pseudo-inverse of one or more matrices.
@ -805,6 +810,7 @@ def pinv(a, rcond=None, validate_args=False, name=None):
@tf_export('linalg.lu_solve')
@dispatch.add_dispatch_support
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
"""Solves systems of linear eqns `A X = RHS`, given LU factorizations.
@ -902,6 +908,7 @@ def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
@tf_export('linalg.lu_matrix_inverse')
@dispatch.add_dispatch_support
def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None):
"""Computes the inverse given the LU decomposition(s) of one or more matrices.
@ -966,6 +973,7 @@ def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None):
@tf_export('linalg.lu_reconstruct')
@dispatch.add_dispatch_support
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
"""The reconstruct one or more matrices from their LU decomposition(s).

View File

@ -27,10 +27,12 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linalg_impl as linalg
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@tf_export('linalg.experimental.conjugate_gradient')
@dispatch.add_dispatch_support
def conjugate_gradient(operator,
rhs,
preconditioner=None,

View File

@ -32,6 +32,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops.gen_linalg_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
# Names below are lower_case.
@ -82,6 +83,7 @@ def _RegularizedGramianCholesky(matrix, l2_regularizer, first_kind):
@tf_export(
'linalg.triangular_solve',
v1=['linalg.triangular_solve', 'matrix_triangular_solve'])
@dispatch.add_dispatch_support
def matrix_triangular_solve(matrix, rhs, lower=True, adjoint=False, name=None):
"""Solve systems of linear equations with upper or lower triangular matrices.
@ -143,6 +145,7 @@ def matrix_triangular_solve(matrix, rhs, lower=True, adjoint=False, name=None):
@tf_export(
'linalg.cholesky_solve', v1=['linalg.cholesky_solve', 'cholesky_solve'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('cholesky_solve')
def cholesky_solve(chol, rhs, name=None):
"""Solves systems of linear eqns `A X = RHS`, given Cholesky factorizations.
@ -187,6 +190,7 @@ def cholesky_solve(chol, rhs, name=None):
@tf_export('eye', 'linalg.eye')
@dispatch.add_dispatch_support
def eye(num_rows,
num_columns=None,
batch_shape=None,
@ -234,6 +238,7 @@ def eye(num_rows,
@tf_export('linalg.lstsq', v1=['linalg.lstsq', 'matrix_solve_ls'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('matrix_solve_ls')
def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None):
r"""Solves one or more linear least-squares problems.
@ -371,6 +376,7 @@ def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None):
@tf_export('linalg.eig', 'eig', v1=[])
@dispatch.add_dispatch_support
def eig(tensor, name=None):
"""Computes the eigen decomposition of a batch of matrices.
@ -401,6 +407,7 @@ def eig(tensor, name=None):
@tf_export('linalg.eigvals', 'eigvals', v1=[])
@dispatch.add_dispatch_support
def eigvals(tensor, name=None):
"""Computes the eigenvalues of one or more matrices.
@ -427,6 +434,7 @@ def eigvals(tensor, name=None):
@tf_export('linalg.eigh', v1=['linalg.eigh', 'self_adjoint_eig'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('self_adjoint_eig')
def self_adjoint_eig(tensor, name=None):
"""Computes the eigen decomposition of a batch of self-adjoint matrices.
@ -450,6 +458,7 @@ def self_adjoint_eig(tensor, name=None):
@tf_export('linalg.eigvalsh', v1=['linalg.eigvalsh', 'self_adjoint_eigvals'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('self_adjoint_eigvals')
def self_adjoint_eigvals(tensor, name=None):
"""Computes the eigenvalues of one or more self-adjoint matrices.
@ -473,6 +482,7 @@ def self_adjoint_eigvals(tensor, name=None):
@tf_export('linalg.svd', v1=['linalg.svd', 'svd'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('svd')
def svd(tensor, full_matrices=False, compute_uv=True, name=None):
r"""Computes the singular value decompositions of one or more matrices.
@ -544,6 +554,7 @@ def svd(tensor, full_matrices=False, compute_uv=True, name=None):
# pylint: disable=redefined-builtin
@tf_export('norm', 'linalg.norm', v1=[])
@dispatch.add_dispatch_support
def norm_v2(tensor,
ord='euclidean',
axis=None,
@ -615,6 +626,7 @@ def norm_v2(tensor,
# pylint: disable=redefined-builtin
@tf_export(v1=['norm', 'linalg.norm'])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(
None, 'keep_dims is deprecated, use keepdims instead', 'keep_dims')
def norm(tensor,

View File

@ -38,6 +38,7 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.ops.gen_logging_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.platform import tf_logging
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@ -71,6 +72,7 @@ except NameError:
"only a concern in graph mode. Below is an example "
"of how to ensure tf.print executes in graph mode:\n")
@tf_export(v1=["Print"])
@dispatch.add_dispatch_support
def Print(input_, data, message=None, first_n=None, summarize=None, name=None):
"""Prints a list of tensors.
@ -136,6 +138,7 @@ def _is_filepath(output_stream):
# function definition.
# pylint: disable=g-doc-args
@tf_export("print")
@dispatch.add_dispatch_support
def print_v2(*inputs, **kwargs):
"""Print the specified inputs.

View File

@ -29,6 +29,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.ops.losses import util
from tensorflow.python.util import dispatch
from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.deprecation import deprecated_argument_lookup
from tensorflow.python.util.tf_export import tf_export
@ -136,6 +137,7 @@ def _num_elements(losses):
@tf_export(v1=["losses.compute_weighted_loss"])
@dispatch.add_dispatch_support
def compute_weighted_loss(
losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
@ -204,6 +206,7 @@ def compute_weighted_loss(
@tf_export(v1=["losses.absolute_difference"])
@dispatch.add_dispatch_support
def absolute_difference(
labels, predictions, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
@ -257,6 +260,7 @@ def absolute_difference(
@tf_export(v1=["losses.cosine_distance"])
@dispatch.add_dispatch_support
@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def cosine_distance(
labels, predictions, axis=None, weights=1.0, scope=None,
@ -313,6 +317,7 @@ def cosine_distance(
@tf_export(v1=["losses.hinge_loss"])
@dispatch.add_dispatch_support
def hinge_loss(labels, logits, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
@ -363,6 +368,7 @@ def hinge_loss(labels, logits, weights=1.0, scope=None,
@tf_export(v1=["losses.huber_loss"])
@dispatch.add_dispatch_support
def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
@ -439,6 +445,7 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None,
@tf_export(v1=["losses.log_loss"])
@dispatch.add_dispatch_support
def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
@ -496,6 +503,7 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None,
# TODO(b/37208492): Add reduction arg.
@tf_export(v1=["losses.mean_pairwise_squared_error"])
@dispatch.add_dispatch_support
def mean_pairwise_squared_error(
labels, predictions, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES):
@ -592,6 +600,7 @@ def mean_pairwise_squared_error(
@tf_export(v1=["losses.mean_squared_error"])
@dispatch.add_dispatch_support
def mean_squared_error(
labels, predictions, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
@ -645,6 +654,7 @@ def mean_squared_error(
@tf_export(v1=["losses.sigmoid_cross_entropy"])
@dispatch.add_dispatch_support
def sigmoid_cross_entropy(
multi_class_labels, logits, weights=1.0, label_smoothing=0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
@ -709,6 +719,7 @@ def sigmoid_cross_entropy(
@tf_export(v1=["losses.softmax_cross_entropy"])
@dispatch.add_dispatch_support
def softmax_cross_entropy(
onehot_labels, logits, weights=1.0, label_smoothing=0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
@ -831,6 +842,7 @@ def _remove_squeezable_dimensions(
@tf_export(v1=["losses.sparse_softmax_cross_entropy"])
@dispatch.add_dispatch_support
def sparse_softmax_cross_entropy(
labels, logits, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,

View File

@ -20,11 +20,13 @@ from __future__ import print_function
from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access
@tf_export('roll', v1=['roll', 'manip.roll'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('manip.roll')
def roll(input, shift, axis, name=None): # pylint: disable=redefined-builtin
return _gen_manip_ops.roll(input, shift, axis, name)

View File

@ -104,6 +104,7 @@ nextafter = gen_math_ops.next_after
@tf_export("linspace", v1=["lin_space", "linspace"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("lin_space")
def linspace_nd(start, stop, num, name=None, axis=0):
r"""Generates evenly-spaced values in an interval along a given axis.
@ -214,8 +215,8 @@ linspace = linspace_nd
arg_max = deprecation.deprecated(None, "Use `tf.math.argmax` instead")(arg_max) # pylint: disable=used-before-assignment
arg_min = deprecation.deprecated(None, "Use `tf.math.argmin` instead")(arg_min) # pylint: disable=used-before-assignment
tf_export(v1=["arg_max"])(arg_max)
tf_export(v1=["arg_min"])(arg_min)
tf_export(v1=["arg_max"])(dispatch.add_dispatch_support(arg_max))
tf_export(v1=["arg_min"])(dispatch.add_dispatch_support(arg_min))
# This is set by resource_variable_ops.py. It is included in this way since
@ -234,6 +235,7 @@ def _set_doc(doc):
# pylint: disable=redefined-builtin
@tf_export(v1=["math.argmax", "argmax"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, "Use the `axis` argument instead",
"dimension")
@_set_doc(
@ -250,6 +252,7 @@ def argmax(input,
@tf_export("math.argmax", "argmax", v1=[])
@dispatch.add_dispatch_support
def argmax_v2(input, axis=None, output_type=dtypes.int64, name=None):
"""Returns the index with the largest value across axes of a tensor.
@ -283,6 +286,7 @@ def argmax_v2(input, axis=None, output_type=dtypes.int64, name=None):
@tf_export(v1=["math.argmin", "argmin"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, "Use the `axis` argument instead",
"dimension")
@_set_doc(
@ -299,6 +303,7 @@ def argmin(input,
@tf_export("math.argmin", "argmin", v1=[])
@dispatch.add_dispatch_support
def argmin_v2(input, axis=None, output_type=dtypes.int64, name=None):
"""Returns the index with the smallest value across axes of a tensor.
@ -549,6 +554,7 @@ def _neg(x, name=None):
@tf_export(v1=["math.scalar_mul", "scalar_mul"])
@dispatch.add_dispatch_support
def scalar_mul(scalar, x, name=None):
"""Multiplies a scalar times a `Tensor` or `IndexedSlices` object.
@ -581,6 +587,7 @@ def scalar_mul(scalar, x, name=None):
@tf_export("math.scalar_mul", "scalar_mul", v1=[])
@dispatch.add_dispatch_support
@_set_doc(scalar_mul.__doc__)
def scalar_mul_v2(scalar, x, name=None):
with ops.name_scope(name, "scalar_mul", [x]) as name:
@ -701,6 +708,7 @@ def sign(x, name=None):
@tf_export("math.real", v1=["math.real", "real"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("real")
@dispatch.add_dispatch_support
def real(input, name=None):
@ -735,6 +743,7 @@ def real(input, name=None):
@tf_export("math.imag", v1=["math.imag", "imag"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("imag")
@dispatch.add_dispatch_support
def imag(input, name=None):
@ -768,6 +777,7 @@ def imag(input, name=None):
@tf_export("math.angle", v1=["math.angle", "angle"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("angle")
@dispatch.add_dispatch_support
def angle(input, name=None):
@ -937,6 +947,7 @@ def saturate_cast(value, dtype, name=None):
@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
@tf_export(v1=["to_float"])
@dispatch.add_dispatch_support
def to_float(x, name="ToFloat"):
"""Casts a tensor to type `float32`.
@ -956,6 +967,7 @@ def to_float(x, name="ToFloat"):
@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
@tf_export(v1=["to_double"])
@dispatch.add_dispatch_support
def to_double(x, name="ToDouble"):
"""Casts a tensor to type `float64`.
@ -975,6 +987,7 @@ def to_double(x, name="ToDouble"):
@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
@tf_export(v1=["to_int32"])
@dispatch.add_dispatch_support
def to_int32(x, name="ToInt32"):
"""Casts a tensor to type `int32`.
@ -994,6 +1007,7 @@ def to_int32(x, name="ToInt32"):
@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
@tf_export(v1=["to_int64"])
@dispatch.add_dispatch_support
def to_int64(x, name="ToInt64"):
"""Casts a tensor to type `int64`.
@ -1013,6 +1027,7 @@ def to_int64(x, name="ToInt64"):
@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
@tf_export(v1=["to_bfloat16"])
@dispatch.add_dispatch_support
def to_bfloat16(x, name="ToBFloat16"):
"""Casts a tensor to type `bfloat16`.
@ -1032,6 +1047,7 @@ def to_bfloat16(x, name="ToBFloat16"):
@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
@tf_export(v1=["to_complex64"])
@dispatch.add_dispatch_support
def to_complex64(x, name="ToComplex64"):
"""Casts a tensor to type `complex64`.
@ -1051,6 +1067,7 @@ def to_complex64(x, name="ToComplex64"):
@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
@tf_export(v1=["to_complex128"])
@dispatch.add_dispatch_support
def to_complex128(x, name="ToComplex128"):
"""Casts a tensor to type `complex128`.
@ -1265,6 +1282,7 @@ def truediv(x, y, name=None):
date=None,
instructions="Deprecated in favor of operator or tf.math.divide.")
@tf_export(v1=["div"])
@dispatch.add_dispatch_support
def div(x, y, name=None):
"""Divides x / y elementwise (using Python 2 division operator semantics).
@ -1288,6 +1306,7 @@ def div(x, y, name=None):
@tf_export("math.divide_no_nan", v1=["math.divide_no_nan", "div_no_nan"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("div_no_nan")
@dispatch.add_dispatch_support
def div_no_nan(x, y, name=None):
@ -1620,6 +1639,7 @@ ops.Tensor._override_operator("__ne__", tensor_not_equals)
@tf_export("range")
@dispatch.add_dispatch_support
def range(start, limit=None, delta=1, dtype=None, name="range"): # pylint: disable=redefined-builtin
"""Creates a sequence of numbers.
@ -1751,6 +1771,7 @@ def _may_reduce_to_scalar(keepdims, axis, output):
@tf_export(v1=["math.reduce_sum", "reduce_sum"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None,
"keep_dims is deprecated, use keepdims instead",
"keep_dims")
@ -1885,6 +1906,7 @@ def reduce_sum_with_dims(input_tensor,
@tf_export("math.reduce_euclidean_norm")
@dispatch.add_dispatch_support
def reduce_euclidean_norm(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the Euclidean norm of elements across dimensions of a tensor.
@ -1928,6 +1950,7 @@ def reduce_euclidean_norm(input_tensor, axis=None, keepdims=False, name=None):
@tf_export(v1=["math.count_nonzero", "count_nonzero"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None,
"keep_dims is deprecated, use keepdims instead",
"keep_dims")
@ -2005,6 +2028,7 @@ def count_nonzero(input_tensor=None,
@tf_export("math.count_nonzero", v1=[])
@dispatch.add_dispatch_support
def count_nonzero_v2(
input, # pylint: disable=redefined-builtin
axis=None,
@ -2072,6 +2096,7 @@ def count_nonzero_v2(
@tf_export(v1=["math.reduce_mean", "reduce_mean"])
@dispatch.add_dispatch_support
def reduce_mean_v1(input_tensor,
axis=None,
keepdims=None,
@ -2198,6 +2223,7 @@ def reduce_mean(input_tensor, axis=None, keepdims=False, name=None):
@tf_export("math.reduce_variance")
@dispatch.add_dispatch_support
def reduce_variance(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the variance of elements across dimensions of a tensor.
@ -2246,6 +2272,7 @@ def reduce_variance(input_tensor, axis=None, keepdims=False, name=None):
@tf_export("math.reduce_std")
@dispatch.add_dispatch_support
def reduce_std(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the standard deviation of elements across dimensions of a tensor.
@ -2328,6 +2355,7 @@ def reduce_prod(input_tensor, axis=None, keepdims=False, name=None):
@tf_export(v1=["math.reduce_prod", "reduce_prod"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None,
"keep_dims is deprecated, use keepdims instead",
"keep_dims")
@ -2373,6 +2401,7 @@ def reduce_prod_v1(input_tensor,
@tf_export(v1=["math.reduce_min", "reduce_min"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None,
"keep_dims is deprecated, use keepdims instead",
"keep_dims")
@ -2459,6 +2488,7 @@ def reduce_min(input_tensor, axis=None, keepdims=False, name=None):
@tf_export(v1=["math.reduce_max", "reduce_max"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None,
"keep_dims is deprecated, use keepdims instead",
"keep_dims")
@ -2563,6 +2593,7 @@ def reduce_max_with_dims(input_tensor,
@tf_export(v1=["math.reduce_all", "reduce_all"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None,
"keep_dims is deprecated, use keepdims instead",
"keep_dims")
@ -2662,6 +2693,7 @@ def reduce_all(input_tensor, axis=None, keepdims=False, name=None):
@tf_export(v1=["math.reduce_any", "reduce_any"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None,
"keep_dims is deprecated, use keepdims instead",
"keep_dims")
@ -2761,6 +2793,7 @@ def reduce_any(input_tensor, axis=None, keepdims=False, name=None):
@tf_export(v1=["math.reduce_logsumexp", "reduce_logsumexp"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None,
"keep_dims is deprecated, use keepdims instead",
"keep_dims")
@ -2817,6 +2850,7 @@ def reduce_logsumexp_v1(input_tensor,
@tf_export("math.reduce_logsumexp", "reduce_logsumexp", v1=[])
@dispatch.add_dispatch_support
def reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None):
"""Computes log(sum(exp(elements across dimensions of a tensor))).
@ -2877,6 +2911,7 @@ def reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None):
@tf_export("linalg.trace", v1=["linalg.trace", "trace"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("trace")
@dispatch.add_dispatch_support
def trace(x, name=None):
@ -3116,6 +3151,7 @@ def matmul(a,
@tf_export("linalg.matvec")
@dispatch.add_dispatch_support
def matvec(a,
b,
transpose_a=False,
@ -3219,6 +3255,7 @@ _OverrideBinaryOperatorHelper(matmul, "matmul")
sparse_matmul = deprecation.deprecated(None, "Use `tf.linalg.matmul` instead")(
gen_math_ops.sparse_mat_mul)
tf_export(v1=["sparse_matmul"])(sparse_matmul)
@dispatch.add_dispatch_support
@ops.RegisterStatistics("MatMul", "flops")
@ -3371,6 +3408,7 @@ def add_n(inputs, name=None):
@tf_export("math.accumulate_n", v1=["math.accumulate_n", "accumulate_n"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("accumulate_n")
def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
"""Returns the element-wise sum of a list of tensors.
@ -3449,6 +3487,7 @@ def _accumulate_n_grad(op, grad):
@tf_export("math.sigmoid", "nn.sigmoid", "sigmoid")
@dispatch.add_dispatch_support
def sigmoid(x, name=None):
r"""Computes sigmoid of `x` element-wise.
@ -3521,6 +3560,7 @@ def log_sigmoid(x, name=None):
@tf_export("math.bincount", v1=[])
@dispatch.add_dispatch_support
def bincount(arr,
weights=None,
minlength=None,
@ -3596,6 +3636,7 @@ def bincount(arr,
@tf_export(v1=["math.bincount", "bincount"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("bincount")
def bincount_v1(arr,
weights=None,
@ -3629,6 +3670,7 @@ def bincount_v1(arr,
@tf_export("math.cumsum", "cumsum")
@dispatch.add_dispatch_support
def cumsum(x, axis=0, exclusive=False, reverse=False, name=None):
"""Compute the cumulative sum of the tensor `x` along `axis`.
@ -3700,6 +3742,7 @@ def cumsum(x, axis=0, exclusive=False, reverse=False, name=None):
@tf_export("math.cumprod", v1=["math.cumprod", "cumprod"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("cumprod")
def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
"""Compute the cumulative product of the tensor `x` along `axis`.
@ -3753,6 +3796,7 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
@tf_export("math.cumulative_logsumexp", v1=["math.cumulative_logsumexp"])
@dispatch.add_dispatch_support
def cumulative_logsumexp(x, axis=0, exclusive=False, reverse=False, name=None):
"""Compute the cumulative log-sum-exp of the tensor `x` along `axis`.
@ -3912,6 +3956,7 @@ def _unsorted_segment_N(data, segment_ids, num_segments):
@tf_export(
"math.unsorted_segment_mean",
v1=["math.unsorted_segment_mean", "unsorted_segment_mean"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("unsorted_segment_mean")
@dispatch.add_dispatch_support
def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
@ -3958,6 +4003,7 @@ def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
@tf_export(
"math.unsorted_segment_sqrt_n",
v1=["math.unsorted_segment_sqrt_n", "unsorted_segment_sqrt_n"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("unsorted_segment_sqrt_n")
@dispatch.add_dispatch_support
def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None):
@ -4307,6 +4353,7 @@ def sparse_segment_sqrt_n_v2(data,
@tf_export("tensordot", "linalg.tensordot")
@dispatch.add_dispatch_support
def tensordot(a, b, axes, name=None):
r"""Tensor contraction of a and b along specified axes and outer product.
@ -4493,6 +4540,7 @@ def tensordot(a, b, axes, name=None):
@tf_export("math.polyval")
@dispatch.add_dispatch_support
def polyval(coeffs, x, name=None):
r"""Computes the elementwise value of a polynomial.
@ -4563,6 +4611,7 @@ def polyval(coeffs, x, name=None):
@tf_export("math.reciprocal_no_nan")
@dispatch.add_dispatch_support
def reciprocal_no_nan(x, name=None):
"""Performs a safe reciprocal operation, element wise.
@ -4665,6 +4714,7 @@ def ndtri(x, name=None):
@tf_export("math.ceil", v1=["math.ceil", "ceil"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("ceil")
@dispatch.add_dispatch_support
def ceil(x, name=None):
@ -4778,6 +4828,7 @@ def exp(x, name=None):
@tf_export("math.sobol_sample")
@dispatch.add_dispatch_support
def sobol_sample(dim, num_results, skip=0, dtype=dtypes.float32, name=None):
"""Generates points from the Sobol sequence.
@ -4802,6 +4853,7 @@ def sobol_sample(dim, num_results, skip=0, dtype=dtypes.float32, name=None):
@tf_export("math.rsqrt", v1=["math.rsqrt", "rsqrt"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("rsqrt")
@dispatch.add_dispatch_support
def rsqrt(x, name=None):

View File

@ -39,12 +39,14 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import util as losses_util
from tensorflow.python.platform import device_context
from tensorflow.python.util import dispatch
from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.deprecation import deprecated_argument_lookup
from tensorflow.python.util.tf_export import tf_export
@tf_export("nn.log_poisson_loss")
@dispatch.add_dispatch_support
def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None):
"""Computes log Poisson loss given `log_input`.
@ -110,6 +112,7 @@ def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None):
@tf_export(v1=["nn.sigmoid_cross_entropy_with_logits"])
@dispatch.add_dispatch_support
def sigmoid_cross_entropy_with_logits( # pylint: disable=invalid-name
_sentinel=None,
labels=None,
@ -192,6 +195,7 @@ def sigmoid_cross_entropy_with_logits( # pylint: disable=invalid-name
# Note: intentionally calling this v2 to not allow existing code with indirect
# imports to ignore the sentinel behavior.
@tf_export("nn.sigmoid_cross_entropy_with_logits", v1=[])
@dispatch.add_dispatch_support
def sigmoid_cross_entropy_with_logits_v2( # pylint: disable=invalid-name
labels=None,
logits=None,
@ -242,6 +246,7 @@ def sigmoid_cross_entropy_with_logits_v2( # pylint: disable=invalid-name
@tf_export("nn.weighted_cross_entropy_with_logits", v1=[])
@dispatch.add_dispatch_support
def weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight,
name=None):
"""Computes a weighted cross entropy.
@ -320,6 +325,7 @@ def weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight,
@tf_export(v1=["nn.weighted_cross_entropy_with_logits"])
@dispatch.add_dispatch_support
@deprecated_args(None, "targets is deprecated, use labels instead", "targets")
def weighted_cross_entropy_with_logits(labels=None,
logits=None,
@ -384,6 +390,7 @@ def weighted_cross_entropy_with_logits(labels=None,
@tf_export("nn.compute_average_loss")
@dispatch.add_dispatch_support
def compute_average_loss(per_example_loss,
sample_weight=None,
global_batch_size=None):
@ -440,6 +447,7 @@ def compute_average_loss(per_example_loss,
@tf_export("nn.scale_regularization_loss")
@dispatch.add_dispatch_support
def scale_regularization_loss(regularization_loss):
"""Scales the sum of the given regularization losses by number of replicas.
@ -478,6 +486,7 @@ def scale_regularization_loss(regularization_loss):
@tf_export(v1=["nn.relu_layer"])
@dispatch.add_dispatch_support
def relu_layer(x, weights, biases, name=None):
"""Computes Relu(x * weight + biases).
@ -501,6 +510,7 @@ def relu_layer(x, weights, biases, name=None):
@tf_export("nn.swish")
@dispatch.add_dispatch_support
@custom_gradient.custom_gradient
def swish(features):
# pylint: disable=g-doc-args
@ -538,6 +548,7 @@ def swish(features):
# pylint: disable=redefined-builtin
@tf_export("linalg.normalize")
@dispatch.add_dispatch_support
def normalize(tensor, ord="euclidean", axis=None, name=None):
"""Normalizes `tensor` along dimension `axis` using specified norm.
@ -590,6 +601,7 @@ def normalize(tensor, ord="euclidean", axis=None, name=None):
@tf_export(v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"])
@dispatch.add_dispatch_support
@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
"""Normalizes along dimension `axis` using an L2 norm.
@ -618,6 +630,7 @@ def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
@tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize", v1=[])
@dispatch.add_dispatch_support
def l2_normalize_v2(x, axis=None, epsilon=1e-12, name=None):
"""Normalizes along dimension `axis` using an L2 norm.
@ -668,6 +681,7 @@ def _count_nonzero(input_tensor, dtype=dtypes.int64):
@tf_export("math.zero_fraction", "nn.zero_fraction")
@dispatch.add_dispatch_support
def zero_fraction(value, name=None):
"""Returns the fraction of zeros in `value`.
@ -710,6 +724,7 @@ def zero_fraction(value, name=None):
# pylint: disable=redefined-builtin
@tf_export(v1=["nn.depthwise_conv2d"])
@dispatch.add_dispatch_support
def depthwise_conv2d(input,
filter,
strides,
@ -838,6 +853,7 @@ def depthwise_conv2d(input,
@tf_export("nn.depthwise_conv2d", v1=[])
@dispatch.add_dispatch_support
def depthwise_conv2d_v2(input,
filter,
strides,
@ -935,6 +951,7 @@ def depthwise_conv2d_v2(input,
# pylint: disable=redefined-builtin,line-too-long
@tf_export(v1=["nn.separable_conv2d"])
@dispatch.add_dispatch_support
def separable_conv2d(input,
depthwise_filter,
pointwise_filter,
@ -1042,6 +1059,7 @@ def separable_conv2d(input,
@tf_export("nn.separable_conv2d", v1=[])
@dispatch.add_dispatch_support
def separable_conv2d_v2(
input,
depthwise_filter,
@ -1117,6 +1135,7 @@ def separable_conv2d_v2(
@tf_export(v1=["nn.sufficient_statistics"])
@dispatch.add_dispatch_support
def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None,
keepdims=None):
"""Calculate the sufficient statistics for the mean and variance of `x`.
@ -1174,6 +1193,7 @@ def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None,
@tf_export("nn.sufficient_statistics", v1=[])
@dispatch.add_dispatch_support
def sufficient_statistics_v2(x, axes, shift=None, keepdims=False, name=None):
"""Calculate the sufficient statistics for the mean and variance of `x`.
@ -1203,6 +1223,7 @@ def sufficient_statistics_v2(x, axes, shift=None, keepdims=False, name=None):
@tf_export("nn.normalize_moments")
@dispatch.add_dispatch_support
def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
"""Calculate the mean and variance of based on the sufficient statistics.
@ -1235,6 +1256,7 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
@tf_export(v1=["nn.moments"])
@dispatch.add_dispatch_support
def moments(
x,
axes,
@ -1300,6 +1322,7 @@ def moments(
@tf_export("nn.moments", v1=[])
@dispatch.add_dispatch_support
def moments_v2(
x,
axes,
@ -1336,6 +1359,7 @@ def moments_v2(
@tf_export(v1=["nn.weighted_moments"])
@dispatch.add_dispatch_support
def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=None,
keepdims=None):
"""Returns the frequency-weighted mean and variance of `x`.
@ -1414,6 +1438,7 @@ def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=None,
@tf_export("nn.weighted_moments", v1=[])
@dispatch.add_dispatch_support
def weighted_moments_v2(x, axes, frequency_weights, keepdims=False, name=None):
"""Returns the frequency-weighted mean and variance of `x`.
@ -1438,6 +1463,7 @@ def weighted_moments_v2(x, axes, frequency_weights, keepdims=False, name=None):
@tf_export("nn.batch_normalization")
@dispatch.add_dispatch_support
def batch_normalization(x,
mean,
variance,
@ -1508,6 +1534,7 @@ def batch_normalization(x,
@tf_export(v1=["nn.fused_batch_norm"])
@dispatch.add_dispatch_support
def fused_batch_norm(
x,
scale,
@ -1631,6 +1658,7 @@ def fused_batch_norm(
@tf_export(v1=["nn.batch_norm_with_global_normalization"])
@dispatch.add_dispatch_support
def batch_norm_with_global_normalization(t=None,
m=None,
v=None,
@ -1685,6 +1713,7 @@ def batch_norm_with_global_normalization(t=None,
# pylint: disable=redefined-builtin,line-too-long
@tf_export("nn.batch_norm_with_global_normalization", v1=[])
@dispatch.add_dispatch_support
def batch_norm_with_global_normalization_v2(input,
mean,
variance,
@ -1934,6 +1963,7 @@ def _compute_sampled_logits(weights,
@tf_export("nn.nce_loss", v1=[])
@dispatch.add_dispatch_support
def nce_loss_v2(weights,
biases,
labels,
@ -2038,6 +2068,7 @@ def nce_loss_v2(weights,
@tf_export(v1=["nn.nce_loss"])
@dispatch.add_dispatch_support
def nce_loss(weights,
biases,
labels,
@ -2149,6 +2180,7 @@ def nce_loss(weights,
@tf_export("nn.sampled_softmax_loss", v1=[])
@dispatch.add_dispatch_support
def sampled_softmax_loss_v2(weights,
biases,
labels,
@ -2240,6 +2272,7 @@ def sampled_softmax_loss_v2(weights,
@tf_export(v1=["nn.sampled_softmax_loss"])
@dispatch.add_dispatch_support
def sampled_softmax_loss(weights,
biases,
labels,

View File

@ -239,6 +239,7 @@ class _NonAtrousConvolution(object):
@tf_export("nn.dilation2d", v1=[])
@dispatch.add_dispatch_support
def dilation2d_v2(
input, # pylint: disable=redefined-builtin
filters, # pylint: disable=redefined-builtin
@ -306,6 +307,7 @@ def dilation2d_v2(
@tf_export(v1=["nn.dilation2d"])
@dispatch.add_dispatch_support
def dilation2d_v1( # pylint: disable=missing-docstring
input, # pylint: disable=redefined-builtin
filter=None, # pylint: disable=redefined-builtin
@ -324,6 +326,7 @@ dilation2d_v1.__doc__ = gen_nn_ops.dilation2d.__doc__
@tf_export("nn.with_space_to_batch")
@dispatch.add_dispatch_support
def with_space_to_batch(
input, # pylint: disable=redefined-builtin
dilation_rate,
@ -772,6 +775,7 @@ def _get_strides_and_dilation_rate(num_spatial_dims, strides, dilation_rate):
@tf_export(v1=["nn.convolution"])
@dispatch.add_dispatch_support
def convolution(
input, # pylint: disable=redefined-builtin
filter, # pylint: disable=redefined-builtin
@ -907,7 +911,8 @@ def convolution(
@tf_export("nn.convolution", v1=[])
def convolution_v2(
@dispatch.add_dispatch_support
def convolution_v2( # pylint: disable=missing-docstring
input, # pylint: disable=redefined-builtin
filters,
strides=None,
@ -1116,6 +1121,7 @@ class Convolution(object):
@tf_export(v1=["nn.pool"])
@dispatch.add_dispatch_support
def pool(
input, # pylint: disable=redefined-builtin
window_shape,
@ -1290,6 +1296,7 @@ def pool(
@tf_export("nn.pool", v1=[])
@dispatch.add_dispatch_support
def pool_v2(
input, # pylint: disable=redefined-builtin
window_shape,
@ -1389,6 +1396,7 @@ def pool_v2(
@tf_export("nn.atrous_conv2d")
@dispatch.add_dispatch_support
def atrous_conv2d(value, filters, rate, padding, name=None):
"""Atrous convolution (a.k.a. convolution with holes or dilated convolution).
@ -1576,6 +1584,7 @@ def convert_padding(padding):
@tf_export(v1=["nn.conv1d"])
@dispatch.add_dispatch_support
@deprecation.deprecated_arg_values(
None,
"`NCHW` for data_format is deprecated, use `NCW` instead",
@ -1674,6 +1683,7 @@ def conv1d(
@tf_export("nn.conv1d", v1=[])
@dispatch.add_dispatch_support
def conv1d_v2(
input, # pylint: disable=redefined-builtin
filters,
@ -1739,6 +1749,7 @@ def conv1d_v2(
@tf_export("nn.conv1d_transpose")
@dispatch.add_dispatch_support
def conv1d_transpose(
input, # pylint: disable=redefined-builtin
filters,
@ -1827,6 +1838,7 @@ def conv1d_transpose(
@tf_export("nn.conv2d", v1=[])
@dispatch.add_dispatch_support
def conv2d_v2(input, # pylint: disable=redefined-builtin
filters,
strides,
@ -1927,6 +1939,7 @@ def conv2d_v2(input, # pylint: disable=redefined-builtin
@tf_export(v1=["nn.conv2d"])
@dispatch.add_dispatch_support
def conv2d( # pylint: disable=redefined-builtin,dangerous-default-value
input,
filter=None,
@ -2024,6 +2037,7 @@ def conv2d( # pylint: disable=redefined-builtin,dangerous-default-value
@tf_export(v1=["nn.conv2d_backprop_filter"])
@dispatch.add_dispatch_support
def conv2d_backprop_filter( # pylint: disable=redefined-builtin,dangerous-default-value
input,
filter_sizes,
@ -2084,6 +2098,7 @@ def conv2d_backprop_filter( # pylint: disable=redefined-builtin,dangerous-defau
@tf_export(v1=["nn.conv2d_backprop_input"])
@dispatch.add_dispatch_support
def conv2d_backprop_input( # pylint: disable=redefined-builtin,dangerous-default-value
input_sizes,
filter=None,
@ -2148,6 +2163,7 @@ def conv2d_backprop_input( # pylint: disable=redefined-builtin,dangerous-defaul
@tf_export(v1=["nn.conv2d_transpose"])
@dispatch.add_dispatch_support
def conv2d_transpose(
value=None,
filter=None, # pylint: disable=redefined-builtin
@ -2224,6 +2240,7 @@ def conv2d_transpose(
@tf_export("nn.conv2d_transpose", v1=[])
@dispatch.add_dispatch_support
def conv2d_transpose_v2(
input, # pylint: disable=redefined-builtin
filters, # pylint: disable=redefined-builtin
@ -2301,6 +2318,7 @@ def conv2d_transpose_v2(
@tf_export("nn.atrous_conv2d_transpose")
@dispatch.add_dispatch_support
def atrous_conv2d_transpose(value,
filters,
output_shape,
@ -2459,6 +2477,7 @@ def atrous_conv2d_transpose(value,
@tf_export(v1=["nn.depthwise_conv2d_native"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("nn.depthwise_conv2d_native")
def depthwise_conv2d_native( # pylint: disable=redefined-builtin,dangerous-default-value
input,
@ -2538,6 +2557,7 @@ def depthwise_conv2d_native( # pylint: disable=redefined-builtin,dangerous-defa
"nn.depthwise_conv2d_native_backprop_input",
"nn.depthwise_conv2d_backprop_input"
])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("nn.depthwise_conv2d_native_backprop_input")
def depthwise_conv2d_native_backprop_input( # pylint: disable=redefined-builtin,dangerous-default-value
input_sizes,
@ -2607,6 +2627,7 @@ def depthwise_conv2d_native_backprop_input( # pylint: disable=redefined-builtin
"nn.depthwise_conv2d_native_backprop_filter",
"nn.depthwise_conv2d_backprop_filter"
])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("nn.depthwise_conv2d_native_backprop_filter")
def depthwise_conv2d_native_backprop_filter( # pylint: disable=redefined-builtin,dangerous-default-value
input,
@ -2672,6 +2693,7 @@ def depthwise_conv2d_native_backprop_filter( # pylint: disable=redefined-builti
@tf_export("nn.conv3d", v1=[])
@dispatch.add_dispatch_support
def conv3d_v2(input, # pylint: disable=redefined-builtin,missing-docstring
filters,
strides,
@ -2691,6 +2713,7 @@ def conv3d_v2(input, # pylint: disable=redefined-builtin,missing-docstring
@tf_export(v1=["nn.conv3d"])
@dispatch.add_dispatch_support
def conv3d_v1( # pylint: disable=missing-docstring,dangerous-default-value
input, # pylint: disable=redefined-builtin
filter=None, # pylint: disable=redefined-builtin
@ -2711,6 +2734,7 @@ conv3d_v1.__doc__ = gen_nn_ops.conv3d.__doc__
@tf_export(v1=["nn.conv3d_transpose"])
@dispatch.add_dispatch_support
def conv3d_transpose(
value,
filter=None, # pylint: disable=redefined-builtin
@ -2782,6 +2806,7 @@ def conv3d_transpose(
@tf_export("nn.conv3d_transpose", v1=[])
@dispatch.add_dispatch_support
def conv3d_transpose_v2(input, # pylint: disable=redefined-builtin
filters,
output_shape,
@ -2861,6 +2886,7 @@ CONV_TRANSPOSE_OPS = (
@tf_export("nn.conv_transpose")
@dispatch.add_dispatch_support
def conv_transpose(input, # pylint: disable=redefined-builtin
filters,
output_shape,
@ -2958,6 +2984,7 @@ _tf_deterministic_ops.value = None
@tf_export("nn.bias_add")
@dispatch.add_dispatch_support
def bias_add(value, bias, data_format=None, name=None):
"""Adds `bias` to `value`.
@ -3047,6 +3074,7 @@ def bias_add_v1(value, bias, name=None):
@tf_export(v1=["nn.crelu"])
@dispatch.add_dispatch_support
def crelu(features, name=None, axis=-1):
"""Computes Concatenated ReLU.
@ -3079,12 +3107,14 @@ def crelu(features, name=None, axis=-1):
@tf_export("nn.crelu", v1=[])
@dispatch.add_dispatch_support
def crelu_v2(features, axis=-1, name=None):
return crelu(features, name=name, axis=axis)
crelu_v2.__doc__ = crelu.__doc__
@tf_export("nn.relu6")
@dispatch.add_dispatch_support
def relu6(features, name=None):
"""Computes Rectified Linear 6: `min(max(features, 0), 6)`.
@ -3107,6 +3137,7 @@ def relu6(features, name=None):
@tf_export("nn.leaky_relu")
@dispatch.add_dispatch_support
def leaky_relu(features, alpha=0.2, name=None):
"""Compute the Leaky ReLU activation function.
@ -3245,6 +3276,7 @@ def _softmax(logits, compute_op, dim=-1, name=None):
@tf_export(v1=["nn.softmax", "math.softmax"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def softmax(logits, axis=None, name=None, dim=None):
"""Computes softmax activations.
@ -3289,6 +3321,7 @@ def softmax(logits, axis=None, name=None, dim=None):
@tf_export("nn.softmax", "math.softmax", v1=[])
@dispatch.add_dispatch_support
def softmax_v2(logits, axis=None, name=None):
"""Computes softmax activations.
@ -3316,6 +3349,7 @@ def softmax_v2(logits, axis=None, name=None):
@tf_export(v1=["nn.log_softmax", "math.log_softmax"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def log_softmax(logits, axis=None, name=None, dim=None):
"""Computes log softmax activations.
@ -3346,6 +3380,7 @@ def log_softmax(logits, axis=None, name=None, dim=None):
@tf_export("nn.log_softmax", "math.log_softmax", v1=[])
@dispatch.add_dispatch_support
def log_softmax_v2(logits, axis=None, name=None):
"""Computes log softmax activations.
@ -3382,6 +3417,7 @@ def _ensure_xent_args(name, sentinel, labels, logits):
@tf_export("nn.softmax_cross_entropy_with_logits", v1=[])
@dispatch.add_dispatch_support
def softmax_cross_entropy_with_logits_v2(labels, logits, axis=-1, name=None):
"""Computes softmax cross entropy between `logits` and `labels`.
@ -3444,6 +3480,7 @@ def softmax_cross_entropy_with_logits_v2(labels, logits, axis=-1, name=None):
@tf_export(v1=["nn.softmax_cross_entropy_with_logits_v2"])
@dispatch.add_dispatch_support
@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def softmax_cross_entropy_with_logits_v2_helper(
labels, logits, axis=None, name=None, dim=None):
@ -3571,6 +3608,7 @@ See `tf.nn.softmax_cross_entropy_with_logits_v2`.
@tf_export(v1=["nn.softmax_cross_entropy_with_logits"])
@dispatch.add_dispatch_support
@deprecation.deprecated(date=None, instructions=_XENT_DEPRECATION)
def softmax_cross_entropy_with_logits(
_sentinel=None, # pylint: disable=invalid-name
@ -3639,6 +3677,7 @@ def softmax_cross_entropy_with_logits(
@tf_export(v1=["nn.sparse_softmax_cross_entropy_with_logits"])
@dispatch.add_dispatch_support
def sparse_softmax_cross_entropy_with_logits(
_sentinel=None, # pylint: disable=invalid-name
labels=None,
@ -3764,6 +3803,7 @@ def sparse_softmax_cross_entropy_with_logits(
@tf_export("nn.sparse_softmax_cross_entropy_with_logits", v1=[])
@dispatch.add_dispatch_support
def sparse_softmax_cross_entropy_with_logits_v2(labels, logits, name=None):
"""Computes sparse softmax cross entropy between `logits` and `labels`.
@ -3816,6 +3856,7 @@ def sparse_softmax_cross_entropy_with_logits_v2(labels, logits, name=None):
@tf_export("nn.avg_pool", v1=["nn.avg_pool_v2"])
@dispatch.add_dispatch_support
def avg_pool_v2(input, ksize, strides, padding, data_format=None, name=None): # pylint: disable=redefined-builtin
"""Performs the avg pooling on the input.
@ -3878,6 +3919,7 @@ def avg_pool_v2(input, ksize, strides, padding, data_format=None, name=None): #
@tf_export(v1=["nn.avg_pool", "nn.avg_pool2d"])
@dispatch.add_dispatch_support
def avg_pool(value, ksize, strides, padding, data_format="NHWC",
name=None, input=None): # pylint: disable=redefined-builtin
"""Performs the average pooling on the input.
@ -3922,6 +3964,7 @@ def avg_pool(value, ksize, strides, padding, data_format="NHWC",
@tf_export("nn.avg_pool2d", v1=[])
@dispatch.add_dispatch_support
def avg_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None): # pylint: disable=redefined-builtin
"""Performs the average pooling on the input.
@ -3961,6 +4004,7 @@ def avg_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None):
@tf_export("nn.avg_pool1d")
@dispatch.add_dispatch_support
def avg_pool1d(input, ksize, strides, padding, data_format="NWC", name=None): # pylint: disable=redefined-builtin
"""Performs the average pooling on the input.
@ -4006,6 +4050,7 @@ def avg_pool1d(input, ksize, strides, padding, data_format="NWC", name=None): #
@tf_export("nn.avg_pool3d")
@dispatch.add_dispatch_support
def avg_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None): # pylint: disable=redefined-builtin
"""Performs the average pooling on the input.
@ -4046,6 +4091,7 @@ def avg_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None):
# pylint: disable=redefined-builtin
@tf_export("nn.max_pool", v1=["nn.max_pool_v2"])
@dispatch.add_dispatch_support
def max_pool_v2(input, ksize, strides, padding, data_format=None, name=None):
"""Performs the max pooling on the input.
@ -4106,6 +4152,7 @@ def max_pool_v2(input, ksize, strides, padding, data_format=None, name=None):
@tf_export(v1=["nn.max_pool"])
@dispatch.add_dispatch_support
def max_pool(value,
ksize,
strides,
@ -4155,6 +4202,7 @@ def max_pool(value,
# pylint: disable=redefined-builtin
@tf_export("nn.max_pool1d")
@dispatch.add_dispatch_support
def max_pool1d(input, ksize, strides, padding, data_format="NWC", name=None):
"""Performs the max pooling on the input.
@ -4199,6 +4247,7 @@ def max_pool1d(input, ksize, strides, padding, data_format="NWC", name=None):
# pylint: disable=redefined-builtin
@tf_export("nn.max_pool2d")
@dispatch.add_dispatch_support
def max_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None):
"""Performs the max pooling on the input.
@ -4237,6 +4286,7 @@ def max_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None):
# pylint: disable=redefined-builtin
@tf_export("nn.max_pool3d")
@dispatch.add_dispatch_support
def max_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None):
"""Performs the max pooling on the input.
@ -4279,6 +4329,7 @@ def max_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None):
@tf_export("nn.max_pool_with_argmax", v1=[])
@dispatch.add_dispatch_support
def max_pool_with_argmax_v2(
input, # pylint: disable=redefined-builtin
ksize,
@ -4348,6 +4399,7 @@ def max_pool_with_argmax_v2(
@tf_export(v1=["nn.max_pool_with_argmax"])
@dispatch.add_dispatch_support
def max_pool_with_argmax_v1( # pylint: disable=missing-docstring,invalid-name
input, # pylint: disable=redefined-builtin
ksize,
@ -4442,6 +4494,7 @@ def _calc_bias_add_flops(graph, node):
@tf_export(v1=["nn.xw_plus_b"])
@dispatch.add_dispatch_support
def xw_plus_b(x, weights, biases, name=None): # pylint: disable=invalid-name
"""Computes matmul(x, weights) + biases.
@ -4691,6 +4744,7 @@ def dropout_v2(x, rate, noise_shape=None, seed=None, name=None):
@tf_export("math.top_k", "nn.top_k")
@dispatch.add_dispatch_support
def top_k(input, k=1, sorted=True, name=None): # pylint: disable=redefined-builtin
"""Finds values and indices of the `k` largest entries for the last dimension.
@ -4751,6 +4805,7 @@ def nth_element(input, n, reverse=False, name=None): # pylint: disable=redefine
@tf_export(v1=["nn.fractional_max_pool"])
@dispatch.add_dispatch_support
@deprecation.deprecated(date=None, instructions="`seed2` and `deterministic` "
"args are deprecated. Use fractional_max_pool_v2.")
def fractional_max_pool(value,
@ -4837,6 +4892,7 @@ def fractional_max_pool(value,
@tf_export("nn.fractional_max_pool", v1=[])
@dispatch.add_dispatch_support
def fractional_max_pool_v2(value,
pooling_ratio,
pseudo_random=False,
@ -4922,6 +4978,7 @@ def fractional_max_pool_v2(value,
@tf_export(v1=["nn.fractional_avg_pool"])
@dispatch.add_dispatch_support
@deprecation.deprecated(date=None, instructions="`seed2` and `deterministic` "
"args are deprecated. Use fractional_avg_pool_v2.")
def fractional_avg_pool(value,
@ -4987,6 +5044,7 @@ def fractional_avg_pool(value,
@tf_export("nn.fractional_avg_pool", v1=[])
@dispatch.add_dispatch_support
def fractional_avg_pool_v2(value,
pooling_ratio,
pseudo_random=False,
@ -5065,6 +5123,7 @@ def _calc_dilation2d_flops(graph, node):
@tf_export(v1=["nn.erosion2d"])
@dispatch.add_dispatch_support
def erosion2d(value, kernel, strides, rates, padding, name=None):
"""Computes the grayscale erosion of 4-D `value` and 3-D `kernel` tensors.
@ -5124,6 +5183,7 @@ def erosion2d(value, kernel, strides, rates, padding, name=None):
@tf_export("nn.erosion2d", v1=[])
@dispatch.add_dispatch_support
def erosion2d_v2(value,
filters,
strides,
@ -5193,6 +5253,7 @@ def erosion2d_v2(value,
@tf_export(v1=["math.in_top_k", "nn.in_top_k"])
@dispatch.add_dispatch_support
def in_top_k(predictions, targets, k, name=None):
r"""Says whether the targets are in the top `K` predictions.
@ -5227,6 +5288,7 @@ def in_top_k(predictions, targets, k, name=None):
@tf_export("math.in_top_k", "nn.in_top_k", v1=[])
@dispatch.add_dispatch_support
def in_top_k_v2(targets, predictions, k, name=None):
return in_top_k(predictions, targets, k, name)
@ -5234,7 +5296,11 @@ def in_top_k_v2(targets, predictions, k, name=None):
in_top_k_v2.__doc__ = in_top_k.__doc__
tf_export(v1=["nn.quantized_avg_pool"])(gen_nn_ops.quantized_avg_pool)
tf_export(v1=["nn.quantized_conv2d"])(gen_nn_ops.quantized_conv2d)
tf_export(v1=["nn.quantized_relu_x"])(gen_nn_ops.quantized_relu_x)
tf_export(v1=["nn.quantized_max_pool"])(gen_nn_ops.quantized_max_pool)
tf_export(v1=["nn.quantized_avg_pool"])(
dispatch.add_dispatch_support(gen_nn_ops.quantized_avg_pool))
tf_export(v1=["nn.quantized_conv2d"])(
dispatch.add_dispatch_support(gen_nn_ops.quantized_conv2d))
tf_export(v1=["nn.quantized_relu_x"])(
dispatch.add_dispatch_support(gen_nn_ops.quantized_relu_x))
tf_export(v1=["nn.quantized_max_pool"])(
dispatch.add_dispatch_support(gen_nn_ops.quantized_max_pool))

View File

@ -25,10 +25,12 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@tf_export(v1=["debugging.assert_all_finite", "verify_tensor_all_finite"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("verify_tensor_all_finite")
def verify_tensor_all_finite(t=None, msg=None, name=None, x=None, message=None):
"""Assert that the tensor does not contain any NaN's or Inf's.
@ -50,6 +52,7 @@ def verify_tensor_all_finite(t=None, msg=None, name=None, x=None, message=None):
@tf_export("debugging.assert_all_finite", v1=[])
@dispatch.add_dispatch_support
def verify_tensor_all_finite_v2(x, message, name=None):
"""Assert that the tensor does not contain any NaN's or Inf's.

View File

@ -30,6 +30,7 @@ from tensorflow.python.ops import parsing_config
from tensorflow.python.ops.gen_parsing_ops import *
# pylint: enable=wildcard-import,undefined-variable
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@ -77,6 +78,7 @@ def _prepend_none_dimension(features):
@tf_export("io.parse_example", v1=[])
@dispatch.add_dispatch_support
def parse_example_v2(serialized, features, example_names=None, name=None):
# pylint: disable=line-too-long
"""Parses `Example` protos into a `dict` of tensors.
@ -314,6 +316,7 @@ def parse_example_v2(serialized, features, example_names=None, name=None):
@tf_export(v1=["io.parse_example", "parse_example"])
@dispatch.add_dispatch_support
def parse_example(serialized, features, name=None, example_names=None):
return parse_example_v2(serialized, features, example_names, name)
@ -373,6 +376,7 @@ def _parse_example_raw(serialized, names, params, name):
@tf_export(v1=["io.parse_single_example", "parse_single_example"])
@dispatch.add_dispatch_support
def parse_single_example(serialized, features, name=None, example_names=None):
"""Parses a single `Example` proto.
@ -407,6 +411,7 @@ def parse_single_example(serialized, features, name=None, example_names=None):
@tf_export("io.parse_single_example", v1=[])
@dispatch.add_dispatch_support
def parse_single_example_v2(
serialized, features, example_names=None, name=None
):
@ -448,6 +453,7 @@ def parse_single_example_v2(
@tf_export("io.parse_sequence_example")
@dispatch.add_dispatch_support
def parse_sequence_example(serialized,
context_features=None,
sequence_features=None,
@ -692,6 +698,7 @@ def _parse_sequence_example_raw(serialized,
@tf_export("io.parse_single_sequence_example",
v1=["io.parse_single_sequence_example",
"parse_single_sequence_example"])
@dispatch.add_dispatch_support
def parse_single_sequence_example(
serialized, context_features=None, sequence_features=None,
example_name=None, name=None):
@ -835,6 +842,7 @@ def _parse_single_sequence_example_raw(serialized,
@tf_export("io.decode_raw", v1=[])
@dispatch.add_dispatch_support
def decode_raw(input_bytes,
out_type,
little_endian=True,
@ -877,6 +885,7 @@ def decode_raw(input_bytes,
@tf_export(v1=["decode_raw", "io.decode_raw"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None,
"bytes is deprecated, use input_bytes instead",
"bytes")
@ -921,6 +930,7 @@ def decode_raw_v1(
# Swap `name` and `na_value` for backward compatibility.
@tf_export(v1=["io.decode_csv", "decode_csv"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("decode_csv")
def decode_csv(records,
record_defaults,
@ -970,6 +980,7 @@ def decode_csv(records,
@tf_export("io.decode_csv", v1=[])
@dispatch.add_dispatch_support
def decode_csv_v2(records,
record_defaults,
field_delim=",",

View File

@ -22,10 +22,11 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops.gen_decode_proto_ops import decode_proto_v2 as decode_proto
from tensorflow.python.ops.gen_encode_proto_ops import encode_proto
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
tf_export("io.decode_proto")(decode_proto)
tf_export("io.encode_proto")(encode_proto)
tf_export("io.decode_proto")(dispatch.add_dispatch_support(decode_proto))
tf_export("io.encode_proto")(dispatch.add_dispatch_support(encode_proto))
ops.NotDifferentiable("DecodeProtoV2")
ops.NotDifferentiable("EncodeProto")

View File

@ -32,6 +32,7 @@ from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.ops.ragged import segment_id_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
#===============================================================================
@ -40,6 +41,7 @@ from tensorflow.python.util.tf_export import tf_export
@tf_export('ragged.boolean_mask')
@dispatch.add_dispatch_support
def boolean_mask(data, mask, name=None):
"""Applies a boolean mask to `data` without flattening the mask dimensions.
@ -538,6 +540,7 @@ def ragged_one_hot(indices,
# ragged.stack_dynamic_partitions
#===============================================================================
@tf_export('ragged.stack_dynamic_partitions')
@dispatch.add_dispatch_support
def stack_dynamic_partitions(data, partitions, num_partitions, name=None):
"""Stacks dynamic partitions of a Tensor or RaggedTensor.
@ -699,6 +702,7 @@ def reverse(tensor, axis, name=None):
@tf_export('ragged.cross')
@dispatch.add_dispatch_support
def cross(inputs, name=None):
"""Generates feature cross from a list of tensors.
@ -725,6 +729,7 @@ def cross(inputs, name=None):
@tf_export('ragged.cross_hashed')
@dispatch.add_dispatch_support
def cross_hashed(inputs, num_buckets=0, hash_key=None, name=None):
"""Generates hashed feature cross from a list of tensors.

View File

@ -27,6 +27,7 @@ from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_gather_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@ -71,6 +72,7 @@ def concat(values, axis, name=None):
@tf_export('ragged.stack')
@dispatch.add_dispatch_support
def stack(values, axis=0, name=None):
"""Stacks a list of rank-`R` tensors into one rank-`(R+1)` `RaggedTensor`.

View File

@ -27,6 +27,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@ -34,6 +35,7 @@ from tensorflow.python.util.tf_export import tf_export
# Op to construct a constant RaggedTensor from a nested Python list.
#===============================================================================
@tf_export("ragged.constant")
@dispatch.add_dispatch_support
def constant(pylist, dtype=None, ragged_rank=None, inner_shape=None,
name=None, row_splits_dtype=dtypes.int64):
"""Constructs a constant RaggedTensor from a nested Python list.
@ -86,6 +88,7 @@ def constant(pylist, dtype=None, ragged_rank=None, inner_shape=None,
@tf_export(v1=["ragged.constant_value"])
@dispatch.add_dispatch_support
def constant_value(pylist, dtype=None, ragged_rank=None, inner_shape=None,
row_splits_dtype="int64"):
"""Constructs a RaggedTensorValue from a nested Python list.
@ -311,6 +314,7 @@ def _default_inner_shape_for_pylist(pylist, ragged_rank):
@tf_export(v1=["ragged.placeholder"])
@dispatch.add_dispatch_support
def placeholder(dtype, ragged_rank, value_shape=None, name=None):
"""Creates a placeholder for a `tf.RaggedTensor` that will always be fed.

View File

@ -24,10 +24,12 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_config
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@tf_export("ragged.map_flat_values")
@dispatch.add_dispatch_support
def map_flat_values(op, *args, **kwargs):
"""Applies `op` to the values of one or more RaggedTensors.

View File

@ -30,6 +30,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_functional_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import segment_id_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@ -38,6 +39,7 @@ from tensorflow.python.util.tf_export import tf_export
#===============================================================================
# pylint: disable=redefined-builtin
@tf_export('ragged.range')
@dispatch.add_dispatch_support
def range(starts, limits=None, deltas=1, dtype=None,
name=None, row_splits_dtype=dtypes.int64):
"""Returns a `RaggedTensor` containing the specified sequences of numbers.

View File

@ -29,10 +29,12 @@ from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util import compat as util_compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@tf_export("strings.bytes_split")
@dispatch.add_dispatch_support
def string_bytes_split(input, name=None): # pylint: disable=redefined-builtin
"""Split string elements of `input` into bytes.
@ -80,6 +82,7 @@ def string_bytes_split(input, name=None): # pylint: disable=redefined-builtin
# pylint: disable=redefined-builtin
@tf_export("strings.unicode_encode")
@dispatch.add_dispatch_support
def unicode_encode(input,
output_encoding,
errors="replace",
@ -177,6 +180,7 @@ def unicode_encode(input,
# pylint: disable=redefined-builtin
@tf_export("strings.unicode_decode")
@dispatch.add_dispatch_support
def unicode_decode(input,
input_encoding,
errors="replace",
@ -222,6 +226,7 @@ def unicode_decode(input,
@tf_export("strings.unicode_decode_with_offsets")
@dispatch.add_dispatch_support
def unicode_decode_with_offsets(input,
input_encoding,
errors="replace",
@ -283,6 +288,7 @@ def unicode_decode_with_offsets(input,
@tf_export("strings.unicode_split")
@dispatch.add_dispatch_support
def unicode_split(input,
input_encoding,
errors="replace",
@ -330,6 +336,7 @@ def unicode_split(input,
@tf_export("strings.unicode_split_with_offsets")
@dispatch.add_dispatch_support
def unicode_split_with_offsets(input,
input_encoding,
errors="replace",
@ -453,6 +460,7 @@ def _unicode_decode(input, input_encoding, errors, replacement_char,
@tf_export("strings.split", v1=[])
@dispatch.add_dispatch_support
def string_split_v2(input, sep=None, maxsplit=-1, name=None): # pylint: disable=redefined-builtin
"""Split elements of `input` based on `sep` into a `RaggedTensor`.
@ -514,6 +522,7 @@ def string_split_v2(input, sep=None, maxsplit=-1, name=None): # pylint: disable
@tf_export(v1=["string_split"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None,
"delimiter is deprecated, please use sep instead.",
"delimiter")
@ -578,6 +587,7 @@ def string_split(source, sep=None, skip_empty=True, delimiter=None,
# In TensorFlow 1.x, "tf.strings.split" uses the new signature (with maxsplit),
# but we need to add the result_type argument.
@tf_export(v1=["strings.split"])
@dispatch.add_dispatch_support
def strings_split_v1(input=None, sep=None, maxsplit=-1, # pylint: disable=redefined-builtin
result_type="SparseTensor", source=None, name=None):
"""Split elements of `input` based on `sep`.
@ -651,6 +661,7 @@ def reduce_join(inputs, axis=None, keepdims=None, separator="", name=None):
@tf_export("strings.ngrams")
@dispatch.add_dispatch_support
def ngrams(data,
ngram_width,
separator=" ",

View File

@ -25,12 +25,14 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
# For background on "segments" and "segment ids", see:
# https://www.tensorflow.org/api_docs/python/tf/math#Segmentation
@tf_export("ragged.row_splits_to_segment_ids")
@dispatch.add_dispatch_support
def row_splits_to_segment_ids(splits, name=None, out_type=None):
"""Generates the segmentation corresponding to a RaggedTensor `row_splits`.
@ -74,6 +76,7 @@ def row_splits_to_segment_ids(splits, name=None, out_type=None):
# For background on "segments" and "segment ids", see:
# https://www.tensorflow.org/api_docs/python/tf/math#Segmentation
@tf_export("ragged.segment_ids_to_row_splits")
@dispatch.add_dispatch_support
def segment_ids_to_row_splits(segment_ids, num_segments=None,
out_type=None, name=None):
"""Generates the RaggedTensor `row_splits` corresponding to a segmentation.

View File

@ -36,10 +36,12 @@ from tensorflow.python.ops.gen_random_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@tf_export("random.normal", v1=["random.normal", "random_normal"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("random_normal")
def random_normal(shape,
mean=0.0,
@ -155,6 +157,7 @@ def parameterized_truncated_normal(shape,
@tf_export("random.truncated_normal",
v1=["random.truncated_normal", "truncated_normal"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("truncated_normal")
def truncated_normal(shape,
mean=0.0,
@ -202,6 +205,7 @@ ops.NotDifferentiable("TruncatedNormal")
@tf_export("random.uniform", v1=["random.uniform", "random_uniform"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("random_uniform")
def random_uniform(shape,
minval=0,
@ -313,6 +317,7 @@ ops.NotDifferentiable("RandomUniform")
@tf_export("random.shuffle", v1=["random.shuffle", "random_shuffle"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("random_shuffle")
def random_shuffle(value, seed=None, name=None):
"""Randomly shuffles a tensor along its first dimension.
@ -345,6 +350,7 @@ def random_shuffle(value, seed=None, name=None):
@tf_export("image.random_crop", v1=["image.random_crop", "random_crop"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("random_crop")
def random_crop(value, size, seed=None, name=None):
"""Randomly crops a tensor to a given size.
@ -389,6 +395,7 @@ def random_crop(value, size, seed=None, name=None):
@tf_export(v1=["random.multinomial", "multinomial"])
@dispatch.add_dispatch_support
@deprecation.deprecated(
date=None, instructions="Use `tf.random.categorical` instead.")
def multinomial(logits, num_samples, seed=None, name=None, output_dtype=None):
@ -468,6 +475,7 @@ def _maybe_set_static_shape_helper(tensor, shape, postfix_tensor):
@tf_export("random.gamma", v1=["random.gamma", "random_gamma"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("random_gamma")
def random_gamma(shape,
alpha,
@ -561,6 +569,7 @@ def random_gamma(shape,
@tf_export(v1=["random.poisson", "random_poisson"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("random_poisson")
def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None):
"""Draws `shape` samples from each of the given Poisson distribution(s).
@ -601,6 +610,7 @@ def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None):
@tf_export("random.poisson", v1=[])
@dispatch.add_dispatch_support
def random_poisson_v2(shape, lam, dtype=dtypes.float32, seed=None, name=None):
"""Draws `shape` samples from each of the given Poisson distribution(s).

View File

@ -32,6 +32,7 @@ from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@ -342,6 +343,7 @@ def _reverse_seq(input_seq, lengths):
"keras.layers.RNN(cell))`, which is equivalent to "
"this API")
@tf_export(v1=["nn.bidirectional_dynamic_rnn"])
@dispatch.add_dispatch_support
def bidirectional_dynamic_rnn(cell_fw,
cell_bw,
inputs,
@ -499,6 +501,7 @@ def bidirectional_dynamic_rnn(cell_fw,
None,
"Please use `keras.layers.RNN(cell)`, which is equivalent to this API")
@tf_export(v1=["nn.dynamic_rnn"])
@dispatch.add_dispatch_support
def dynamic_rnn(cell,
inputs,
sequence_length=None,
@ -912,6 +915,7 @@ def _dynamic_rnn_loop(cell,
@tf_export(v1=["nn.raw_rnn"])
@dispatch.add_dispatch_support
def raw_rnn(cell,
loop_fn,
parallel_iterations=None,
@ -1238,6 +1242,7 @@ def raw_rnn(cell,
"Please use `keras.layers.RNN(cell, unroll=True)`, "
"which is equivalent to this API")
@tf_export(v1=["nn.static_rnn"])
@dispatch.add_dispatch_support
def static_rnn(cell,
inputs,
initial_state=None,
@ -1416,6 +1421,7 @@ def static_rnn(cell,
"Please use `keras.layers.RNN(cell, stateful=True)`, "
"which is equivalent to this API")
@tf_export(v1=["nn.static_state_saving_rnn"])
@dispatch.add_dispatch_support
def static_state_saving_rnn(cell,
inputs,
state_saver,
@ -1510,6 +1516,7 @@ def static_state_saving_rnn(cell,
"keras.layers.RNN(cell, unroll=True))`, which is "
"equivalent to this API")
@tf_export(v1=["nn.static_bidirectional_rnn"])
@dispatch.add_dispatch_support
def static_bidirectional_rnn(cell_fw,
cell_bw,
inputs,

View File

@ -39,6 +39,7 @@ from tensorflow.python.ops import gen_script_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import lazy_loader
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
@ -370,6 +371,7 @@ def _EagerPyFuncGrad(op, *dy):
@tf_export("py_function")
@dispatch.add_dispatch_support
def eager_py_func(func, inp, Tout, name=None):
"""Wraps a python function into a TensorFlow op that executes it eagerly.
@ -551,6 +553,7 @@ def py_func_common(func, inp, Tout, stateful=True, name=None):
stateful argument making all functions stateful.
""")
@tf_export(v1=["py_func"])
@dispatch.add_dispatch_support
def py_func(func, inp, Tout, stateful=True, name=None):
return py_func_common(func, inp, Tout, stateful, name=name)
@ -559,6 +562,7 @@ py_func.__doc__ = "%s" % py_func_common.__doc__
@tf_export("numpy_function")
@dispatch.add_dispatch_support
def numpy_function(func, inp, Tout, name=None):
"""Wraps a python function and uses it as a TensorFlow op.

View File

@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import gen_set_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@ -32,6 +33,7 @@ _VALID_DTYPES = set([
@tf_export("sets.size", v1=["sets.size", "sets.set_size"])
@dispatch.add_dispatch_support
def set_size(a, validate_indices=True):
"""Compute number of unique elements along last dimension of `a`.
@ -135,6 +137,7 @@ def _set_operation(a, b, set_operation, validate_indices=True):
@tf_export(
"sets.intersection", v1=["sets.intersection", "sets.set_intersection"])
@dispatch.add_dispatch_support
def set_intersection(a, b, validate_indices=True):
"""Compute set intersection of elements in last dimension of `a` and `b`.
@ -205,6 +208,7 @@ def set_intersection(a, b, validate_indices=True):
@tf_export(
"sets.difference", v1=["sets.difference", "sets.set_difference"])
@dispatch.add_dispatch_support
def set_difference(a, b, aminusb=True, validate_indices=True):
"""Compute set difference of elements in last dimension of `a` and `b`.
@ -286,6 +290,7 @@ def set_difference(a, b, aminusb=True, validate_indices=True):
@tf_export(
"sets.union", v1=["sets.union", "sets.set_union"])
@dispatch.add_dispatch_support
def set_union(a, b, validate_indices=True):
"""Compute set union of elements in last dimension of `a` and `b`.

View File

@ -25,6 +25,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops as _array_ops
from tensorflow.python.ops import math_ops as _math_ops
from tensorflow.python.ops.signal import fft_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@ -50,6 +51,7 @@ def _validate_dct_arguments(input_tensor, dct_type, n, axis, norm):
# TODO(rjryan): Implement `axis` parameter.
@tf_export("signal.dct", v1=["signal.dct", "spectral.dct"])
@dispatch.add_dispatch_support
def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
"""Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`.
@ -181,6 +183,7 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
# TODO(rjryan): Implement `n` and `axis` parameters.
@tf_export("signal.idct", v1=["signal.idct", "spectral.idct"])
@dispatch.add_dispatch_support
def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
"""Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`.

View File

@ -26,6 +26,7 @@ from tensorflow.python.ops import array_ops as _array_ops
from tensorflow.python.ops import gen_spectral_ops
from tensorflow.python.ops import manip_ops
from tensorflow.python.ops import math_ops as _math_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@ -181,17 +182,23 @@ ifft2d = gen_spectral_ops.ifft2d
fft3d = gen_spectral_ops.fft3d
ifft3d = gen_spectral_ops.ifft3d
rfft = _rfft_wrapper(gen_spectral_ops.rfft, 1, "rfft")
tf_export("signal.rfft", v1=["signal.rfft", "spectral.rfft"])(rfft)
tf_export("signal.rfft", v1=["signal.rfft", "spectral.rfft"])(
dispatch.add_dispatch_support(rfft))
irfft = _irfft_wrapper(gen_spectral_ops.irfft, 1, "irfft")
tf_export("signal.irfft", v1=["signal.irfft", "spectral.irfft"])(irfft)
tf_export("signal.irfft", v1=["signal.irfft", "spectral.irfft"])(
dispatch.add_dispatch_support(irfft))
rfft2d = _rfft_wrapper(gen_spectral_ops.rfft2d, 2, "rfft2d")
tf_export("signal.rfft2d", v1=["signal.rfft2d", "spectral.rfft2d"])(rfft2d)
tf_export("signal.rfft2d", v1=["signal.rfft2d", "spectral.rfft2d"])(
dispatch.add_dispatch_support(rfft2d))
irfft2d = _irfft_wrapper(gen_spectral_ops.irfft2d, 2, "irfft2d")
tf_export("signal.irfft2d", v1=["signal.irfft2d", "spectral.irfft2d"])(irfft2d)
tf_export("signal.irfft2d", v1=["signal.irfft2d", "spectral.irfft2d"])(
dispatch.add_dispatch_support(irfft2d))
rfft3d = _rfft_wrapper(gen_spectral_ops.rfft3d, 3, "rfft3d")
tf_export("signal.rfft3d", v1=["signal.rfft3d", "spectral.rfft3d"])(rfft3d)
tf_export("signal.rfft3d", v1=["signal.rfft3d", "spectral.rfft3d"])(
dispatch.add_dispatch_support(rfft3d))
irfft3d = _irfft_wrapper(gen_spectral_ops.irfft3d, 3, "irfft3d")
tf_export("signal.irfft3d", v1=["signal.irfft3d", "spectral.irfft3d"])(irfft3d)
tf_export("signal.irfft3d", v1=["signal.irfft3d", "spectral.irfft3d"])(
dispatch.add_dispatch_support(irfft3d))
def _fft_size_for_grad(grad, rank):
@ -363,6 +370,7 @@ def _irfft_grad_helper(rank, rfft_fn):
@tf_export("signal.fftshift")
@dispatch.add_dispatch_support
def fftshift(x, axes=None, name=None):
"""Shift the zero-frequency component to the center of the spectrum.
@ -407,6 +415,7 @@ def fftshift(x, axes=None, name=None):
@tf_export("signal.ifftshift")
@dispatch.add_dispatch_support
def ifftshift(x, axes=None, name=None):
"""The inverse of fftshift.

View File

@ -24,6 +24,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.signal import shape_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@ -90,6 +91,7 @@ def _validate_arguments(num_mel_bins, sample_rate,
@tf_export('signal.linear_to_mel_weight_matrix')
@dispatch.add_dispatch_support
def linear_to_mel_weight_matrix(num_mel_bins=20,
num_spectrogram_bins=129,
sample_rate=8000,

View File

@ -22,10 +22,12 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.signal import dct_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@tf_export('signal.mfccs_from_log_mel_spectrograms')
@dispatch.add_dispatch_support
def mfccs_from_log_mel_spectrograms(log_mel_spectrograms, name=None):
"""Computes [MFCCs][mfcc] of `log_mel_spectrograms`.

View File

@ -23,10 +23,12 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@tf_export("signal.overlap_and_add")
@dispatch.add_dispatch_support
def overlap_and_add(signal, frame_step, name=None):
"""Reconstructs a signal from a framed representation.

View File

@ -25,6 +25,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.signal import util_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@ -55,6 +56,7 @@ def _infer_frame_shape(signal, frame_length, frame_step, pad_end, axis):
@tf_export("signal.frame")
@dispatch.add_dispatch_support
def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1,
name=None):
"""Expands `signal`'s `axis` dimension into frames of `frame_length`.

View File

@ -31,10 +31,12 @@ from tensorflow.python.ops.signal import fft_ops
from tensorflow.python.ops.signal import reconstruction_ops
from tensorflow.python.ops.signal import shape_ops
from tensorflow.python.ops.signal import window_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@tf_export('signal.stft')
@dispatch.add_dispatch_support
def stft(signals, frame_length, frame_step, fft_length=None,
window_fn=window_ops.hann_window,
pad_end=False, name=None):
@ -95,6 +97,7 @@ def stft(signals, frame_length, frame_step, fft_length=None,
@tf_export('signal.inverse_stft_window_fn')
@dispatch.add_dispatch_support
def inverse_stft_window_fn(frame_step,
forward_window_fn=window_ops.hann_window,
name=None):
@ -156,6 +159,7 @@ def inverse_stft_window_fn(frame_step,
@tf_export('signal.inverse_stft')
@dispatch.add_dispatch_support
def inverse_stft(stfts,
frame_length,
frame_step,
@ -291,6 +295,7 @@ def _enclosing_power_of_two(value):
@tf_export('signal.mdct')
@dispatch.add_dispatch_support
def mdct(signals, frame_length, window_fn=window_ops.vorbis_window,
pad_end=False, norm=None, name=None):
"""Computes the [Modified Discrete Cosine Transform][mdct] of `signals`.
@ -366,6 +371,7 @@ def mdct(signals, frame_length, window_fn=window_ops.vorbis_window,
@tf_export('signal.inverse_mdct')
@dispatch.add_dispatch_support
def inverse_mdct(mdcts,
window_fn=window_ops.vorbis_window,
norm=None,

View File

@ -27,6 +27,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@ -52,6 +53,7 @@ def _check_params(window_length, dtype):
@tf_export('signal.kaiser_window')
@dispatch.add_dispatch_support
def kaiser_window(window_length, beta=12., dtype=dtypes.float32, name=None):
"""Generate a [Kaiser window][kaiser].
@ -91,6 +93,7 @@ def kaiser_window(window_length, beta=12., dtype=dtypes.float32, name=None):
@tf_export('signal.kaiser_bessel_derived_window')
@dispatch.add_dispatch_support
def kaiser_bessel_derived_window(window_length, beta=12.,
dtype=dtypes.float32, name=None):
"""Generate a [Kaiser Bessel derived window][kbd].
@ -118,6 +121,7 @@ def kaiser_bessel_derived_window(window_length, beta=12.,
@tf_export('signal.vorbis_window')
@dispatch.add_dispatch_support
def vorbis_window(window_length, dtype=dtypes.float32, name=None):
"""Generate a [Vorbis power complementary window][vorbis].
@ -142,6 +146,7 @@ def vorbis_window(window_length, dtype=dtypes.float32, name=None):
@tf_export('signal.hann_window')
@dispatch.add_dispatch_support
def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None):
"""Generate a [Hann window][hann].
@ -167,6 +172,7 @@ def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None):
@tf_export('signal.hamming_window')
@dispatch.add_dispatch_support
def hamming_window(window_length, periodic=True, dtype=dtypes.float32,
name=None):
"""Generate a [Hamming][hamming] window.

View File

@ -30,10 +30,12 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@tf_export('sort')
@dispatch.add_dispatch_support
def sort(values, axis=-1, direction='ASCENDING', name=None):
"""Sorts a tensor.
@ -67,6 +69,7 @@ def sort(values, axis=-1, direction='ASCENDING', name=None):
@tf_export('argsort')
@dispatch.add_dispatch_support
def argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None):
"""Returns the indices of a tensor that give its sorted order along an axis.

View File

@ -1065,6 +1065,7 @@ def sparse_slice(sp_input, start, size, name=None):
@tf_export(v1=["sparse_to_dense"])
@dispatch.add_dispatch_support
@deprecation.deprecated(
None,
"Create a `tf.sparse.SparseTensor` and use `tf.sparse.to_dense` instead.")
@ -1994,6 +1995,7 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None):
@tf_export(v1=["io.serialize_sparse", "serialize_sparse"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("serialize_sparse")
def serialize_sparse(sp_input, name=None, out_type=dtypes.string):
"""Serialize a `SparseTensor` into a 3-vector (1-D `Tensor`) object.
@ -2014,6 +2016,7 @@ def serialize_sparse(sp_input, name=None, out_type=dtypes.string):
@tf_export("io.serialize_sparse", v1=[])
@dispatch.add_dispatch_support
def serialize_sparse_v2(sp_input, out_type=dtypes.string, name=None):
"""Serialize a `SparseTensor` into a 3-vector (1-D `Tensor`) object.
@ -2040,6 +2043,7 @@ def serialize_sparse_v2(sp_input, out_type=dtypes.string, name=None):
@tf_export(v1=["io.serialize_many_sparse", "serialize_many_sparse"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("serialize_many_sparse")
def serialize_many_sparse(sp_input, name=None, out_type=dtypes.string):
"""Serialize `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor`.
@ -2069,6 +2073,7 @@ def serialize_many_sparse(sp_input, name=None, out_type=dtypes.string):
@tf_export("io.serialize_many_sparse", v1=[])
@dispatch.add_dispatch_support
def serialize_many_sparse_v2(sp_input, out_type=dtypes.string, name=None):
"""Serialize `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor`.
@ -2172,6 +2177,7 @@ def deserialize_sparse(serialized_sparse, dtype, rank=None, name=None):
@tf_export(
"io.deserialize_many_sparse",
v1=["io.deserialize_many_sparse", "deserialize_many_sparse"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("deserialize_many_sparse")
def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None):
"""Deserialize and concatenate `SparseTensors` from a serialized minibatch.

View File

@ -42,11 +42,13 @@ from tensorflow.python.ops import gen_special_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
# TODO(b/27419586) Change docstring for required dtype of x once int allowed
@tf_export('math.lbeta', v1=['math.lbeta', 'lbeta'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('lbeta')
def lbeta(x, name=None):
r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension.
@ -102,6 +104,7 @@ def lbeta(x, name=None):
@tf_export('math.special.dawsn')
@dispatch.add_dispatch_support
def dawsn(x, name=None):
"""Computes Dawson's integral of `x` element-wise.
@ -131,6 +134,7 @@ def dawsn(x, name=None):
@tf_export('math.special.expint')
@dispatch.add_dispatch_support
def expint(x, name=None):
"""Computes the Exponential integral of `x` element-wise.
@ -159,6 +163,7 @@ def expint(x, name=None):
@tf_export('math.special.fresnel_cos')
@dispatch.add_dispatch_support
def fresnel_cos(x, name=None):
"""Computes Fresnel's cosine integral of `x` element-wise.
@ -188,6 +193,7 @@ def fresnel_cos(x, name=None):
@tf_export('math.special.fresnel_sin')
@dispatch.add_dispatch_support
def fresnel_sin(x, name=None):
"""Computes Fresnel's sine integral of `x` element-wise.
@ -216,6 +222,7 @@ def fresnel_sin(x, name=None):
@tf_export('math.special.spence')
@dispatch.add_dispatch_support
def spence(x, name=None):
"""Computes Spence's integral of `x` element-wise.
@ -244,6 +251,7 @@ def spence(x, name=None):
@tf_export('math.bessel_i0')
@dispatch.add_dispatch_support
def bessel_i0(x, name=None):
"""Computes the Bessel i0 function of `x` element-wise.
@ -268,6 +276,7 @@ def bessel_i0(x, name=None):
@tf_export('math.bessel_i1')
@dispatch.add_dispatch_support
def bessel_i1(x, name=None):
"""Computes the Bessel i1 function of `x` element-wise.
@ -325,6 +334,7 @@ def _enclosing_tpu_context():
@tf_export('einsum', 'linalg.einsum')
@dispatch.add_dispatch_support
def einsum(equation, *inputs, **kwargs):
"""Tensor contraction over specified indices and outer product.

View File

@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_stateless_random_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
ops.NotDifferentiable("StatelessMultinomial")
@ -40,6 +41,7 @@ ops.NotDifferentiable("StatelessTruncatedNormal")
@tf_export("random.experimental.stateless_split")
@dispatch.add_dispatch_support
def split(seed, num=2):
"""Splits an RNG seed into `num` new seeds by adding a leading axis.
@ -73,6 +75,7 @@ def split(seed, num=2):
@tf_export("random.experimental.stateless_fold_in")
@dispatch.add_dispatch_support
def fold_in(seed, data):
"""Folds in data to an RNG seed to form a new RNG seed.
@ -111,6 +114,7 @@ def fold_in(seed, data):
@tf_export("random.stateless_uniform")
@dispatch.add_dispatch_support
def stateless_random_uniform(shape,
seed,
minval=0,
@ -205,6 +209,7 @@ def stateless_random_uniform(shape,
@tf_export("random.stateless_binomial")
@dispatch.add_dispatch_support
def stateless_random_binomial(shape,
seed,
counts,
@ -274,6 +279,7 @@ def stateless_random_binomial(shape,
@tf_export("random.stateless_gamma")
@dispatch.add_dispatch_support
def stateless_random_gamma(shape,
seed,
alpha,
@ -372,6 +378,7 @@ def stateless_random_gamma(shape,
@tf_export("random.stateless_poisson")
@dispatch.add_dispatch_support
def stateless_random_poisson(shape,
seed,
lam,
@ -434,6 +441,7 @@ def stateless_random_poisson(shape,
@tf_export("random.stateless_normal")
@dispatch.add_dispatch_support
def stateless_random_normal(shape,
seed,
mean=0.0,
@ -474,6 +482,7 @@ def stateless_random_normal(shape,
@tf_export("random.stateless_truncated_normal")
@dispatch.add_dispatch_support
def stateless_truncated_normal(shape,
seed,
mean=0.0,
@ -520,6 +529,7 @@ def stateless_truncated_normal(shape,
@tf_export(v1=["random.stateless_multinomial"])
@dispatch.add_dispatch_support
@deprecation.deprecated(
date=None, instructions="Use `tf.random.stateless_categorical` instead.")
def stateless_multinomial(logits,
@ -562,6 +572,7 @@ def stateless_multinomial(logits,
@tf_export("random.stateless_categorical")
@dispatch.add_dispatch_support
def stateless_categorical(logits,
num_samples,
seed,

View File

@ -73,6 +73,7 @@ regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__
@tf_export(
"strings.regex_replace", v1=["strings.regex_replace", "regex_replace"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("regex_replace")
@dispatch.add_dispatch_support
def regex_replace(input, pattern, rewrite, replace_global=True, name=None):
@ -112,6 +113,7 @@ def regex_replace(input, pattern, rewrite, replace_global=True, name=None):
@tf_export("strings.format")
@dispatch.add_dispatch_support
def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
r"""Formats a string template using a list of tensors.
@ -300,6 +302,7 @@ def _reduce_join_reduction_dims(x, axis):
@tf_export(v1=["strings.reduce_join", "reduce_join"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None,
"keep_dims is deprecated, use keepdims instead",
"keep_dims")
@ -412,6 +415,7 @@ string_length_v2.__doc__ = gen_string_ops.string_length.__doc__
@tf_export(v1=["substr"])
@dispatch.add_dispatch_support
@deprecation.deprecated(None, "Use `tf.strings.substr` instead of `tf.substr`.")
def substr_deprecated(input, pos, len, name=None, unit="BYTE"):
return substr(input, pos, len, name=name, unit=unit)
@ -476,6 +480,7 @@ def string_to_number(input, out_type=dtypes.float32, name=None):
@tf_export(v1=["strings.to_number", "string_to_number"])
@dispatch.add_dispatch_support
def string_to_number_v1(
string_tensor=None,
out_type=dtypes.float32,
@ -519,6 +524,7 @@ def string_to_hash_bucket(input, num_buckets, name=None):
@tf_export(v1=["strings.to_hash_bucket", "string_to_hash_bucket"])
@dispatch.add_dispatch_support
def string_to_hash_bucket_v1(
string_tensor=None,
num_buckets=None,
@ -532,6 +538,7 @@ string_to_hash_bucket_v1.__doc__ = gen_string_ops.string_to_hash_bucket.__doc__
@tf_export("strings.join", v1=["strings.join", "string_join"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("string_join")
@dispatch.add_dispatch_support
def string_join(inputs, separator="", name=None):