Add dispatch support to more Python APIs.

PiperOrigin-RevId: 311763060
Change-Id: Ib35371483aa083e245996508a82fd13d8ac43131
This commit is contained in:
Edward Loper 2020-05-15 10:58:42 -07:00 committed by TensorFlower Gardener
parent 26104505b8
commit 77245d07d1
52 changed files with 696 additions and 28 deletions

View File

@ -24,6 +24,7 @@ from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn from tensorflow.python.ops import nn
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
# b/123041942 # b/123041942
@ -41,6 +42,7 @@ _TF_ACTIVATIONS_V2 = {
@keras_export('keras.activations.softmax') @keras_export('keras.activations.softmax')
@dispatch.add_dispatch_support
def softmax(x, axis=-1): def softmax(x, axis=-1):
"""Softmax converts a real vector to a vector of categorical probabilities. """Softmax converts a real vector to a vector of categorical probabilities.
@ -82,6 +84,7 @@ def softmax(x, axis=-1):
@keras_export('keras.activations.elu') @keras_export('keras.activations.elu')
@dispatch.add_dispatch_support
def elu(x, alpha=1.0): def elu(x, alpha=1.0):
"""Exponential linear unit. """Exponential linear unit.
@ -100,6 +103,7 @@ def elu(x, alpha=1.0):
@keras_export('keras.activations.selu') @keras_export('keras.activations.selu')
@dispatch.add_dispatch_support
def selu(x): def selu(x):
"""Scaled Exponential Linear Unit (SELU). """Scaled Exponential Linear Unit (SELU).
@ -153,6 +157,7 @@ def selu(x):
@keras_export('keras.activations.softplus') @keras_export('keras.activations.softplus')
@dispatch.add_dispatch_support
def softplus(x): def softplus(x):
"""Softplus activation function, `softplus(x) = log(exp(x) + 1)`. """Softplus activation function, `softplus(x) = log(exp(x) + 1)`.
@ -174,6 +179,7 @@ def softplus(x):
@keras_export('keras.activations.softsign') @keras_export('keras.activations.softsign')
@dispatch.add_dispatch_support
def softsign(x): def softsign(x):
"""Softsign activation function, `softsign(x) = x / (abs(x) + 1)`. """Softsign activation function, `softsign(x) = x / (abs(x) + 1)`.
@ -194,6 +200,7 @@ def softsign(x):
@keras_export('keras.activations.swish') @keras_export('keras.activations.swish')
@dispatch.add_dispatch_support
def swish(x): def swish(x):
"""Swish activation function, `swish(x) = x * sigmoid(x)`. """Swish activation function, `swish(x) = x * sigmoid(x)`.
@ -224,6 +231,7 @@ def swish(x):
@keras_export('keras.activations.relu') @keras_export('keras.activations.relu')
@dispatch.add_dispatch_support
def relu(x, alpha=0., max_value=None, threshold=0): def relu(x, alpha=0., max_value=None, threshold=0):
"""Applies the rectified linear unit activation function. """Applies the rectified linear unit activation function.
@ -264,6 +272,7 @@ def relu(x, alpha=0., max_value=None, threshold=0):
@keras_export('keras.activations.tanh') @keras_export('keras.activations.tanh')
@dispatch.add_dispatch_support
def tanh(x): def tanh(x):
"""Hyperbolic tangent activation function. """Hyperbolic tangent activation function.
@ -285,6 +294,7 @@ def tanh(x):
@keras_export('keras.activations.sigmoid') @keras_export('keras.activations.sigmoid')
@dispatch.add_dispatch_support
def sigmoid(x): def sigmoid(x):
"""Sigmoid activation function, `sigmoid(x) = 1 / (1 + exp(-x))`. """Sigmoid activation function, `sigmoid(x) = 1 / (1 + exp(-x))`.
@ -314,6 +324,7 @@ def sigmoid(x):
@keras_export('keras.activations.exponential') @keras_export('keras.activations.exponential')
@dispatch.add_dispatch_support
def exponential(x): def exponential(x):
"""Exponential activation function. """Exponential activation function.
@ -334,6 +345,7 @@ def exponential(x):
@keras_export('keras.activations.hard_sigmoid') @keras_export('keras.activations.hard_sigmoid')
@dispatch.add_dispatch_support
def hard_sigmoid(x): def hard_sigmoid(x):
"""Hard sigmoid activation function. """Hard sigmoid activation function.
@ -360,6 +372,7 @@ def hard_sigmoid(x):
@keras_export('keras.activations.linear') @keras_export('keras.activations.linear')
@dispatch.add_dispatch_support
def linear(x): def linear(x):
"""Linear activation function (pass-through). """Linear activation function (pass-through).
@ -380,6 +393,7 @@ def linear(x):
@keras_export('keras.activations.serialize') @keras_export('keras.activations.serialize')
@dispatch.add_dispatch_support
def serialize(activation): def serialize(activation):
"""Returns the string identifier of an activation function. """Returns the string identifier of an activation function.
@ -410,6 +424,7 @@ def serialize(activation):
@keras_export('keras.activations.deserialize') @keras_export('keras.activations.deserialize')
@dispatch.add_dispatch_support
def deserialize(name, custom_objects=None): def deserialize(name, custom_objects=None):
"""Returns activation function given a string identifier. """Returns activation function given a string identifier.
@ -447,6 +462,7 @@ def deserialize(name, custom_objects=None):
@keras_export('keras.activations.get') @keras_export('keras.activations.get')
@dispatch.add_dispatch_support
def get(identifier): def get(identifier):
"""Returns function. """Returns function.

View File

@ -76,6 +76,7 @@ from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import moving_averages from tensorflow.python.training import moving_averages
from tensorflow.python.training.tracking import util as tracking_util from tensorflow.python.training.tracking import util as tracking_util
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import object_identity from tensorflow.python.util import object_identity
from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_contextlib
@ -173,6 +174,7 @@ def backend():
@keras_export('keras.backend.cast_to_floatx') @keras_export('keras.backend.cast_to_floatx')
@dispatch.add_dispatch_support
def cast_to_floatx(x): def cast_to_floatx(x):
"""Cast a Numpy array to the default Keras float type. """Cast a Numpy array to the default Keras float type.
@ -799,6 +801,7 @@ def is_sparse(tensor):
@keras_export('keras.backend.to_dense') @keras_export('keras.backend.to_dense')
@dispatch.add_dispatch_support
def to_dense(tensor): def to_dense(tensor):
"""Converts a sparse tensor into a dense tensor and returns it. """Converts a sparse tensor into a dense tensor and returns it.
@ -1007,6 +1010,7 @@ def _initialize_variables(session):
@keras_export('keras.backend.constant') @keras_export('keras.backend.constant')
@dispatch.add_dispatch_support
def constant(value, dtype=None, shape=None, name=None): def constant(value, dtype=None, shape=None, name=None):
"""Creates a constant tensor. """Creates a constant tensor.
@ -1163,6 +1167,7 @@ def is_placeholder(x):
@keras_export('keras.backend.shape') @keras_export('keras.backend.shape')
@dispatch.add_dispatch_support
def shape(x): def shape(x):
"""Returns the symbolic shape of a tensor or variable. """Returns the symbolic shape of a tensor or variable.
@ -1245,6 +1250,7 @@ def ndim(x):
@keras_export('keras.backend.dtype') @keras_export('keras.backend.dtype')
@dispatch.add_dispatch_support
def dtype(x): def dtype(x):
"""Returns the dtype of a Keras tensor or variable, as a string. """Returns the dtype of a Keras tensor or variable, as a string.
@ -1343,6 +1349,7 @@ def zeros(shape, dtype=None, name=None):
@keras_export('keras.backend.ones') @keras_export('keras.backend.ones')
@dispatch.add_dispatch_support
def ones(shape, dtype=None, name=None): def ones(shape, dtype=None, name=None):
"""Instantiates an all-ones variable and returns it. """Instantiates an all-ones variable and returns it.
@ -1377,6 +1384,7 @@ def ones(shape, dtype=None, name=None):
@keras_export('keras.backend.eye') @keras_export('keras.backend.eye')
@dispatch.add_dispatch_support
def eye(size, dtype=None, name=None): def eye(size, dtype=None, name=None):
"""Instantiate an identity matrix and returns it. """Instantiate an identity matrix and returns it.
@ -1433,6 +1441,7 @@ def zeros_like(x, dtype=None, name=None):
@keras_export('keras.backend.ones_like') @keras_export('keras.backend.ones_like')
@dispatch.add_dispatch_support
def ones_like(x, dtype=None, name=None): def ones_like(x, dtype=None, name=None):
"""Instantiates an all-ones variable of the same shape as another tensor. """Instantiates an all-ones variable of the same shape as another tensor.
@ -1563,6 +1572,7 @@ def count_params(x):
@keras_export('keras.backend.cast') @keras_export('keras.backend.cast')
@dispatch.add_dispatch_support
def cast(x, dtype): def cast(x, dtype):
"""Casts a tensor to a different dtype and returns it. """Casts a tensor to a different dtype and returns it.
@ -1647,6 +1657,7 @@ def moving_average_update(x, value, momentum):
@keras_export('keras.backend.dot') @keras_export('keras.backend.dot')
@dispatch.add_dispatch_support
def dot(x, y): def dot(x, y):
"""Multiplies 2 tensors (and/or variables) and returns a tensor. """Multiplies 2 tensors (and/or variables) and returns a tensor.
@ -1707,6 +1718,7 @@ def dot(x, y):
@keras_export('keras.backend.batch_dot') @keras_export('keras.backend.batch_dot')
@dispatch.add_dispatch_support
def batch_dot(x, y, axes=None): def batch_dot(x, y, axes=None):
"""Batchwise dot product. """Batchwise dot product.
@ -1895,6 +1907,7 @@ def batch_dot(x, y, axes=None):
@keras_export('keras.backend.transpose') @keras_export('keras.backend.transpose')
@dispatch.add_dispatch_support
def transpose(x): def transpose(x):
"""Transposes a tensor and returns it. """Transposes a tensor and returns it.
@ -1926,6 +1939,7 @@ def transpose(x):
@keras_export('keras.backend.gather') @keras_export('keras.backend.gather')
@dispatch.add_dispatch_support
def gather(reference, indices): def gather(reference, indices):
"""Retrieves the elements of indices `indices` in the tensor `reference`. """Retrieves the elements of indices `indices` in the tensor `reference`.
@ -1961,6 +1975,7 @@ def gather(reference, indices):
@keras_export('keras.backend.max') @keras_export('keras.backend.max')
@dispatch.add_dispatch_support
def max(x, axis=None, keepdims=False): def max(x, axis=None, keepdims=False):
"""Maximum value in a tensor. """Maximum value in a tensor.
@ -1979,6 +1994,7 @@ def max(x, axis=None, keepdims=False):
@keras_export('keras.backend.min') @keras_export('keras.backend.min')
@dispatch.add_dispatch_support
def min(x, axis=None, keepdims=False): def min(x, axis=None, keepdims=False):
"""Minimum value in a tensor. """Minimum value in a tensor.
@ -1997,6 +2013,7 @@ def min(x, axis=None, keepdims=False):
@keras_export('keras.backend.sum') @keras_export('keras.backend.sum')
@dispatch.add_dispatch_support
def sum(x, axis=None, keepdims=False): def sum(x, axis=None, keepdims=False):
"""Sum of the values in a tensor, alongside the specified axis. """Sum of the values in a tensor, alongside the specified axis.
@ -2015,6 +2032,7 @@ def sum(x, axis=None, keepdims=False):
@keras_export('keras.backend.prod') @keras_export('keras.backend.prod')
@dispatch.add_dispatch_support
def prod(x, axis=None, keepdims=False): def prod(x, axis=None, keepdims=False):
"""Multiplies the values in a tensor, alongside the specified axis. """Multiplies the values in a tensor, alongside the specified axis.
@ -2033,6 +2051,7 @@ def prod(x, axis=None, keepdims=False):
@keras_export('keras.backend.cumsum') @keras_export('keras.backend.cumsum')
@dispatch.add_dispatch_support
def cumsum(x, axis=0): def cumsum(x, axis=0):
"""Cumulative sum of the values in a tensor, alongside the specified axis. """Cumulative sum of the values in a tensor, alongside the specified axis.
@ -2047,6 +2066,7 @@ def cumsum(x, axis=0):
@keras_export('keras.backend.cumprod') @keras_export('keras.backend.cumprod')
@dispatch.add_dispatch_support
def cumprod(x, axis=0): def cumprod(x, axis=0):
"""Cumulative product of the values in a tensor, alongside the specified axis. """Cumulative product of the values in a tensor, alongside the specified axis.
@ -2081,6 +2101,7 @@ def var(x, axis=None, keepdims=False):
@keras_export('keras.backend.std') @keras_export('keras.backend.std')
@dispatch.add_dispatch_support
def std(x, axis=None, keepdims=False): def std(x, axis=None, keepdims=False):
"""Standard deviation of a tensor, alongside the specified axis. """Standard deviation of a tensor, alongside the specified axis.
@ -2107,6 +2128,7 @@ def std(x, axis=None, keepdims=False):
@keras_export('keras.backend.mean') @keras_export('keras.backend.mean')
@dispatch.add_dispatch_support
def mean(x, axis=None, keepdims=False): def mean(x, axis=None, keepdims=False):
"""Mean of a tensor, alongside the specified axis. """Mean of a tensor, alongside the specified axis.
@ -2127,6 +2149,7 @@ def mean(x, axis=None, keepdims=False):
@keras_export('keras.backend.any') @keras_export('keras.backend.any')
@dispatch.add_dispatch_support
def any(x, axis=None, keepdims=False): def any(x, axis=None, keepdims=False):
"""Bitwise reduction (logical OR). """Bitwise reduction (logical OR).
@ -2143,6 +2166,7 @@ def any(x, axis=None, keepdims=False):
@keras_export('keras.backend.all') @keras_export('keras.backend.all')
@dispatch.add_dispatch_support
def all(x, axis=None, keepdims=False): def all(x, axis=None, keepdims=False):
"""Bitwise reduction (logical AND). """Bitwise reduction (logical AND).
@ -2159,6 +2183,7 @@ def all(x, axis=None, keepdims=False):
@keras_export('keras.backend.argmax') @keras_export('keras.backend.argmax')
@dispatch.add_dispatch_support
def argmax(x, axis=-1): def argmax(x, axis=-1):
"""Returns the index of the maximum value along an axis. """Returns the index of the maximum value along an axis.
@ -2173,6 +2198,7 @@ def argmax(x, axis=-1):
@keras_export('keras.backend.argmin') @keras_export('keras.backend.argmin')
@dispatch.add_dispatch_support
def argmin(x, axis=-1): def argmin(x, axis=-1):
"""Returns the index of the minimum value along an axis. """Returns the index of the minimum value along an axis.
@ -2187,6 +2213,7 @@ def argmin(x, axis=-1):
@keras_export('keras.backend.square') @keras_export('keras.backend.square')
@dispatch.add_dispatch_support
def square(x): def square(x):
"""Element-wise square. """Element-wise square.
@ -2200,6 +2227,7 @@ def square(x):
@keras_export('keras.backend.abs') @keras_export('keras.backend.abs')
@dispatch.add_dispatch_support
def abs(x): def abs(x):
"""Element-wise absolute value. """Element-wise absolute value.
@ -2213,6 +2241,7 @@ def abs(x):
@keras_export('keras.backend.sqrt') @keras_export('keras.backend.sqrt')
@dispatch.add_dispatch_support
def sqrt(x): def sqrt(x):
"""Element-wise square root. """Element-wise square root.
@ -2229,6 +2258,7 @@ def sqrt(x):
@keras_export('keras.backend.exp') @keras_export('keras.backend.exp')
@dispatch.add_dispatch_support
def exp(x): def exp(x):
"""Element-wise exponential. """Element-wise exponential.
@ -2242,6 +2272,7 @@ def exp(x):
@keras_export('keras.backend.log') @keras_export('keras.backend.log')
@dispatch.add_dispatch_support
def log(x): def log(x):
"""Element-wise log. """Element-wise log.
@ -2276,6 +2307,7 @@ def logsumexp(x, axis=None, keepdims=False):
@keras_export('keras.backend.round') @keras_export('keras.backend.round')
@dispatch.add_dispatch_support
def round(x): def round(x):
"""Element-wise rounding to the closest integer. """Element-wise rounding to the closest integer.
@ -2291,6 +2323,7 @@ def round(x):
@keras_export('keras.backend.sign') @keras_export('keras.backend.sign')
@dispatch.add_dispatch_support
def sign(x): def sign(x):
"""Element-wise sign. """Element-wise sign.
@ -2304,6 +2337,7 @@ def sign(x):
@keras_export('keras.backend.pow') @keras_export('keras.backend.pow')
@dispatch.add_dispatch_support
def pow(x, a): def pow(x, a):
"""Element-wise exponentiation. """Element-wise exponentiation.
@ -2318,6 +2352,7 @@ def pow(x, a):
@keras_export('keras.backend.clip') @keras_export('keras.backend.clip')
@dispatch.add_dispatch_support
def clip(x, min_value, max_value): def clip(x, min_value, max_value):
"""Element-wise value clipping. """Element-wise value clipping.
@ -2341,6 +2376,7 @@ def clip(x, min_value, max_value):
@keras_export('keras.backend.equal') @keras_export('keras.backend.equal')
@dispatch.add_dispatch_support
def equal(x, y): def equal(x, y):
"""Element-wise equality between two tensors. """Element-wise equality between two tensors.
@ -2355,6 +2391,7 @@ def equal(x, y):
@keras_export('keras.backend.not_equal') @keras_export('keras.backend.not_equal')
@dispatch.add_dispatch_support
def not_equal(x, y): def not_equal(x, y):
"""Element-wise inequality between two tensors. """Element-wise inequality between two tensors.
@ -2369,6 +2406,7 @@ def not_equal(x, y):
@keras_export('keras.backend.greater') @keras_export('keras.backend.greater')
@dispatch.add_dispatch_support
def greater(x, y): def greater(x, y):
"""Element-wise truth value of (x > y). """Element-wise truth value of (x > y).
@ -2383,6 +2421,7 @@ def greater(x, y):
@keras_export('keras.backend.greater_equal') @keras_export('keras.backend.greater_equal')
@dispatch.add_dispatch_support
def greater_equal(x, y): def greater_equal(x, y):
"""Element-wise truth value of (x >= y). """Element-wise truth value of (x >= y).
@ -2397,6 +2436,7 @@ def greater_equal(x, y):
@keras_export('keras.backend.less') @keras_export('keras.backend.less')
@dispatch.add_dispatch_support
def less(x, y): def less(x, y):
"""Element-wise truth value of (x < y). """Element-wise truth value of (x < y).
@ -2411,6 +2451,7 @@ def less(x, y):
@keras_export('keras.backend.less_equal') @keras_export('keras.backend.less_equal')
@dispatch.add_dispatch_support
def less_equal(x, y): def less_equal(x, y):
"""Element-wise truth value of (x <= y). """Element-wise truth value of (x <= y).
@ -2425,6 +2466,7 @@ def less_equal(x, y):
@keras_export('keras.backend.maximum') @keras_export('keras.backend.maximum')
@dispatch.add_dispatch_support
def maximum(x, y): def maximum(x, y):
"""Element-wise maximum of two tensors. """Element-wise maximum of two tensors.
@ -2449,6 +2491,7 @@ def maximum(x, y):
@keras_export('keras.backend.minimum') @keras_export('keras.backend.minimum')
@dispatch.add_dispatch_support
def minimum(x, y): def minimum(x, y):
"""Element-wise minimum of two tensors. """Element-wise minimum of two tensors.
@ -2463,6 +2506,7 @@ def minimum(x, y):
@keras_export('keras.backend.sin') @keras_export('keras.backend.sin')
@dispatch.add_dispatch_support
def sin(x): def sin(x):
"""Computes sin of x element-wise. """Computes sin of x element-wise.
@ -2476,6 +2520,7 @@ def sin(x):
@keras_export('keras.backend.cos') @keras_export('keras.backend.cos')
@dispatch.add_dispatch_support
def cos(x): def cos(x):
"""Computes cos of x element-wise. """Computes cos of x element-wise.
@ -2621,6 +2666,7 @@ def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
@keras_export('keras.backend.batch_normalization') @keras_export('keras.backend.batch_normalization')
@dispatch.add_dispatch_support
def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3): def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
"""Applies batch normalization on x given mean, var, beta and gamma. """Applies batch normalization on x given mean, var, beta and gamma.
@ -2683,6 +2729,7 @@ def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
@keras_export('keras.backend.concatenate') @keras_export('keras.backend.concatenate')
@dispatch.add_dispatch_support
def concatenate(tensors, axis=-1): def concatenate(tensors, axis=-1):
"""Concatenates a list of tensors alongside the specified axis. """Concatenates a list of tensors alongside the specified axis.
@ -2720,6 +2767,7 @@ def concatenate(tensors, axis=-1):
@keras_export('keras.backend.reshape') @keras_export('keras.backend.reshape')
@dispatch.add_dispatch_support
def reshape(x, shape): def reshape(x, shape):
"""Reshapes a tensor to the specified shape. """Reshapes a tensor to the specified shape.
@ -2749,6 +2797,7 @@ def reshape(x, shape):
@keras_export('keras.backend.permute_dimensions') @keras_export('keras.backend.permute_dimensions')
@dispatch.add_dispatch_support
def permute_dimensions(x, pattern): def permute_dimensions(x, pattern):
"""Permutes axes in a tensor. """Permutes axes in a tensor.
@ -2780,6 +2829,7 @@ def permute_dimensions(x, pattern):
@keras_export('keras.backend.resize_images') @keras_export('keras.backend.resize_images')
@dispatch.add_dispatch_support
def resize_images(x, height_factor, width_factor, data_format, def resize_images(x, height_factor, width_factor, data_format,
interpolation='nearest'): interpolation='nearest'):
"""Resizes the images contained in a 4D tensor. """Resizes the images contained in a 4D tensor.
@ -2843,6 +2893,7 @@ def resize_images(x, height_factor, width_factor, data_format,
@keras_export('keras.backend.resize_volumes') @keras_export('keras.backend.resize_volumes')
@dispatch.add_dispatch_support
def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
"""Resizes the volume contained in a 5D tensor. """Resizes the volume contained in a 5D tensor.
@ -2875,6 +2926,7 @@ def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
@keras_export('keras.backend.repeat_elements') @keras_export('keras.backend.repeat_elements')
@dispatch.add_dispatch_support
def repeat_elements(x, rep, axis): def repeat_elements(x, rep, axis):
"""Repeats the elements of a tensor along an axis, like `np.repeat`. """Repeats the elements of a tensor along an axis, like `np.repeat`.
@ -2936,6 +2988,7 @@ def repeat_elements(x, rep, axis):
@keras_export('keras.backend.repeat') @keras_export('keras.backend.repeat')
@dispatch.add_dispatch_support
def repeat(x, n): def repeat(x, n):
"""Repeats a 2D tensor. """Repeats a 2D tensor.
@ -2971,6 +3024,7 @@ def repeat(x, n):
@keras_export('keras.backend.arange') @keras_export('keras.backend.arange')
@dispatch.add_dispatch_support
def arange(start, stop=None, step=1, dtype='int32'): def arange(start, stop=None, step=1, dtype='int32'):
"""Creates a 1D tensor containing a sequence of integers. """Creates a 1D tensor containing a sequence of integers.
@ -3009,6 +3063,7 @@ def arange(start, stop=None, step=1, dtype='int32'):
@keras_export('keras.backend.tile') @keras_export('keras.backend.tile')
@dispatch.add_dispatch_support
def tile(x, n): def tile(x, n):
"""Creates a tensor by tiling `x` by `n`. """Creates a tensor by tiling `x` by `n`.
@ -3026,6 +3081,7 @@ def tile(x, n):
@keras_export('keras.backend.flatten') @keras_export('keras.backend.flatten')
@dispatch.add_dispatch_support
def flatten(x): def flatten(x):
"""Flatten a tensor. """Flatten a tensor.
@ -3051,6 +3107,7 @@ def flatten(x):
@keras_export('keras.backend.batch_flatten') @keras_export('keras.backend.batch_flatten')
@dispatch.add_dispatch_support
def batch_flatten(x): def batch_flatten(x):
"""Turn a nD tensor into a 2D tensor with same 0th dimension. """Turn a nD tensor into a 2D tensor with same 0th dimension.
@ -3076,6 +3133,7 @@ def batch_flatten(x):
@keras_export('keras.backend.expand_dims') @keras_export('keras.backend.expand_dims')
@dispatch.add_dispatch_support
def expand_dims(x, axis=-1): def expand_dims(x, axis=-1):
"""Adds a 1-sized dimension at index "axis". """Adds a 1-sized dimension at index "axis".
@ -3090,6 +3148,7 @@ def expand_dims(x, axis=-1):
@keras_export('keras.backend.squeeze') @keras_export('keras.backend.squeeze')
@dispatch.add_dispatch_support
def squeeze(x, axis): def squeeze(x, axis):
"""Removes a 1-dimension from the tensor at index "axis". """Removes a 1-dimension from the tensor at index "axis".
@ -3104,6 +3163,7 @@ def squeeze(x, axis):
@keras_export('keras.backend.temporal_padding') @keras_export('keras.backend.temporal_padding')
@dispatch.add_dispatch_support
def temporal_padding(x, padding=(1, 1)): def temporal_padding(x, padding=(1, 1)):
"""Pads the middle dimension of a 3D tensor. """Pads the middle dimension of a 3D tensor.
@ -3121,6 +3181,7 @@ def temporal_padding(x, padding=(1, 1)):
@keras_export('keras.backend.spatial_2d_padding') @keras_export('keras.backend.spatial_2d_padding')
@dispatch.add_dispatch_support
def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None): def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
"""Pads the 2nd and 3rd dimensions of a 4D tensor. """Pads the 2nd and 3rd dimensions of a 4D tensor.
@ -3152,6 +3213,7 @@ def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
@keras_export('keras.backend.spatial_3d_padding') @keras_export('keras.backend.spatial_3d_padding')
@dispatch.add_dispatch_support
def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None): def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
"""Pads 5D tensor with zeros along the depth, height, width dimensions. """Pads 5D tensor with zeros along the depth, height, width dimensions.
@ -3196,6 +3258,7 @@ def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
@keras_export('keras.backend.stack') @keras_export('keras.backend.stack')
@dispatch.add_dispatch_support
def stack(x, axis=0): def stack(x, axis=0):
"""Stacks a list of rank `R` tensors into a rank `R+1` tensor. """Stacks a list of rank `R` tensors into a rank `R+1` tensor.
@ -3222,6 +3285,7 @@ def stack(x, axis=0):
@keras_export('keras.backend.one_hot') @keras_export('keras.backend.one_hot')
@dispatch.add_dispatch_support
def one_hot(indices, num_classes): def one_hot(indices, num_classes):
"""Computes the one-hot representation of an integer tensor. """Computes the one-hot representation of an integer tensor.
@ -3241,6 +3305,7 @@ def one_hot(indices, num_classes):
@keras_export('keras.backend.reverse') @keras_export('keras.backend.reverse')
@dispatch.add_dispatch_support
def reverse(x, axes): def reverse(x, axes):
"""Reverse a tensor along the specified axes. """Reverse a tensor along the specified axes.
@ -3321,6 +3386,7 @@ def get_value(x):
@keras_export('keras.backend.batch_get_value') @keras_export('keras.backend.batch_get_value')
@dispatch.add_dispatch_support
def batch_get_value(tensors): def batch_get_value(tensors):
"""Returns the value of more than one tensor variable. """Returns the value of more than one tensor variable.
@ -3382,6 +3448,7 @@ def set_value(x, value):
@keras_export('keras.backend.batch_set_value') @keras_export('keras.backend.batch_set_value')
@dispatch.add_dispatch_support
def batch_set_value(tuples): def batch_set_value(tuples):
"""Sets the values of many tensor variables at once. """Sets the values of many tensor variables at once.
@ -3424,6 +3491,7 @@ set_value.__doc__ = set_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING)
@keras_export('keras.backend.print_tensor') @keras_export('keras.backend.print_tensor')
@dispatch.add_dispatch_support
def print_tensor(x, message=''): def print_tensor(x, message=''):
"""Prints `message` and the tensor value when evaluated. """Prints `message` and the tensor value when evaluated.
@ -3861,6 +3929,7 @@ def gradients(loss, variables):
@keras_export('keras.backend.stop_gradient') @keras_export('keras.backend.stop_gradient')
@dispatch.add_dispatch_support
def stop_gradient(variables): def stop_gradient(variables):
"""Returns `variables` but with zero gradient w.r.t. every other variable. """Returns `variables` but with zero gradient w.r.t. every other variable.
@ -3882,6 +3951,7 @@ def stop_gradient(variables):
@keras_export('keras.backend.rnn') @keras_export('keras.backend.rnn')
@dispatch.add_dispatch_support
def rnn(step_function, def rnn(step_function,
inputs, inputs,
initial_states, initial_states,
@ -4276,6 +4346,7 @@ def rnn(step_function,
@keras_export('keras.backend.switch') @keras_export('keras.backend.switch')
@dispatch.add_dispatch_support
def switch(condition, then_expression, else_expression): def switch(condition, then_expression, else_expression):
"""Switches between two operations depending on a scalar value. """Switches between two operations depending on a scalar value.
@ -4409,6 +4480,7 @@ def in_test_phase(x, alt, training=None):
@keras_export('keras.backend.relu') @keras_export('keras.backend.relu')
@dispatch.add_dispatch_support
def relu(x, alpha=0., max_value=None, threshold=0): def relu(x, alpha=0., max_value=None, threshold=0):
"""Rectified linear unit. """Rectified linear unit.
@ -4462,6 +4534,7 @@ def relu(x, alpha=0., max_value=None, threshold=0):
@keras_export('keras.backend.elu') @keras_export('keras.backend.elu')
@dispatch.add_dispatch_support
def elu(x, alpha=1.): def elu(x, alpha=1.):
"""Exponential linear unit. """Exponential linear unit.
@ -4480,6 +4553,7 @@ def elu(x, alpha=1.):
@keras_export('keras.backend.softmax') @keras_export('keras.backend.softmax')
@dispatch.add_dispatch_support
def softmax(x, axis=-1): def softmax(x, axis=-1):
"""Softmax of a tensor. """Softmax of a tensor.
@ -4495,6 +4569,7 @@ def softmax(x, axis=-1):
@keras_export('keras.backend.softplus') @keras_export('keras.backend.softplus')
@dispatch.add_dispatch_support
def softplus(x): def softplus(x):
"""Softplus of a tensor. """Softplus of a tensor.
@ -4508,6 +4583,7 @@ def softplus(x):
@keras_export('keras.backend.softsign') @keras_export('keras.backend.softsign')
@dispatch.add_dispatch_support
def softsign(x): def softsign(x):
"""Softsign of a tensor. """Softsign of a tensor.
@ -4527,6 +4603,7 @@ def _backtrack_identity(tensor):
@keras_export('keras.backend.categorical_crossentropy') @keras_export('keras.backend.categorical_crossentropy')
@dispatch.add_dispatch_support
def categorical_crossentropy(target, output, from_logits=False, axis=-1): def categorical_crossentropy(target, output, from_logits=False, axis=-1):
"""Categorical crossentropy between an output tensor and a target tensor. """Categorical crossentropy between an output tensor and a target tensor.
@ -4595,6 +4672,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
@keras_export('keras.backend.sparse_categorical_crossentropy') @keras_export('keras.backend.sparse_categorical_crossentropy')
@dispatch.add_dispatch_support
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
"""Categorical crossentropy with integer targets. """Categorical crossentropy with integer targets.
@ -4676,6 +4754,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
@keras_export('keras.backend.binary_crossentropy') @keras_export('keras.backend.binary_crossentropy')
@dispatch.add_dispatch_support
def binary_crossentropy(target, output, from_logits=False): def binary_crossentropy(target, output, from_logits=False):
"""Binary crossentropy between an output tensor and a target tensor. """Binary crossentropy between an output tensor and a target tensor.
@ -4712,6 +4791,7 @@ def binary_crossentropy(target, output, from_logits=False):
@keras_export('keras.backend.sigmoid') @keras_export('keras.backend.sigmoid')
@dispatch.add_dispatch_support
def sigmoid(x): def sigmoid(x):
"""Element-wise sigmoid. """Element-wise sigmoid.
@ -4725,6 +4805,7 @@ def sigmoid(x):
@keras_export('keras.backend.hard_sigmoid') @keras_export('keras.backend.hard_sigmoid')
@dispatch.add_dispatch_support
def hard_sigmoid(x): def hard_sigmoid(x):
"""Segment-wise linear approximation of sigmoid. """Segment-wise linear approximation of sigmoid.
@ -4747,6 +4828,7 @@ def hard_sigmoid(x):
@keras_export('keras.backend.tanh') @keras_export('keras.backend.tanh')
@dispatch.add_dispatch_support
def tanh(x): def tanh(x):
"""Element-wise tanh. """Element-wise tanh.
@ -4760,6 +4842,7 @@ def tanh(x):
@keras_export('keras.backend.dropout') @keras_export('keras.backend.dropout')
@dispatch.add_dispatch_support
def dropout(x, level, noise_shape=None, seed=None): def dropout(x, level, noise_shape=None, seed=None):
"""Sets entries in `x` to zero at random, while scaling the entire tensor. """Sets entries in `x` to zero at random, while scaling the entire tensor.
@ -4780,6 +4863,7 @@ def dropout(x, level, noise_shape=None, seed=None):
@keras_export('keras.backend.l2_normalize') @keras_export('keras.backend.l2_normalize')
@dispatch.add_dispatch_support
def l2_normalize(x, axis=None): def l2_normalize(x, axis=None):
"""Normalizes a tensor wrt the L2 norm alongside the specified axis. """Normalizes a tensor wrt the L2 norm alongside the specified axis.
@ -4794,6 +4878,7 @@ def l2_normalize(x, axis=None):
@keras_export('keras.backend.in_top_k') @keras_export('keras.backend.in_top_k')
@dispatch.add_dispatch_support
def in_top_k(predictions, targets, k): def in_top_k(predictions, targets, k):
"""Returns whether the `targets` are in the top `k` `predictions`. """Returns whether the `targets` are in the top `k` `predictions`.
@ -4896,6 +4981,7 @@ def _preprocess_padding(padding):
@keras_export('keras.backend.conv1d') @keras_export('keras.backend.conv1d')
@dispatch.add_dispatch_support
def conv1d(x, def conv1d(x,
kernel, kernel,
strides=1, strides=1,
@ -4946,6 +5032,7 @@ def conv1d(x,
@keras_export('keras.backend.conv2d') @keras_export('keras.backend.conv2d')
@dispatch.add_dispatch_support
def conv2d(x, def conv2d(x,
kernel, kernel,
strides=(1, 1), strides=(1, 1),
@ -4989,6 +5076,7 @@ def conv2d(x,
@keras_export('keras.backend.conv2d_transpose') @keras_export('keras.backend.conv2d_transpose')
@dispatch.add_dispatch_support
def conv2d_transpose(x, def conv2d_transpose(x,
kernel, kernel,
output_shape, output_shape,
@ -5129,6 +5217,7 @@ def separable_conv1d(x,
@keras_export('keras.backend.separable_conv2d') @keras_export('keras.backend.separable_conv2d')
@dispatch.add_dispatch_support
def separable_conv2d(x, def separable_conv2d(x,
depthwise_kernel, depthwise_kernel,
pointwise_kernel, pointwise_kernel,
@ -5186,6 +5275,7 @@ def separable_conv2d(x,
@keras_export('keras.backend.depthwise_conv2d') @keras_export('keras.backend.depthwise_conv2d')
@dispatch.add_dispatch_support
def depthwise_conv2d(x, def depthwise_conv2d(x,
depthwise_kernel, depthwise_kernel,
strides=(1, 1), strides=(1, 1),
@ -5235,6 +5325,7 @@ def depthwise_conv2d(x,
@keras_export('keras.backend.conv3d') @keras_export('keras.backend.conv3d')
@dispatch.add_dispatch_support
def conv3d(x, def conv3d(x,
kernel, kernel,
strides=(1, 1, 1), strides=(1, 1, 1),
@ -5337,6 +5428,7 @@ def conv3d_transpose(x,
@keras_export('keras.backend.pool2d') @keras_export('keras.backend.pool2d')
@dispatch.add_dispatch_support
def pool2d(x, def pool2d(x,
pool_size, pool_size,
strides=(1, 1), strides=(1, 1),
@ -5396,6 +5488,7 @@ def pool2d(x,
@keras_export('keras.backend.pool3d') @keras_export('keras.backend.pool3d')
@dispatch.add_dispatch_support
def pool3d(x, def pool3d(x,
pool_size, pool_size,
strides=(1, 1, 1), strides=(1, 1, 1),
@ -5526,6 +5619,7 @@ def local_conv(inputs,
@keras_export('keras.backend.local_conv1d') @keras_export('keras.backend.local_conv1d')
@dispatch.add_dispatch_support
def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None): def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
"""Apply 1D conv with un-shared weights. """Apply 1D conv with un-shared weights.
@ -5561,6 +5655,7 @@ def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
@keras_export('keras.backend.local_conv2d') @keras_export('keras.backend.local_conv2d')
@dispatch.add_dispatch_support
def local_conv2d(inputs, def local_conv2d(inputs,
kernel, kernel,
kernel_size, kernel_size,
@ -5602,6 +5697,7 @@ def local_conv2d(inputs,
@keras_export('keras.backend.bias_add') @keras_export('keras.backend.bias_add')
@dispatch.add_dispatch_support
def bias_add(x, bias, data_format=None): def bias_add(x, bias, data_format=None):
"""Adds a bias vector to a tensor. """Adds a bias vector to a tensor.
@ -5646,6 +5742,7 @@ def bias_add(x, bias, data_format=None):
@keras_export('keras.backend.random_normal') @keras_export('keras.backend.random_normal')
@dispatch.add_dispatch_support
def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
"""Returns a tensor with normal distribution of values. """Returns a tensor with normal distribution of values.
@ -5682,6 +5779,7 @@ def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
@keras_export('keras.backend.random_uniform') @keras_export('keras.backend.random_uniform')
@dispatch.add_dispatch_support
def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
"""Returns a tensor with uniform distribution of values. """Returns a tensor with uniform distribution of values.
@ -5715,6 +5813,7 @@ def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
@deprecated(None, 'Use `tf.keras.backend.random_bernoulli` instead.') @deprecated(None, 'Use `tf.keras.backend.random_bernoulli` instead.')
@keras_export('keras.backend.random_binomial') @keras_export('keras.backend.random_binomial')
@dispatch.add_dispatch_support
def random_binomial(shape, p=0.0, dtype=None, seed=None): def random_binomial(shape, p=0.0, dtype=None, seed=None):
"""Returns a tensor with random binomial distribution of values. """Returns a tensor with random binomial distribution of values.
@ -5751,6 +5850,7 @@ def random_binomial(shape, p=0.0, dtype=None, seed=None):
@keras_export('keras.backend.random_bernoulli') @keras_export('keras.backend.random_bernoulli')
@dispatch.add_dispatch_support
def random_bernoulli(shape, p=0.0, dtype=None, seed=None): def random_bernoulli(shape, p=0.0, dtype=None, seed=None):
"""Returns a tensor with random bernoulli distribution of values. """Returns a tensor with random bernoulli distribution of values.
@ -5767,6 +5867,7 @@ def random_bernoulli(shape, p=0.0, dtype=None, seed=None):
@keras_export('keras.backend.truncated_normal') @keras_export('keras.backend.truncated_normal')
@dispatch.add_dispatch_support
def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
"""Returns a tensor with truncated random normal distribution of values. """Returns a tensor with truncated random normal distribution of values.
@ -5801,6 +5902,7 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
@keras_export('keras.backend.ctc_label_dense_to_sparse') @keras_export('keras.backend.ctc_label_dense_to_sparse')
@dispatch.add_dispatch_support
def ctc_label_dense_to_sparse(labels, label_lengths): def ctc_label_dense_to_sparse(labels, label_lengths):
"""Converts CTC labels from dense to sparse. """Converts CTC labels from dense to sparse.
@ -5847,6 +5949,7 @@ def ctc_label_dense_to_sparse(labels, label_lengths):
@keras_export('keras.backend.ctc_batch_cost') @keras_export('keras.backend.ctc_batch_cost')
@dispatch.add_dispatch_support
def ctc_batch_cost(y_true, y_pred, input_length, label_length): def ctc_batch_cost(y_true, y_pred, input_length, label_length):
"""Runs CTC loss algorithm on each batch element. """Runs CTC loss algorithm on each batch element.
@ -5879,6 +5982,7 @@ def ctc_batch_cost(y_true, y_pred, input_length, label_length):
@keras_export('keras.backend.ctc_decode') @keras_export('keras.backend.ctc_decode')
@dispatch.add_dispatch_support
def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1): def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
"""Decodes the output of a softmax. """Decodes the output of a softmax.

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
# The type of float to use throughout a session. # The type of float to use throughout a session.
@ -30,6 +31,7 @@ _IMAGE_DATA_FORMAT = 'channels_last'
@keras_export('keras.backend.epsilon') @keras_export('keras.backend.epsilon')
@dispatch.add_dispatch_support
def epsilon(): def epsilon():
"""Returns the value of the fuzz factor used in numeric expressions. """Returns the value of the fuzz factor used in numeric expressions.
@ -110,6 +112,7 @@ def set_floatx(value):
@keras_export('keras.backend.image_data_format') @keras_export('keras.backend.image_data_format')
@dispatch.add_dispatch_support
def image_data_format(): def image_data_format():
"""Returns the default image data format convention. """Returns the default image data format convention.

View File

@ -38,6 +38,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn from tensorflow.python.ops import nn
from tensorflow.python.ops.losses import losses_impl from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.ops.losses import util as tf_losses_util from tensorflow.python.ops.losses import util as tf_losses_util
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls from tensorflow.tools.docs import doc_controls
@ -1164,6 +1165,7 @@ class Huber(LossFunctionWrapper):
'keras.losses.mean_squared_error', 'keras.losses.mean_squared_error',
'keras.losses.mse', 'keras.losses.mse',
'keras.losses.MSE') 'keras.losses.MSE')
@dispatch.add_dispatch_support
def mean_squared_error(y_true, y_pred): def mean_squared_error(y_true, y_pred):
"""Computes the mean squared error between labels and predictions. """Computes the mean squared error between labels and predictions.
@ -1199,6 +1201,7 @@ def mean_squared_error(y_true, y_pred):
'keras.losses.mean_absolute_error', 'keras.losses.mean_absolute_error',
'keras.losses.mae', 'keras.losses.mae',
'keras.losses.MAE') 'keras.losses.MAE')
@dispatch.add_dispatch_support
def mean_absolute_error(y_true, y_pred): def mean_absolute_error(y_true, y_pred):
"""Computes the mean absolute error between labels and predictions. """Computes the mean absolute error between labels and predictions.
@ -1231,6 +1234,7 @@ def mean_absolute_error(y_true, y_pred):
'keras.losses.mean_absolute_percentage_error', 'keras.losses.mean_absolute_percentage_error',
'keras.losses.mape', 'keras.losses.mape',
'keras.losses.MAPE') 'keras.losses.MAPE')
@dispatch.add_dispatch_support
def mean_absolute_percentage_error(y_true, y_pred): def mean_absolute_percentage_error(y_true, y_pred):
"""Computes the mean absolute percentage error between `y_true` and `y_pred`. """Computes the mean absolute percentage error between `y_true` and `y_pred`.
@ -1267,6 +1271,7 @@ def mean_absolute_percentage_error(y_true, y_pred):
'keras.losses.mean_squared_logarithmic_error', 'keras.losses.mean_squared_logarithmic_error',
'keras.losses.msle', 'keras.losses.msle',
'keras.losses.MSLE') 'keras.losses.MSLE')
@dispatch.add_dispatch_support
def mean_squared_logarithmic_error(y_true, y_pred): def mean_squared_logarithmic_error(y_true, y_pred):
"""Computes the mean squared logarithmic error between `y_true` and `y_pred`. """Computes the mean squared logarithmic error between `y_true` and `y_pred`.
@ -1315,6 +1320,7 @@ def _maybe_convert_labels(y_true):
@keras_export('keras.metrics.squared_hinge', 'keras.losses.squared_hinge') @keras_export('keras.metrics.squared_hinge', 'keras.losses.squared_hinge')
@dispatch.add_dispatch_support
def squared_hinge(y_true, y_pred): def squared_hinge(y_true, y_pred):
"""Computes the squared hinge loss between `y_true` and `y_pred`. """Computes the squared hinge loss between `y_true` and `y_pred`.
@ -1347,6 +1353,7 @@ def squared_hinge(y_true, y_pred):
@keras_export('keras.metrics.hinge', 'keras.losses.hinge') @keras_export('keras.metrics.hinge', 'keras.losses.hinge')
@dispatch.add_dispatch_support
def hinge(y_true, y_pred): def hinge(y_true, y_pred):
"""Computes the hinge loss between `y_true` and `y_pred`. """Computes the hinge loss between `y_true` and `y_pred`.
@ -1378,6 +1385,7 @@ def hinge(y_true, y_pred):
@keras_export('keras.losses.categorical_hinge') @keras_export('keras.losses.categorical_hinge')
@dispatch.add_dispatch_support
def categorical_hinge(y_true, y_pred): def categorical_hinge(y_true, y_pred):
"""Computes the categorical hinge loss between `y_true` and `y_pred`. """Computes the categorical hinge loss between `y_true` and `y_pred`.
@ -1410,6 +1418,7 @@ def categorical_hinge(y_true, y_pred):
@keras_export('keras.losses.huber', v1=[]) @keras_export('keras.losses.huber', v1=[])
@dispatch.add_dispatch_support
def huber(y_true, y_pred, delta=1.0): def huber(y_true, y_pred, delta=1.0):
"""Computes Huber loss value. """Computes Huber loss value.
@ -1447,6 +1456,7 @@ def huber(y_true, y_pred, delta=1.0):
@keras_export('keras.losses.log_cosh', 'keras.losses.logcosh') @keras_export('keras.losses.log_cosh', 'keras.losses.logcosh')
@dispatch.add_dispatch_support
def log_cosh(y_true, y_pred): def log_cosh(y_true, y_pred):
"""Logarithm of the hyperbolic cosine of the prediction error. """Logarithm of the hyperbolic cosine of the prediction error.
@ -1485,6 +1495,7 @@ def log_cosh(y_true, y_pred):
@keras_export('keras.metrics.categorical_crossentropy', @keras_export('keras.metrics.categorical_crossentropy',
'keras.losses.categorical_crossentropy') 'keras.losses.categorical_crossentropy')
@dispatch.add_dispatch_support
def categorical_crossentropy(y_true, def categorical_crossentropy(y_true,
y_pred, y_pred,
from_logits=False, from_logits=False,
@ -1525,6 +1536,7 @@ def categorical_crossentropy(y_true,
@keras_export('keras.metrics.sparse_categorical_crossentropy', @keras_export('keras.metrics.sparse_categorical_crossentropy',
'keras.losses.sparse_categorical_crossentropy') 'keras.losses.sparse_categorical_crossentropy')
@dispatch.add_dispatch_support
def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1): def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
"""Computes the sparse categorical crossentropy loss. """Computes the sparse categorical crossentropy loss.
@ -1556,6 +1568,7 @@ def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
@keras_export('keras.metrics.binary_crossentropy', @keras_export('keras.metrics.binary_crossentropy',
'keras.losses.binary_crossentropy') 'keras.losses.binary_crossentropy')
@dispatch.add_dispatch_support
def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0): def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0):
"""Computes the binary crossentropy loss. """Computes the binary crossentropy loss.
@ -1599,6 +1612,7 @@ def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0):
'keras.losses.kullback_leibler_divergence', 'keras.losses.kullback_leibler_divergence',
'keras.losses.kld', 'keras.losses.kld',
'keras.losses.KLD') 'keras.losses.KLD')
@dispatch.add_dispatch_support
def kl_divergence(y_true, y_pred): def kl_divergence(y_true, y_pred):
"""Computes Kullback-Leibler divergence loss between `y_true` and `y_pred`. """Computes Kullback-Leibler divergence loss between `y_true` and `y_pred`.
@ -1635,6 +1649,7 @@ def kl_divergence(y_true, y_pred):
@keras_export('keras.metrics.poisson', 'keras.losses.poisson') @keras_export('keras.metrics.poisson', 'keras.losses.poisson')
@dispatch.add_dispatch_support
def poisson(y_true, y_pred): def poisson(y_true, y_pred):
"""Computes the Poisson loss between y_true and y_pred. """Computes the Poisson loss between y_true and y_pred.
@ -1676,6 +1691,7 @@ def poisson(y_true, y_pred):
'keras.losses.cosine', 'keras.losses.cosine',
'keras.losses.cosine_similarity', 'keras.losses.cosine_similarity',
]) ])
@dispatch.add_dispatch_support
def cosine_similarity(y_true, y_pred, axis=-1): def cosine_similarity(y_true, y_pred, axis=-1):
"""Computes the cosine similarity between labels and predictions. """Computes the cosine similarity between labels and predictions.

View File

@ -69,6 +69,7 @@ from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.ops.losses import util as tf_losses_utils from tensorflow.python.ops.losses import util as tf_losses_utils
from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -3212,6 +3213,7 @@ def accuracy(y_true, y_pred):
@keras_export('keras.metrics.binary_accuracy') @keras_export('keras.metrics.binary_accuracy')
@dispatch.add_dispatch_support
def binary_accuracy(y_true, y_pred, threshold=0.5): def binary_accuracy(y_true, y_pred, threshold=0.5):
"""Calculates how often predictions matches binary labels. """Calculates how often predictions matches binary labels.
@ -3239,6 +3241,7 @@ def binary_accuracy(y_true, y_pred, threshold=0.5):
@keras_export('keras.metrics.categorical_accuracy') @keras_export('keras.metrics.categorical_accuracy')
@dispatch.add_dispatch_support
def categorical_accuracy(y_true, y_pred): def categorical_accuracy(y_true, y_pred):
"""Calculates how often predictions matches one-hot labels. """Calculates how often predictions matches one-hot labels.
@ -3267,6 +3270,7 @@ def categorical_accuracy(y_true, y_pred):
@keras_export('keras.metrics.sparse_categorical_accuracy') @keras_export('keras.metrics.sparse_categorical_accuracy')
@dispatch.add_dispatch_support
def sparse_categorical_accuracy(y_true, y_pred): def sparse_categorical_accuracy(y_true, y_pred):
"""Calculates how often predictions matches integer labels. """Calculates how often predictions matches integer labels.
@ -3307,6 +3311,7 @@ def sparse_categorical_accuracy(y_true, y_pred):
@keras_export('keras.metrics.top_k_categorical_accuracy') @keras_export('keras.metrics.top_k_categorical_accuracy')
@dispatch.add_dispatch_support
def top_k_categorical_accuracy(y_true, y_pred, k=5): def top_k_categorical_accuracy(y_true, y_pred, k=5):
"""Computes how often targets are in the top `K` predictions. """Computes how often targets are in the top `K` predictions.
@ -3332,6 +3337,7 @@ def top_k_categorical_accuracy(y_true, y_pred, k=5):
@keras_export('keras.metrics.sparse_top_k_categorical_accuracy') @keras_export('keras.metrics.sparse_top_k_categorical_accuracy')
@dispatch.add_dispatch_support
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5): def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
"""Computes how often integer targets are in the top `K` predictions. """Computes how often integer targets are in the top `K` predictions.

View File

@ -57,6 +57,7 @@ _BaseSlice = slice
@tf_export("reshape", v1=["reshape", "manip.reshape"]) @tf_export("reshape", v1=["reshape", "manip.reshape"])
@dispatch.add_dispatch_support
def reshape(tensor, shape, name=None): # pylint: disable=redefined-outer-name def reshape(tensor, shape, name=None): # pylint: disable=redefined-outer-name
r"""Reshapes a tensor. r"""Reshapes a tensor.
@ -197,6 +198,7 @@ def reshape(tensor, shape, name=None): # pylint: disable=redefined-outer-name
@tf_export("fill") @tf_export("fill")
@dispatch.add_dispatch_support
def fill(dims, value, name=None): def fill(dims, value, name=None):
r"""Creates a tensor filled with a scalar value. r"""Creates a tensor filled with a scalar value.
@ -455,6 +457,7 @@ listdiff.__doc__ = gen_array_ops.list_diff.__doc__ + "\n" + listdiff.__doc__
"This op will be removed after the deprecation date. " "This op will be removed after the deprecation date. "
"Please switch to tf.sets.difference().") "Please switch to tf.sets.difference().")
@tf_export(v1=["setdiff1d"]) @tf_export(v1=["setdiff1d"])
@dispatch.add_dispatch_support
def setdiff1d(x, y, index_dtype=dtypes.int32, name=None): def setdiff1d(x, y, index_dtype=dtypes.int32, name=None):
"""Computes the difference between two lists of numbers or strings. """Computes the difference between two lists of numbers or strings.
@ -498,6 +501,7 @@ setdiff1d.__doc__ = gen_array_ops.list_diff.__doc__
@tf_export("broadcast_dynamic_shape") @tf_export("broadcast_dynamic_shape")
@dispatch.add_dispatch_support
def broadcast_dynamic_shape(shape_x, shape_y): def broadcast_dynamic_shape(shape_x, shape_y):
"""Computes the shape of a broadcast given symbolic shapes. """Computes the shape of a broadcast given symbolic shapes.
@ -523,6 +527,7 @@ def broadcast_dynamic_shape(shape_x, shape_y):
@tf_export("broadcast_static_shape") @tf_export("broadcast_static_shape")
@dispatch.add_dispatch_support
def broadcast_static_shape(shape_x, shape_y): def broadcast_static_shape(shape_x, shape_y):
"""Computes the shape of a broadcast given known shapes. """Computes the shape of a broadcast given known shapes.
@ -550,6 +555,7 @@ def broadcast_static_shape(shape_x, shape_y):
@tf_export("shape", v1=[]) @tf_export("shape", v1=[])
@dispatch.add_dispatch_support
def shape_v2(input, out_type=dtypes.int32, name=None): def shape_v2(input, out_type=dtypes.int32, name=None):
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
"""Returns the shape of a tensor. """Returns the shape of a tensor.
@ -596,6 +602,7 @@ def shape_v2(input, out_type=dtypes.int32, name=None):
@tf_export(v1=["shape"]) @tf_export(v1=["shape"])
@dispatch.add_dispatch_support
def shape(input, name=None, out_type=dtypes.int32): def shape(input, name=None, out_type=dtypes.int32):
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
"""Returns the shape of a tensor. """Returns the shape of a tensor.
@ -650,6 +657,7 @@ def shape_internal(input, name=None, optimize=True, out_type=dtypes.int32):
@tf_export("shape_n") @tf_export("shape_n")
@dispatch.add_dispatch_support
def shape_n(input, out_type=dtypes.int32, name=None): def shape_n(input, out_type=dtypes.int32, name=None):
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
"""Returns shape of tensors. """Returns shape of tensors.
@ -1007,6 +1015,7 @@ def _slice_helper(tensor, slice_spec, var=None):
# pylint: disable=undefined-variable,protected-access,redefined-outer-name # pylint: disable=undefined-variable,protected-access,redefined-outer-name
@tf_export("slice") @tf_export("slice")
@dispatch.add_dispatch_support
def slice(input_, begin, size, name=None): def slice(input_, begin, size, name=None):
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
"""Extracts a slice from a tensor. """Extracts a slice from a tensor.
@ -1062,6 +1071,7 @@ def slice(input_, begin, size, name=None):
# pylint: disable=invalid-name # pylint: disable=invalid-name
@tf_export("strided_slice") @tf_export("strided_slice")
@dispatch.add_dispatch_support
def strided_slice(input_, def strided_slice(input_,
begin, begin,
end, end,
@ -1253,6 +1263,7 @@ ops.Tensor._override_operator("__getitem__", _slice_helper)
@tf_export("parallel_stack") @tf_export("parallel_stack")
@dispatch.add_dispatch_support
def parallel_stack(values, name="parallel_stack"): def parallel_stack(values, name="parallel_stack"):
"""Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor in parallel. """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor in parallel.
@ -1489,6 +1500,7 @@ ops.register_tensor_conversion_function((list, tuple),
@tf_export("unstack") @tf_export("unstack")
@dispatch.add_dispatch_support
def unstack(value, num=None, axis=0, name="unstack"): def unstack(value, num=None, axis=0, name="unstack"):
"""Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors. """Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
@ -1632,6 +1644,7 @@ def concat(values, axis, name="concat"):
@tf_export(v1=["boolean_mask"]) @tf_export(v1=["boolean_mask"])
@dispatch.add_dispatch_support
def boolean_mask(tensor, mask, name="boolean_mask", axis=None): def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
"""Apply boolean mask to tensor. """Apply boolean mask to tensor.
@ -1824,6 +1837,7 @@ def sparse_mask(a, mask_indices, name=None):
@tf_export("unique") @tf_export("unique")
@dispatch.add_dispatch_support
def unique(x, out_idx=dtypes.int32, name=None): def unique(x, out_idx=dtypes.int32, name=None):
"""Finds unique elements in a 1-D tensor. """Finds unique elements in a 1-D tensor.
@ -1871,6 +1885,7 @@ unique.__doc__ = gen_array_ops.unique.__doc__
@tf_export("unique_with_counts") @tf_export("unique_with_counts")
@dispatch.add_dispatch_support
def unique_with_counts(x, out_idx=dtypes.int32, name=None): def unique_with_counts(x, out_idx=dtypes.int32, name=None):
"""Finds unique elements in a 1-D tensor. """Finds unique elements in a 1-D tensor.
@ -1923,6 +1938,7 @@ unique_with_counts.__doc__ = gen_array_ops.unique_with_counts.__doc__
@tf_export("split") @tf_export("split")
@dispatch.add_dispatch_support
def split(value, num_or_size_splits, axis=0, num=None, name="split"): def split(value, num_or_size_splits, axis=0, num=None, name="split"):
"""Splits a tensor `value` into a list of sub tensors. """Splits a tensor `value` into a list of sub tensors.
@ -2000,6 +2016,7 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"):
@tf_export("transpose", v1=[]) @tf_export("transpose", v1=[])
@dispatch.add_dispatch_support
def transpose_v2(a, perm=None, conjugate=False, name="transpose"): def transpose_v2(a, perm=None, conjugate=False, name="transpose"):
"""Transposes `a`, where `a` is a Tensor. """Transposes `a`, where `a` is a Tensor.
@ -2080,6 +2097,7 @@ def transpose_v2(a, perm=None, conjugate=False, name="transpose"):
@tf_export(v1=["transpose"]) @tf_export(v1=["transpose"])
@dispatch.add_dispatch_support
def transpose(a, perm=None, name="transpose", conjugate=False): def transpose(a, perm=None, name="transpose", conjugate=False):
"""Transposes `a`. """Transposes `a`.
@ -2170,6 +2188,7 @@ def transpose(a, perm=None, name="transpose", conjugate=False):
@tf_export( @tf_export(
"linalg.matrix_transpose", "linalg.matrix_transpose",
v1=["linalg.transpose", "linalg.matrix_transpose", "matrix_transpose"]) v1=["linalg.transpose", "linalg.matrix_transpose", "matrix_transpose"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("matrix_transpose", "linalg.transpose") @deprecation.deprecated_endpoints("matrix_transpose", "linalg.transpose")
def matrix_transpose(a, name="matrix_transpose", conjugate=False): def matrix_transpose(a, name="matrix_transpose", conjugate=False):
"""Transposes last two dimensions of tensor `a`. """Transposes last two dimensions of tensor `a`.
@ -2248,6 +2267,7 @@ def matrix_transpose(a, name="matrix_transpose", conjugate=False):
@tf_export("linalg.diag", v1=["linalg.diag", "matrix_diag"]) @tf_export("linalg.diag", v1=["linalg.diag", "matrix_diag"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("matrix_diag") @deprecation.deprecated_endpoints("matrix_diag")
def matrix_diag(diagonal, def matrix_diag(diagonal,
name="diag", name="diag",
@ -2416,6 +2436,7 @@ def matrix_diag(diagonal,
@tf_export("linalg.diag_part", v1=["linalg.diag_part", "matrix_diag_part"]) @tf_export("linalg.diag_part", v1=["linalg.diag_part", "matrix_diag_part"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("matrix_diag_part") @deprecation.deprecated_endpoints("matrix_diag_part")
@dispatch.add_dispatch_support @dispatch.add_dispatch_support
def matrix_diag_part( def matrix_diag_part(
@ -2556,6 +2577,7 @@ def matrix_diag_part(
@tf_export("linalg.set_diag", v1=["linalg.set_diag", "matrix_set_diag"]) @tf_export("linalg.set_diag", v1=["linalg.set_diag", "matrix_set_diag"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("matrix_set_diag") @deprecation.deprecated_endpoints("matrix_set_diag")
def matrix_set_diag( def matrix_set_diag(
input, # pylint:disable=redefined-builtin input, # pylint:disable=redefined-builtin
@ -2719,6 +2741,7 @@ def _tag_zeros_tensor(fun):
@tf_export("zeros") @tf_export("zeros")
@dispatch.add_dispatch_support
@_tag_zeros_tensor @_tag_zeros_tensor
def zeros(shape, dtype=dtypes.float32, name=None): def zeros(shape, dtype=dtypes.float32, name=None):
"""Creates a tensor with all elements set to zero. """Creates a tensor with all elements set to zero.
@ -2971,6 +2994,7 @@ def ones_like_impl(tensor, dtype, name, optimize=True):
@tf_export("ones") @tf_export("ones")
@dispatch.add_dispatch_support
def ones(shape, dtype=dtypes.float32, name=None): def ones(shape, dtype=dtypes.float32, name=None):
"""Creates a tensor with all elements set to one (1). """Creates a tensor with all elements set to one (1).
@ -3182,6 +3206,7 @@ def sparse_placeholder(dtype, shape=None, name=None):
@tf_export("pad", v1=[]) @tf_export("pad", v1=[])
@dispatch.add_dispatch_support
def pad_v2(tensor, paddings, mode="CONSTANT", constant_values=0, name=None): def pad_v2(tensor, paddings, mode="CONSTANT", constant_values=0, name=None):
"""Pads a tensor. """Pads a tensor.
@ -3240,6 +3265,7 @@ def pad_v2(tensor, paddings, mode="CONSTANT", constant_values=0, name=None):
@tf_export(v1=["pad"]) @tf_export(v1=["pad"])
@dispatch.add_dispatch_support
def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pylint: disable=invalid-name def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pylint: disable=invalid-name
"""Pads a tensor. """Pads a tensor.
@ -3357,6 +3383,7 @@ def _get_paddings_constant(paddings):
@tf_export("meshgrid") @tf_export("meshgrid")
@dispatch.add_dispatch_support
def meshgrid(*args, **kwargs): def meshgrid(*args, **kwargs):
"""Broadcasts parameters for evaluation on an N-D grid. """Broadcasts parameters for evaluation on an N-D grid.
@ -3500,6 +3527,7 @@ def _TileGradShape(op):
@tf_export("edit_distance") @tf_export("edit_distance")
@dispatch.add_dispatch_support
def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"): def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"):
"""Computes the Levenshtein distance between sequences. """Computes the Levenshtein distance between sequences.
@ -3694,6 +3722,7 @@ def required_space_to_batch_paddings(input_shape,
@tf_export(v1=["nn.space_to_batch", "space_to_batch"]) @tf_export(v1=["nn.space_to_batch", "space_to_batch"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("space_to_batch") @deprecation.deprecated_endpoints("space_to_batch")
def space_to_batch( # pylint: disable=missing-docstring def space_to_batch( # pylint: disable=missing-docstring
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
@ -3717,6 +3746,7 @@ space_to_batch.__doc__ = gen_array_ops.space_to_batch.__doc__
@tf_export("space_to_batch", "nn.space_to_batch", v1=[]) @tf_export("space_to_batch", "nn.space_to_batch", v1=[])
@dispatch.add_dispatch_support
def space_to_batch_v2(input, block_shape, paddings, name=None): # pylint: disable=redefined-builtin def space_to_batch_v2(input, block_shape, paddings, name=None): # pylint: disable=redefined-builtin
return space_to_batch_nd(input, block_shape, paddings, name) return space_to_batch_nd(input, block_shape, paddings, name)
@ -3725,6 +3755,7 @@ space_to_batch_v2.__doc__ = gen_array_ops.space_to_batch_nd.__doc__
@tf_export(v1=["nn.space_to_depth", "space_to_depth"]) @tf_export(v1=["nn.space_to_depth", "space_to_depth"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("space_to_depth") @deprecation.deprecated_endpoints("space_to_depth")
def space_to_depth(input, block_size, name=None, data_format="NHWC"): # pylint: disable=redefined-builtin def space_to_depth(input, block_size, name=None, data_format="NHWC"): # pylint: disable=redefined-builtin
return gen_array_ops.space_to_depth(input, block_size, data_format, name=name) return gen_array_ops.space_to_depth(input, block_size, data_format, name=name)
@ -3734,6 +3765,7 @@ space_to_depth.__doc__ = gen_array_ops.space_to_depth.__doc__
@tf_export("nn.space_to_depth", v1=[]) @tf_export("nn.space_to_depth", v1=[])
@dispatch.add_dispatch_support
def space_to_depth_v2(input, block_size, data_format="NHWC", name=None): # pylint: disable=redefined-builtin def space_to_depth_v2(input, block_size, data_format="NHWC", name=None): # pylint: disable=redefined-builtin
return gen_array_ops.space_to_depth(input, block_size, data_format, name=name) return gen_array_ops.space_to_depth(input, block_size, data_format, name=name)
@ -3742,6 +3774,7 @@ space_to_depth_v2.__doc__ = gen_array_ops.space_to_depth.__doc__
@tf_export(v1=["nn.depth_to_space", "depth_to_space"]) @tf_export(v1=["nn.depth_to_space", "depth_to_space"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("depth_to_space") @deprecation.deprecated_endpoints("depth_to_space")
def depth_to_space(input, block_size, name=None, data_format="NHWC"): # pylint: disable=redefined-builtin def depth_to_space(input, block_size, name=None, data_format="NHWC"): # pylint: disable=redefined-builtin
return gen_array_ops.depth_to_space(input, block_size, data_format, name=name) return gen_array_ops.depth_to_space(input, block_size, data_format, name=name)
@ -3751,6 +3784,7 @@ depth_to_space.__doc__ = gen_array_ops.depth_to_space.__doc__
@tf_export("nn.depth_to_space", v1=[]) @tf_export("nn.depth_to_space", v1=[])
@dispatch.add_dispatch_support
def depth_to_space_v2(input, block_size, data_format="NHWC", name=None): # pylint: disable=redefined-builtin def depth_to_space_v2(input, block_size, data_format="NHWC", name=None): # pylint: disable=redefined-builtin
return gen_array_ops.depth_to_space(input, block_size, data_format, name=name) return gen_array_ops.depth_to_space(input, block_size, data_format, name=name)
@ -3759,6 +3793,7 @@ depth_to_space_v2.__doc__ = gen_array_ops.depth_to_space.__doc__
@tf_export(v1=["batch_to_space"]) @tf_export(v1=["batch_to_space"])
@dispatch.add_dispatch_support
def batch_to_space(input, crops, block_size, name=None, block_shape=None): # pylint: disable=redefined-builtin,missing-docstring def batch_to_space(input, crops, block_size, name=None, block_shape=None): # pylint: disable=redefined-builtin,missing-docstring
block_size = deprecation.deprecated_argument_lookup("block_shape", block_size = deprecation.deprecated_argument_lookup("block_shape",
block_shape, "block_size", block_shape, "block_size",
@ -3776,6 +3811,7 @@ batch_to_space.__doc__ = gen_array_ops.batch_to_space.__doc__
@tf_export("batch_to_space", v1=[]) @tf_export("batch_to_space", v1=[])
@dispatch.add_dispatch_support
def batch_to_space_v2(input, block_shape, crops, name=None): # pylint: disable=redefined-builtin def batch_to_space_v2(input, block_shape, crops, name=None): # pylint: disable=redefined-builtin
"""BatchToSpace for N-D tensors of type T. """BatchToSpace for N-D tensors of type T.
@ -4091,6 +4127,7 @@ def _all_dimensions(x):
@tf_export("sequence_mask") @tf_export("sequence_mask")
@dispatch.add_dispatch_support
def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None): def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None):
"""Returns a mask tensor representing the first N positions of each cell. """Returns a mask tensor representing the first N positions of each cell.
@ -4317,6 +4354,7 @@ def where(condition, x=None, y=None, name=None):
@tf_export("where", v1=["where_v2"]) @tf_export("where", v1=["where_v2"])
@dispatch.add_dispatch_support
def where_v2(condition, x=None, y=None, name=None): def where_v2(condition, x=None, y=None, name=None):
"""Return the elements where `condition` is `True` (multiplexing `x` and `y`). """Return the elements where `condition` is `True` (multiplexing `x` and `y`).
@ -5003,6 +5041,7 @@ def batch_gather_nd(params, indices, batch_dims, name=None):
# because round_mode was added later. # because round_mode was added later.
# (And also now because of 'axis' processing). # (And also now because of 'axis' processing).
@tf_export(v1=["quantize_v2"]) @tf_export(v1=["quantize_v2"])
@dispatch.add_dispatch_support
@deprecation.deprecated( @deprecation.deprecated(
"2017-10-25", "2017-10-25",
"`tf.quantize_v2` is deprecated, please use `tf.quantization.quantize` " "`tf.quantize_v2` is deprecated, please use `tf.quantization.quantize` "
@ -5056,6 +5095,7 @@ quantize_v2.__doc__ = """Please use `tf.quantization.quantize` instead."""
# tf.quantization.quantize; we can deprecate tf.quantization.quantize in next # tf.quantization.quantize; we can deprecate tf.quantization.quantize in next
# version of TensorFlow. # version of TensorFlow.
@tf_export("quantization.quantize", v1=["quantization.quantize", "quantize"]) @tf_export("quantization.quantize", v1=["quantization.quantize", "quantize"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("quantize") @deprecation.deprecated_endpoints("quantize")
def quantize( def quantize(
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
@ -5095,6 +5135,7 @@ def quantize(
@tf_export("quantization.dequantize", v1=["quantization.dequantize", @tf_export("quantization.dequantize", v1=["quantization.dequantize",
"dequantize"]) "dequantize"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("dequantize") @deprecation.deprecated_endpoints("dequantize")
def dequantize( # pylint: disable=missing-docstring def dequantize( # pylint: disable=missing-docstring
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
@ -5130,6 +5171,7 @@ dequantize.__doc__ = gen_array_ops.dequantize.__doc__
@tf_export("quantization.quantize_and_dequantize") @tf_export("quantization.quantize_and_dequantize")
@dispatch.add_dispatch_support
def quantize_and_dequantize( def quantize_and_dequantize(
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
input_min, input_min,
@ -5189,6 +5231,7 @@ def quantize_and_dequantize(
@tf_export("searchsorted") @tf_export("searchsorted")
@dispatch.add_dispatch_support
def searchsorted(sorted_sequence, def searchsorted(sorted_sequence,
values, values,
side="left", side="left",
@ -5253,6 +5296,7 @@ quantize.__doc__ = gen_array_ops.quantize_v2.__doc__
@tf_export("image.extract_patches") @tf_export("image.extract_patches")
@dispatch.add_dispatch_support
def extract_image_patches_v2(images, sizes, strides, rates, padding, name=None): def extract_image_patches_v2(images, sizes, strides, rates, padding, name=None):
r"""Extract `patches` from `images`. r"""Extract `patches` from `images`.
@ -5374,6 +5418,7 @@ def extract_image_patches_v2(images, sizes, strides, rates, padding, name=None):
@tf_export(v1=["image.extract_image_patches", "extract_image_patches"]) @tf_export(v1=["image.extract_image_patches", "extract_image_patches"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, "ksizes is deprecated, use sizes instead", @deprecation.deprecated_args(None, "ksizes is deprecated, use sizes instead",
"ksizes") "ksizes")
def extract_image_patches( # pylint: disable=missing-docstring def extract_image_patches( # pylint: disable=missing-docstring
@ -5422,6 +5467,7 @@ extract_image_patches.__doc__ = gen_array_ops.extract_image_patches.__doc__
@tf_export("fingerprint") @tf_export("fingerprint")
@dispatch.add_dispatch_support
def fingerprint(data, method="farmhash64", name=None): def fingerprint(data, method="farmhash64", name=None):
r"""Generates fingerprint values. r"""Generates fingerprint values.
@ -5668,6 +5714,7 @@ def _with_nonzero_rank(data):
@tf_export("repeat") @tf_export("repeat")
@dispatch.add_dispatch_support
def repeat(input, repeats, axis=None, name=None): # pylint: disable=redefined-builtin def repeat(input, repeats, axis=None, name=None): # pylint: disable=redefined-builtin
"""Repeat elements of `input`. """Repeat elements of `input`.

View File

@ -24,12 +24,14 @@ from tensorflow.python.ops import array_ops # pylint: disable=unused-import
from tensorflow.python.ops import gen_candidate_sampling_ops from tensorflow.python.ops import gen_candidate_sampling_ops
from tensorflow.python.ops import math_ops # pylint: disable=unused-import from tensorflow.python.ops import math_ops # pylint: disable=unused-import
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@tf_export( @tf_export(
'random.uniform_candidate_sampler', 'random.uniform_candidate_sampler',
v1=['random.uniform_candidate_sampler', 'nn.uniform_candidate_sampler']) v1=['random.uniform_candidate_sampler', 'nn.uniform_candidate_sampler'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('nn.uniform_candidate_sampler') @deprecation.deprecated_endpoints('nn.uniform_candidate_sampler')
def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
range_max, seed=None, name=None): range_max, seed=None, name=None):
@ -92,6 +94,7 @@ def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
'random.log_uniform_candidate_sampler', 'random.log_uniform_candidate_sampler',
'nn.log_uniform_candidate_sampler' 'nn.log_uniform_candidate_sampler'
]) ])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('nn.log_uniform_candidate_sampler') @deprecation.deprecated_endpoints('nn.log_uniform_candidate_sampler')
def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
range_max, seed=None, name=None): range_max, seed=None, name=None):
@ -154,6 +157,7 @@ def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
@tf_export( @tf_export(
'random.learned_unigram_candidate_sampler', 'random.learned_unigram_candidate_sampler',
'nn.learned_unigram_candidate_sampler') 'nn.learned_unigram_candidate_sampler')
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints(['nn.learned_unigram_candidate_sampler']) @deprecation.deprecated_endpoints(['nn.learned_unigram_candidate_sampler'])
def learned_unigram_candidate_sampler(true_classes, num_true, num_sampled, def learned_unigram_candidate_sampler(true_classes, num_true, num_sampled,
unique, range_max, seed=None, name=None): unique, range_max, seed=None, name=None):
@ -213,6 +217,7 @@ def learned_unigram_candidate_sampler(true_classes, num_true, num_sampled,
@tf_export('random.fixed_unigram_candidate_sampler', @tf_export('random.fixed_unigram_candidate_sampler',
'nn.fixed_unigram_candidate_sampler') 'nn.fixed_unigram_candidate_sampler')
@dispatch.add_dispatch_support
def fixed_unigram_candidate_sampler(true_classes, def fixed_unigram_candidate_sampler(true_classes,
num_true, num_true,
num_sampled, num_sampled,
@ -341,6 +346,7 @@ def all_candidate_sampler(true_classes, num_true, num_sampled, unique,
@tf_export('nn.compute_accidental_hits') @tf_export('nn.compute_accidental_hits')
@dispatch.add_dispatch_support
def compute_accidental_hits(true_classes, sampled_candidates, num_true, def compute_accidental_hits(true_classes, sampled_candidates, num_true,
seed=None, name=None): seed=None, name=None):
"""Compute the position ids in `sampled_candidates` matching `true_classes`. """Compute the position ids in `sampled_candidates` matching `true_classes`.

View File

@ -35,6 +35,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
NUMERIC_TYPES = frozenset( NUMERIC_TYPES = frozenset(
@ -375,6 +376,7 @@ def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
@tf_export( @tf_export(
'debugging.assert_proper_iterable', 'debugging.assert_proper_iterable',
v1=['debugging.assert_proper_iterable', 'assert_proper_iterable']) v1=['debugging.assert_proper_iterable', 'assert_proper_iterable'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_proper_iterable') @deprecation.deprecated_endpoints('assert_proper_iterable')
def assert_proper_iterable(values): def assert_proper_iterable(values):
"""Static assert that values is a "proper" iterable. """Static assert that values is a "proper" iterable.
@ -404,6 +406,7 @@ def assert_proper_iterable(values):
@tf_export('debugging.assert_negative', v1=[]) @tf_export('debugging.assert_negative', v1=[])
@dispatch.add_dispatch_support
def assert_negative_v2(x, message=None, summarize=None, name=None): def assert_negative_v2(x, message=None, summarize=None, name=None):
"""Assert the condition `x < 0` holds element-wise. """Assert the condition `x < 0` holds element-wise.
@ -436,6 +439,7 @@ def assert_negative_v2(x, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_negative', 'assert_negative']) @tf_export(v1=['debugging.assert_negative', 'assert_negative'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_negative') @deprecation.deprecated_endpoints('assert_negative')
@_unary_assert_doc('< 0', 'negative') @_unary_assert_doc('< 0', 'negative')
def assert_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring def assert_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
@ -456,6 +460,7 @@ def assert_negative(x, data=None, summarize=None, message=None, name=None): # p
@tf_export('debugging.assert_positive', v1=[]) @tf_export('debugging.assert_positive', v1=[])
@dispatch.add_dispatch_support
def assert_positive_v2(x, message=None, summarize=None, name=None): def assert_positive_v2(x, message=None, summarize=None, name=None):
"""Assert the condition `x > 0` holds element-wise. """Assert the condition `x > 0` holds element-wise.
@ -488,6 +493,7 @@ def assert_positive_v2(x, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_positive', 'assert_positive']) @tf_export(v1=['debugging.assert_positive', 'assert_positive'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_positive') @deprecation.deprecated_endpoints('assert_positive')
@_unary_assert_doc('> 0', 'positive') @_unary_assert_doc('> 0', 'positive')
def assert_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring def assert_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
@ -507,6 +513,7 @@ def assert_positive(x, data=None, summarize=None, message=None, name=None): # p
@tf_export('debugging.assert_non_negative', v1=[]) @tf_export('debugging.assert_non_negative', v1=[])
@dispatch.add_dispatch_support
def assert_non_negative_v2(x, message=None, summarize=None, name=None): def assert_non_negative_v2(x, message=None, summarize=None, name=None):
"""Assert the condition `x >= 0` holds element-wise. """Assert the condition `x >= 0` holds element-wise.
@ -541,6 +548,7 @@ def assert_non_negative_v2(x, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_non_negative', 'assert_non_negative']) @tf_export(v1=['debugging.assert_non_negative', 'assert_non_negative'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_non_negative') @deprecation.deprecated_endpoints('assert_non_negative')
@_unary_assert_doc('>= 0', 'non-negative') @_unary_assert_doc('>= 0', 'non-negative')
def assert_non_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring def assert_non_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
@ -561,6 +569,7 @@ def assert_non_negative(x, data=None, summarize=None, message=None, name=None):
@tf_export('debugging.assert_non_positive', v1=[]) @tf_export('debugging.assert_non_positive', v1=[])
@dispatch.add_dispatch_support
def assert_non_positive_v2(x, message=None, summarize=None, name=None): def assert_non_positive_v2(x, message=None, summarize=None, name=None):
"""Assert the condition `x <= 0` holds element-wise. """Assert the condition `x <= 0` holds element-wise.
@ -595,6 +604,7 @@ def assert_non_positive_v2(x, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_non_positive', 'assert_non_positive']) @tf_export(v1=['debugging.assert_non_positive', 'assert_non_positive'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_non_positive') @deprecation.deprecated_endpoints('assert_non_positive')
@_unary_assert_doc('<= 0', 'non-positive') @_unary_assert_doc('<= 0', 'non-positive')
def assert_non_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring def assert_non_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
@ -615,6 +625,7 @@ def assert_non_positive(x, data=None, summarize=None, message=None, name=None):
@tf_export('debugging.assert_equal', 'assert_equal', v1=[]) @tf_export('debugging.assert_equal', 'assert_equal', v1=[])
@dispatch.add_dispatch_support
def assert_equal_v2(x, y, message=None, summarize=None, name=None): def assert_equal_v2(x, y, message=None, summarize=None, name=None):
"""Assert the condition `x == y` holds element-wise. """Assert the condition `x == y` holds element-wise.
@ -649,6 +660,7 @@ def assert_equal_v2(x, y, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_equal', 'assert_equal']) @tf_export(v1=['debugging.assert_equal', 'assert_equal'])
@dispatch.add_dispatch_support
@_binary_assert_doc('==') @_binary_assert_doc('==')
def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
with ops.name_scope(name, 'assert_equal', [x, y, data]): with ops.name_scope(name, 'assert_equal', [x, y, data]):
@ -660,6 +672,7 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # p
@tf_export('debugging.assert_none_equal', v1=[]) @tf_export('debugging.assert_none_equal', v1=[])
@dispatch.add_dispatch_support
def assert_none_equal_v2(x, y, summarize=None, message=None, name=None): def assert_none_equal_v2(x, y, summarize=None, message=None, name=None):
"""Assert the condition `x != y` holds for all elements. """Assert the condition `x != y` holds for all elements.
@ -698,6 +711,7 @@ def assert_none_equal_v2(x, y, summarize=None, message=None, name=None):
@tf_export(v1=['debugging.assert_none_equal', 'assert_none_equal']) @tf_export(v1=['debugging.assert_none_equal', 'assert_none_equal'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_none_equal') @deprecation.deprecated_endpoints('assert_none_equal')
@_binary_assert_doc('!=') @_binary_assert_doc('!=')
def assert_none_equal( def assert_none_equal(
@ -707,6 +721,7 @@ def assert_none_equal(
@tf_export('debugging.assert_near', v1=[]) @tf_export('debugging.assert_near', v1=[])
@dispatch.add_dispatch_support
def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None, def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None,
name=None): name=None):
"""Assert the condition `x` and `y` are close element-wise. """Assert the condition `x` and `y` are close element-wise.
@ -760,6 +775,7 @@ def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None,
@tf_export(v1=['debugging.assert_near', 'assert_near']) @tf_export(v1=['debugging.assert_near', 'assert_near'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_near') @deprecation.deprecated_endpoints('assert_near')
def assert_near( def assert_near(
x, y, rtol=None, atol=None, data=None, summarize=None, message=None, x, y, rtol=None, atol=None, data=None, summarize=None, message=None,
@ -839,6 +855,7 @@ def assert_near(
@tf_export('debugging.assert_less', 'assert_less', v1=[]) @tf_export('debugging.assert_less', 'assert_less', v1=[])
@dispatch.add_dispatch_support
def assert_less_v2(x, y, message=None, summarize=None, name=None): def assert_less_v2(x, y, message=None, summarize=None, name=None):
"""Assert the condition `x < y` holds element-wise. """Assert the condition `x < y` holds element-wise.
@ -874,6 +891,7 @@ def assert_less_v2(x, y, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_less', 'assert_less']) @tf_export(v1=['debugging.assert_less', 'assert_less'])
@dispatch.add_dispatch_support
@_binary_assert_doc('<') @_binary_assert_doc('<')
def assert_less(x, y, data=None, summarize=None, message=None, name=None): def assert_less(x, y, data=None, summarize=None, message=None, name=None):
return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data, return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data,
@ -881,6 +899,7 @@ def assert_less(x, y, data=None, summarize=None, message=None, name=None):
@tf_export('debugging.assert_less_equal', v1=[]) @tf_export('debugging.assert_less_equal', v1=[])
@dispatch.add_dispatch_support
def assert_less_equal_v2(x, y, message=None, summarize=None, name=None): def assert_less_equal_v2(x, y, message=None, summarize=None, name=None):
"""Assert the condition `x <= y` holds element-wise. """Assert the condition `x <= y` holds element-wise.
@ -917,6 +936,7 @@ def assert_less_equal_v2(x, y, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_less_equal', 'assert_less_equal']) @tf_export(v1=['debugging.assert_less_equal', 'assert_less_equal'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_less_equal') @deprecation.deprecated_endpoints('assert_less_equal')
@_binary_assert_doc('<=') @_binary_assert_doc('<=')
def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None): def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
@ -925,6 +945,7 @@ def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
@tf_export('debugging.assert_greater', 'assert_greater', v1=[]) @tf_export('debugging.assert_greater', 'assert_greater', v1=[])
@dispatch.add_dispatch_support
def assert_greater_v2(x, y, message=None, summarize=None, name=None): def assert_greater_v2(x, y, message=None, summarize=None, name=None):
"""Assert the condition `x > y` holds element-wise. """Assert the condition `x > y` holds element-wise.
@ -961,6 +982,7 @@ def assert_greater_v2(x, y, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_greater', 'assert_greater']) @tf_export(v1=['debugging.assert_greater', 'assert_greater'])
@dispatch.add_dispatch_support
@_binary_assert_doc('>') @_binary_assert_doc('>')
def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x, return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x,
@ -968,6 +990,7 @@ def assert_greater(x, y, data=None, summarize=None, message=None, name=None): #
@tf_export('debugging.assert_greater_equal', v1=[]) @tf_export('debugging.assert_greater_equal', v1=[])
@dispatch.add_dispatch_support
def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None): def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None):
"""Assert the condition `x >= y` holds element-wise. """Assert the condition `x >= y` holds element-wise.
@ -1005,6 +1028,7 @@ def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_greater_equal', 'assert_greater_equal']) @tf_export(v1=['debugging.assert_greater_equal', 'assert_greater_equal'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_greater_equal') @deprecation.deprecated_endpoints('assert_greater_equal')
@_binary_assert_doc('>=') @_binary_assert_doc('>=')
def assert_greater_equal(x, y, data=None, summarize=None, message=None, def assert_greater_equal(x, y, data=None, summarize=None, message=None,
@ -1062,6 +1086,7 @@ def _assert_rank_condition(
@tf_export('debugging.assert_rank', 'assert_rank', v1=[]) @tf_export('debugging.assert_rank', 'assert_rank', v1=[])
@dispatch.add_dispatch_support
def assert_rank_v2(x, rank, message=None, name=None): def assert_rank_v2(x, rank, message=None, name=None):
"""Assert that `x` has rank equal to `rank`. """Assert that `x` has rank equal to `rank`.
@ -1095,6 +1120,7 @@ def assert_rank_v2(x, rank, message=None, name=None):
@tf_export(v1=['debugging.assert_rank', 'assert_rank']) @tf_export(v1=['debugging.assert_rank', 'assert_rank'])
@dispatch.add_dispatch_support
def assert_rank(x, rank, data=None, summarize=None, message=None, name=None): def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
"""Assert `x` has rank equal to `rank`. """Assert `x` has rank equal to `rank`.
@ -1157,6 +1183,7 @@ def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
@tf_export('debugging.assert_rank_at_least', v1=[]) @tf_export('debugging.assert_rank_at_least', v1=[])
@dispatch.add_dispatch_support
def assert_rank_at_least_v2(x, rank, message=None, name=None): def assert_rank_at_least_v2(x, rank, message=None, name=None):
"""Assert that `x` has rank of at least `rank`. """Assert that `x` has rank of at least `rank`.
@ -1190,6 +1217,7 @@ def assert_rank_at_least_v2(x, rank, message=None, name=None):
@tf_export(v1=['debugging.assert_rank_at_least', 'assert_rank_at_least']) @tf_export(v1=['debugging.assert_rank_at_least', 'assert_rank_at_least'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_rank_at_least') @deprecation.deprecated_endpoints('assert_rank_at_least')
def assert_rank_at_least( def assert_rank_at_least(
x, rank, data=None, summarize=None, message=None, name=None): x, rank, data=None, summarize=None, message=None, name=None):
@ -1322,6 +1350,7 @@ def _assert_ranks_condition(
@tf_export('debugging.assert_rank_in', v1=[]) @tf_export('debugging.assert_rank_in', v1=[])
@dispatch.add_dispatch_support
def assert_rank_in_v2(x, ranks, message=None, name=None): def assert_rank_in_v2(x, ranks, message=None, name=None):
"""Assert that `x` has a rank in `ranks`. """Assert that `x` has a rank in `ranks`.
@ -1354,6 +1383,7 @@ def assert_rank_in_v2(x, ranks, message=None, name=None):
@tf_export(v1=['debugging.assert_rank_in', 'assert_rank_in']) @tf_export(v1=['debugging.assert_rank_in', 'assert_rank_in'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_rank_in') @deprecation.deprecated_endpoints('assert_rank_in')
def assert_rank_in( def assert_rank_in(
x, ranks, data=None, summarize=None, message=None, name=None): x, ranks, data=None, summarize=None, message=None, name=None):
@ -1417,6 +1447,7 @@ def assert_rank_in(
@tf_export('debugging.assert_integer', v1=[]) @tf_export('debugging.assert_integer', v1=[])
@dispatch.add_dispatch_support
def assert_integer_v2(x, message=None, name=None): def assert_integer_v2(x, message=None, name=None):
"""Assert that `x` is of integer dtype. """Assert that `x` is of integer dtype.
@ -1437,6 +1468,7 @@ def assert_integer_v2(x, message=None, name=None):
@tf_export(v1=['debugging.assert_integer', 'assert_integer']) @tf_export(v1=['debugging.assert_integer', 'assert_integer'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_integer') @deprecation.deprecated_endpoints('assert_integer')
def assert_integer(x, message=None, name=None): def assert_integer(x, message=None, name=None):
"""Assert that `x` is of integer dtype. """Assert that `x` is of integer dtype.
@ -1476,6 +1508,7 @@ def assert_integer(x, message=None, name=None):
@tf_export('debugging.assert_type', v1=[]) @tf_export('debugging.assert_type', v1=[])
@dispatch.add_dispatch_support
def assert_type_v2(tensor, tf_type, message=None, name=None): def assert_type_v2(tensor, tf_type, message=None, name=None):
"""Asserts that the given `Tensor` is of the specified type. """Asserts that the given `Tensor` is of the specified type.
@ -1495,6 +1528,7 @@ def assert_type_v2(tensor, tf_type, message=None, name=None):
@tf_export(v1=['debugging.assert_type', 'assert_type']) @tf_export(v1=['debugging.assert_type', 'assert_type'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_type') @deprecation.deprecated_endpoints('assert_type')
def assert_type(tensor, tf_type, message=None, name=None): def assert_type(tensor, tf_type, message=None, name=None):
"""Statically asserts that the given `Tensor` is of the specified type. """Statically asserts that the given `Tensor` is of the specified type.
@ -1584,6 +1618,7 @@ _TensorDimSizes = collections.namedtuple(
@tf_export('debugging.assert_shapes', v1=[]) @tf_export('debugging.assert_shapes', v1=[])
@dispatch.add_dispatch_support
def assert_shapes_v2(shapes, data=None, summarize=None, message=None, def assert_shapes_v2(shapes, data=None, summarize=None, message=None,
name=None): name=None):
"""Assert tensor shapes and dimension size relationships between tensors. """Assert tensor shapes and dimension size relationships between tensors.
@ -1650,6 +1685,7 @@ def assert_shapes_v2(shapes, data=None, summarize=None, message=None,
@tf_export(v1=['debugging.assert_shapes']) @tf_export(v1=['debugging.assert_shapes'])
@dispatch.add_dispatch_support
def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
"""Assert tensor shapes and dimension size relationships between tensors. """Assert tensor shapes and dimension size relationships between tensors.
@ -1939,6 +1975,7 @@ def is_numeric_tensor(tensor):
'math.is_non_decreasing', 'debugging.is_non_decreasing', 'math.is_non_decreasing', 'debugging.is_non_decreasing',
'is_non_decreasing' 'is_non_decreasing'
]) ])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('debugging.is_non_decreasing', @deprecation.deprecated_endpoints('debugging.is_non_decreasing',
'is_non_decreasing') 'is_non_decreasing')
def is_non_decreasing(x, name=None): def is_non_decreasing(x, name=None):
@ -1980,6 +2017,7 @@ def is_non_decreasing(x, name=None):
'math.is_strictly_increasing', 'debugging.is_strictly_increasing', 'math.is_strictly_increasing', 'debugging.is_strictly_increasing',
'is_strictly_increasing' 'is_strictly_increasing'
]) ])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('debugging.is_strictly_increasing', @deprecation.deprecated_endpoints('debugging.is_strictly_increasing',
'is_strictly_increasing') 'is_strictly_increasing')
def is_strictly_increasing(x, name=None): def is_strictly_increasing(x, name=None):
@ -2066,6 +2104,7 @@ def _assert_same_base_type(items, expected_type=None):
@tf_export( @tf_export(
'debugging.assert_same_float_dtype', 'debugging.assert_same_float_dtype',
v1=['debugging.assert_same_float_dtype', 'assert_same_float_dtype']) v1=['debugging.assert_same_float_dtype', 'assert_same_float_dtype'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_same_float_dtype') @deprecation.deprecated_endpoints('assert_same_float_dtype')
def assert_same_float_dtype(tensors=None, dtype=None): def assert_same_float_dtype(tensors=None, dtype=None):
"""Validate and return float type based on `tensors` and `dtype`. """Validate and return float type based on `tensors` and `dtype`.
@ -2098,6 +2137,7 @@ def assert_same_float_dtype(tensors=None, dtype=None):
@tf_export('debugging.assert_scalar', v1=[]) @tf_export('debugging.assert_scalar', v1=[])
@dispatch.add_dispatch_support
def assert_scalar_v2(tensor, message=None, name=None): def assert_scalar_v2(tensor, message=None, name=None):
"""Asserts that the given `tensor` is a scalar. """Asserts that the given `tensor` is a scalar.
@ -2120,6 +2160,7 @@ def assert_scalar_v2(tensor, message=None, name=None):
@tf_export(v1=['debugging.assert_scalar', 'assert_scalar']) @tf_export(v1=['debugging.assert_scalar', 'assert_scalar'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_scalar') @deprecation.deprecated_endpoints('assert_scalar')
def assert_scalar(tensor, name=None, message=None): def assert_scalar(tensor, name=None, message=None):
"""Asserts that the given `tensor` is a scalar (i.e. zero-dimensional). """Asserts that the given `tensor` is a scalar (i.e. zero-dimensional).
@ -2154,6 +2195,7 @@ def assert_scalar(tensor, name=None, message=None):
@tf_export('ensure_shape') @tf_export('ensure_shape')
@dispatch.add_dispatch_support
def ensure_shape(x, shape, name=None): def ensure_shape(x, shape, name=None):
"""Updates the shape of a tensor and checks at runtime that the shape holds. """Updates the shape of a tensor and checks at runtime that the shape holds.

View File

@ -152,6 +152,7 @@ def _clip_by_value_grad(op, grad):
@tf_export("clip_by_norm") @tf_export("clip_by_norm")
@dispatch.add_dispatch_support
def clip_by_norm(t, clip_norm, axes=None, name=None): def clip_by_norm(t, clip_norm, axes=None, name=None):
"""Clips tensor values to a maximum L2-norm. """Clips tensor values to a maximum L2-norm.
@ -235,6 +236,7 @@ def clip_by_norm(t, clip_norm, axes=None, name=None):
@tf_export("linalg.global_norm", v1=["linalg.global_norm", "global_norm"]) @tf_export("linalg.global_norm", v1=["linalg.global_norm", "global_norm"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("global_norm") @deprecation.deprecated_endpoints("global_norm")
def global_norm(t_list, name=None): def global_norm(t_list, name=None):
"""Computes the global norm of multiple tensors. """Computes the global norm of multiple tensors.
@ -285,6 +287,7 @@ def global_norm(t_list, name=None):
@tf_export("clip_by_global_norm") @tf_export("clip_by_global_norm")
@dispatch.add_dispatch_support
def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None): def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
"""Clips values of multiple tensors by the ratio of the sum of their norms. """Clips values of multiple tensors by the ratio of the sum of their norms.
@ -382,6 +385,7 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
"use clip_by_norm(t, clip_norm * tf.cast(tf.size(t), tf.float32), name) " "use clip_by_norm(t, clip_norm * tf.cast(tf.size(t), tf.float32), name) "
"instead.") "instead.")
@tf_export(v1=["clip_by_average_norm"]) @tf_export(v1=["clip_by_average_norm"])
@dispatch.add_dispatch_support
def clip_by_average_norm(t, clip_norm, name=None): def clip_by_average_norm(t, clip_norm, name=None):
"""Clips tensor values to a maximum average L2-norm. """Clips tensor values to a maximum average L2-norm.

View File

@ -27,6 +27,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import sparse_ops
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -93,6 +94,7 @@ def remove_squeezable_dimensions(
@tf_export('math.confusion_matrix', v1=[]) @tf_export('math.confusion_matrix', v1=[])
@dispatch.add_dispatch_support
def confusion_matrix(labels, def confusion_matrix(labels,
predictions, predictions,
num_classes=None, num_classes=None,
@ -202,6 +204,7 @@ def confusion_matrix(labels,
@tf_export(v1=['math.confusion_matrix', 'confusion_matrix']) @tf_export(v1=['math.confusion_matrix', 'confusion_matrix'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('confusion_matrix', 'train.confusion_matrix') @deprecation.deprecated_endpoints('confusion_matrix', 'train.confusion_matrix')
def confusion_matrix_v1(labels, def confusion_matrix_v1(labels,
predictions, predictions,

View File

@ -54,6 +54,7 @@ from tensorflow.python.ops.gen_control_flow_ops import *
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_should_use from tensorflow.python.util import tf_should_use
from tensorflow.python.util.lazy_loader import LazyLoader from tensorflow.python.util.lazy_loader import LazyLoader
@ -110,6 +111,7 @@ def _summarize_eager(tensor, summarize=None):
# Assert and Print are special symbols in python, so we must # Assert and Print are special symbols in python, so we must
# use an upper-case version of them. # use an upper-case version of them.
@tf_export("debugging.Assert", "Assert") @tf_export("debugging.Assert", "Assert")
@dispatch.add_dispatch_support
@tf_should_use.should_use_result @tf_should_use.should_use_result
def Assert(condition, data, summarize=None, name=None): def Assert(condition, data, summarize=None, name=None):
"""Asserts that the given condition is true. """Asserts that the given condition is true.
@ -1095,6 +1097,7 @@ def _UnpackIfSingleton(res):
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
# pylint: disable=g-doc-args # pylint: disable=g-doc-args
@tf_export(v1=["cond"]) @tf_export(v1=["cond"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args( @deprecation.deprecated_args(
None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.", None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
"fn1", "fn2") "fn1", "fn2")
@ -1318,6 +1321,7 @@ def _cast_indexed_slice_indices(a, b):
@tf_export("cond", v1=[]) @tf_export("cond", v1=[])
@dispatch.add_dispatch_support
def cond_for_tf_v2(pred, true_fn=None, false_fn=None, name=None): def cond_for_tf_v2(pred, true_fn=None, false_fn=None, name=None):
"""Return `true_fn()` if the predicate `pred` is true else `false_fn()`. """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
@ -2942,6 +2946,7 @@ def group(*inputs, **kwargs):
@tf_export("tuple", v1=[]) @tf_export("tuple", v1=[])
@dispatch.add_dispatch_support
def tuple_v2(tensors, control_inputs=None, name=None): def tuple_v2(tensors, control_inputs=None, name=None):
"""Group tensors together. """Group tensors together.
@ -2978,6 +2983,7 @@ def tuple_v2(tensors, control_inputs=None, name=None):
@tf_export(v1=["tuple"]) @tf_export(v1=["tuple"])
@dispatch.add_dispatch_support
def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined-builtin def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined-builtin
"""Group tensors together. """Group tensors together.
@ -3312,6 +3318,7 @@ def _indexed_case_helper(branch_fns, default, branch_index, name):
@tf_export("case", v1=[]) @tf_export("case", v1=[])
@dispatch.add_dispatch_support
def case_v2(pred_fn_pairs, def case_v2(pred_fn_pairs,
default=None, default=None,
exclusive=False, exclusive=False,
@ -3416,6 +3423,7 @@ def case_v2(pred_fn_pairs,
@tf_export(v1=["case"]) @tf_export(v1=["case"])
@dispatch.add_dispatch_support
def case(pred_fn_pairs, def case(pred_fn_pairs,
default=None, default=None,
exclusive=False, exclusive=False,

View File

@ -43,6 +43,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.nn_grad import _BroadcastMul from tensorflow.python.ops.nn_grad import _BroadcastMul
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -70,6 +71,7 @@ def _generate_defun_backend(unique_api_name, preferred_device, func):
# pylint: disable=protected-access, invalid-name # pylint: disable=protected-access, invalid-name
@tf_export(v1=["nn.ctc_loss"]) @tf_export(v1=["nn.ctc_loss"])
@dispatch.add_dispatch_support
def ctc_loss(labels, def ctc_loss(labels,
inputs=None, inputs=None,
sequence_length=None, sequence_length=None,
@ -284,6 +286,7 @@ def _CTCLossV2Grad(op, grad_loss, _):
@tf_export("nn.ctc_greedy_decoder") @tf_export("nn.ctc_greedy_decoder")
@dispatch.add_dispatch_support
def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True): def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
"""Performs greedy decoding on the logits given in input (best path). """Performs greedy decoding on the logits given in input (best path).
@ -333,6 +336,7 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
@tf_export(v1=["nn.ctc_beam_search_decoder"]) @tf_export(v1=["nn.ctc_beam_search_decoder"])
@dispatch.add_dispatch_support
def ctc_beam_search_decoder(inputs, def ctc_beam_search_decoder(inputs,
sequence_length, sequence_length,
beam_width=100, beam_width=100,
@ -395,6 +399,7 @@ def ctc_beam_search_decoder(inputs,
@tf_export("nn.ctc_beam_search_decoder", v1=["nn.ctc_beam_search_decoder_v2"]) @tf_export("nn.ctc_beam_search_decoder", v1=["nn.ctc_beam_search_decoder_v2"])
@dispatch.add_dispatch_support
def ctc_beam_search_decoder_v2(inputs, def ctc_beam_search_decoder_v2(inputs,
sequence_length, sequence_length,
beam_width=100, beam_width=100,
@ -731,6 +736,7 @@ def _ctc_loss_shape(op):
# pylint: disable=protected-access, invalid-name # pylint: disable=protected-access, invalid-name
@tf_export(v1=["nn.ctc_loss_v2"]) @tf_export(v1=["nn.ctc_loss_v2"])
@dispatch.add_dispatch_support
def ctc_loss_v2(labels, def ctc_loss_v2(labels,
logits, logits,
label_length, label_length,
@ -825,6 +831,7 @@ def ctc_loss_v2(labels,
@tf_export("nn.ctc_loss", v1=[]) @tf_export("nn.ctc_loss", v1=[])
@dispatch.add_dispatch_support
def ctc_loss_v3(labels, def ctc_loss_v3(labels,
logits, logits,
label_length, label_length,
@ -1056,6 +1063,7 @@ def ctc_loss_dense(labels,
@tf_export("nn.collapse_repeated") @tf_export("nn.collapse_repeated")
@dispatch.add_dispatch_support
def collapse_repeated(labels, seq_length, name=None): def collapse_repeated(labels, seq_length, name=None):
"""Merge repeated labels into single labels. """Merge repeated labels into single labels.
@ -1153,6 +1161,7 @@ def dense_labels_to_sparse(dense, length):
@tf_export("nn.ctc_unique_labels") @tf_export("nn.ctc_unique_labels")
@dispatch.add_dispatch_support
def ctc_unique_labels(labels, name=None): def ctc_unique_labels(labels, name=None):
"""Get unique labels and indices for batched labels for `tf.nn.ctc_loss`. """Get unique labels and indices for batched labels for `tf.nn.ctc_loss`.

View File

@ -36,6 +36,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_functional_ops from tensorflow.python.ops.ragged import ragged_functional_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -250,6 +251,7 @@ def _embedding_lookup_and_transform(params,
@tf_export(v1=["nn.embedding_lookup"]) @tf_export(v1=["nn.embedding_lookup"])
@dispatch.add_dispatch_support
def embedding_lookup( def embedding_lookup(
params, params,
ids, ids,
@ -327,6 +329,7 @@ def embedding_lookup(
@tf_export("nn.embedding_lookup", v1=[]) @tf_export("nn.embedding_lookup", v1=[])
@dispatch.add_dispatch_support
def embedding_lookup_v2(params, ids, max_norm=None, name=None): def embedding_lookup_v2(params, ids, max_norm=None, name=None):
"""Looks up embeddings for the given `ids` from a list of tensors. """Looks up embeddings for the given `ids` from a list of tensors.
@ -392,6 +395,7 @@ def embedding_lookup_v2(params, ids, max_norm=None, name=None):
@tf_export(v1=["nn.embedding_lookup_sparse"]) @tf_export(v1=["nn.embedding_lookup_sparse"])
@dispatch.add_dispatch_support
def embedding_lookup_sparse(params, def embedding_lookup_sparse(params,
sp_ids, sp_ids,
sp_weights, sp_weights,
@ -574,6 +578,7 @@ def embedding_lookup_sparse(params,
@tf_export("nn.embedding_lookup_sparse", v1=[]) @tf_export("nn.embedding_lookup_sparse", v1=[])
@dispatch.add_dispatch_support
def embedding_lookup_sparse_v2(params, def embedding_lookup_sparse_v2(params,
sp_ids, sp_ids,
sp_weights, sp_weights,
@ -664,6 +669,7 @@ def embedding_lookup_sparse_v2(params,
@tf_export("nn.safe_embedding_lookup_sparse", v1=[]) @tf_export("nn.safe_embedding_lookup_sparse", v1=[])
@dispatch.add_dispatch_support
def safe_embedding_lookup_sparse_v2(embedding_weights, def safe_embedding_lookup_sparse_v2(embedding_weights,
sparse_ids, sparse_ids,
sparse_weights=None, sparse_weights=None,
@ -765,6 +771,7 @@ def safe_embedding_lookup_sparse_v2(embedding_weights,
@tf_export(v1=["nn.safe_embedding_lookup_sparse"]) @tf_export(v1=["nn.safe_embedding_lookup_sparse"])
@dispatch.add_dispatch_support
def safe_embedding_lookup_sparse(embedding_weights, def safe_embedding_lookup_sparse(embedding_weights,
sparse_ids, sparse_ids,
sparse_weights=None, sparse_weights=None,

View File

@ -38,6 +38,7 @@ from tensorflow.python.ops.gen_functional_ops import remote_call
from tensorflow.python.ops.gen_functional_ops import symbolic_gradient from tensorflow.python.ops.gen_functional_ops import symbolic_gradient
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import function_utils from tensorflow.python.util import function_utils
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -45,6 +46,7 @@ from tensorflow.python.util.tf_export import tf_export
# TODO(yuanbyu, mrry): Handle stride to support sliding windows. # TODO(yuanbyu, mrry): Handle stride to support sliding windows.
@tf_export(v1=["foldl"]) @tf_export(v1=["foldl"])
@dispatch.add_dispatch_support
def foldl(fn, def foldl(fn,
elems, elems,
initializer=None, initializer=None,
@ -162,6 +164,7 @@ def foldl(fn,
@tf_export("foldl", v1=[]) @tf_export("foldl", v1=[])
@dispatch.add_dispatch_support
@deprecation.deprecated_arg_values( @deprecation.deprecated_arg_values(
None, None,
"""back_prop=False is deprecated. Consider using tf.stop_gradient instead. """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
@ -238,6 +241,7 @@ def foldl_v2(fn,
@tf_export(v1=["foldr"]) @tf_export(v1=["foldr"])
@dispatch.add_dispatch_support
def foldr(fn, def foldr(fn,
elems, elems,
initializer=None, initializer=None,
@ -356,6 +360,7 @@ def foldr(fn,
@tf_export("foldr", v1=[]) @tf_export("foldr", v1=[])
@dispatch.add_dispatch_support
@deprecation.deprecated_arg_values( @deprecation.deprecated_arg_values(
None, None,
"""back_prop=False is deprecated. Consider using tf.stop_gradient instead. """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
@ -432,6 +437,7 @@ def foldr_v2(fn,
@tf_export(v1=["scan"]) @tf_export(v1=["scan"])
@dispatch.add_dispatch_support
def scan(fn, def scan(fn,
elems, elems,
initializer=None, initializer=None,
@ -686,6 +692,7 @@ def scan(fn,
@tf_export("scan", v1=[]) @tf_export("scan", v1=[])
@dispatch.add_dispatch_support
@deprecation.deprecated_arg_values( @deprecation.deprecated_arg_values(
None, None,
"""back_prop=False is deprecated. Consider using tf.stop_gradient instead. """back_prop=False is deprecated. Consider using tf.stop_gradient instead.

View File

@ -26,10 +26,12 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@tf_export('histogram_fixed_width_bins') @tf_export('histogram_fixed_width_bins')
@dispatch.add_dispatch_support
def histogram_fixed_width_bins(values, def histogram_fixed_width_bins(values,
value_range, value_range,
nbins=100, nbins=100,
@ -101,6 +103,7 @@ def histogram_fixed_width_bins(values,
@tf_export('histogram_fixed_width') @tf_export('histogram_fixed_width')
@dispatch.add_dispatch_support
def histogram_fixed_width(values, def histogram_fixed_width(values,
value_range, value_range,
nbins=100, nbins=100,

View File

@ -39,6 +39,7 @@ from tensorflow.python.ops import sort_ops
from tensorflow.python.ops import string_ops from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
ops.NotDifferentiable('RandomCrop') ops.NotDifferentiable('RandomCrop')
@ -323,6 +324,7 @@ def fix_image_flip_shape(image, result):
@tf_export('image.random_flip_up_down') @tf_export('image.random_flip_up_down')
@dispatch.add_dispatch_support
def random_flip_up_down(image, seed=None): def random_flip_up_down(image, seed=None):
"""Randomly flips an image vertically (upside down). """Randomly flips an image vertically (upside down).
@ -363,6 +365,7 @@ def random_flip_up_down(image, seed=None):
@tf_export('image.random_flip_left_right') @tf_export('image.random_flip_left_right')
@dispatch.add_dispatch_support
def random_flip_left_right(image, seed=None): def random_flip_left_right(image, seed=None):
"""Randomly flip an image horizontally (left to right). """Randomly flip an image horizontally (left to right).
@ -450,6 +453,7 @@ def _random_flip(image, flip_index, seed, scope_name):
@tf_export('image.flip_left_right') @tf_export('image.flip_left_right')
@dispatch.add_dispatch_support
def flip_left_right(image): def flip_left_right(image):
"""Flip an image horizontally (left to right). """Flip an image horizontally (left to right).
@ -484,6 +488,7 @@ def flip_left_right(image):
@tf_export('image.flip_up_down') @tf_export('image.flip_up_down')
@dispatch.add_dispatch_support
def flip_up_down(image): def flip_up_down(image):
"""Flip an image vertically (upside down). """Flip an image vertically (upside down).
@ -549,6 +554,7 @@ def _flip(image, flip_index, scope_name):
@tf_export('image.rot90') @tf_export('image.rot90')
@dispatch.add_dispatch_support
def rot90(image, k=1, name=None): def rot90(image, k=1, name=None):
"""Rotate image(s) counter-clockwise by 90 degrees. """Rotate image(s) counter-clockwise by 90 degrees.
@ -660,6 +666,7 @@ def _rot90_4D(images, k, name_scope):
@tf_export('image.transpose', v1=['image.transpose', 'image.transpose_image']) @tf_export('image.transpose', v1=['image.transpose', 'image.transpose_image'])
@dispatch.add_dispatch_support
def transpose(image, name=None): def transpose(image, name=None):
"""Transpose image(s) by swapping the height and width dimension. """Transpose image(s) by swapping the height and width dimension.
@ -718,6 +725,7 @@ def transpose(image, name=None):
@tf_export('image.central_crop') @tf_export('image.central_crop')
@dispatch.add_dispatch_support
def central_crop(image, central_fraction): def central_crop(image, central_fraction):
"""Crop the central region of the image(s). """Crop the central region of the image(s).
@ -850,6 +858,7 @@ def central_crop(image, central_fraction):
@tf_export('image.pad_to_bounding_box') @tf_export('image.pad_to_bounding_box')
@dispatch.add_dispatch_support
def pad_to_bounding_box(image, offset_height, offset_width, target_height, def pad_to_bounding_box(image, offset_height, offset_width, target_height,
target_width): target_width):
"""Pad `image` with zeros to the specified `height` and `width`. """Pad `image` with zeros to the specified `height` and `width`.
@ -959,6 +968,7 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height,
@tf_export('image.crop_to_bounding_box') @tf_export('image.crop_to_bounding_box')
@dispatch.add_dispatch_support
def crop_to_bounding_box(image, offset_height, offset_width, target_height, def crop_to_bounding_box(image, offset_height, offset_width, target_height,
target_width): target_width):
"""Crops an image to a specified bounding box. """Crops an image to a specified bounding box.
@ -1041,6 +1051,7 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height,
@tf_export( @tf_export(
'image.resize_with_crop_or_pad', 'image.resize_with_crop_or_pad',
v1=['image.resize_with_crop_or_pad', 'image.resize_image_with_crop_or_pad']) v1=['image.resize_with_crop_or_pad', 'image.resize_image_with_crop_or_pad'])
@dispatch.add_dispatch_support
def resize_image_with_crop_or_pad(image, target_height, target_width): def resize_image_with_crop_or_pad(image, target_height, target_width):
"""Crops and/or pads an image to a target width and height. """Crops and/or pads an image to a target width and height.
@ -1258,6 +1269,7 @@ def _resize_images_common(images, resizer_fn, size, preserve_aspect_ratio, name,
@tf_export(v1=['image.resize_images', 'image.resize']) @tf_export(v1=['image.resize_images', 'image.resize'])
@dispatch.add_dispatch_support
def resize_images(images, def resize_images(images,
size, size,
method=ResizeMethodV1.BILINEAR, method=ResizeMethodV1.BILINEAR,
@ -1343,6 +1355,7 @@ def resize_images(images,
@tf_export('image.resize', v1=[]) @tf_export('image.resize', v1=[])
@dispatch.add_dispatch_support
def resize_images_v2(images, def resize_images_v2(images,
size, size,
method=ResizeMethod.BILINEAR, method=ResizeMethod.BILINEAR,
@ -1594,6 +1607,7 @@ def _resize_image_with_pad_common(image, target_height, target_width,
@tf_export(v1=['image.resize_image_with_pad']) @tf_export(v1=['image.resize_image_with_pad'])
@dispatch.add_dispatch_support
def resize_image_with_pad_v1(image, def resize_image_with_pad_v1(image,
target_height, target_height,
target_width, target_width,
@ -1636,6 +1650,7 @@ def resize_image_with_pad_v1(image,
@tf_export('image.resize_with_pad', v1=[]) @tf_export('image.resize_with_pad', v1=[])
@dispatch.add_dispatch_support
def resize_image_with_pad_v2(image, def resize_image_with_pad_v2(image,
target_height, target_height,
target_width, target_width,
@ -1676,6 +1691,7 @@ def resize_image_with_pad_v2(image,
@tf_export('image.per_image_standardization') @tf_export('image.per_image_standardization')
@dispatch.add_dispatch_support
def per_image_standardization(image): def per_image_standardization(image):
"""Linearly scales each image in `image` to have mean 0 and variance 1. """Linearly scales each image in `image` to have mean 0 and variance 1.
@ -1721,6 +1737,7 @@ def per_image_standardization(image):
@tf_export('image.random_brightness') @tf_export('image.random_brightness')
@dispatch.add_dispatch_support
def random_brightness(image, max_delta, seed=None): def random_brightness(image, max_delta, seed=None):
"""Adjust the brightness of images by a random factor. """Adjust the brightness of images by a random factor.
@ -1756,6 +1773,7 @@ def random_brightness(image, max_delta, seed=None):
@tf_export('image.random_contrast') @tf_export('image.random_contrast')
@dispatch.add_dispatch_support
def random_contrast(image, lower, upper, seed=None): def random_contrast(image, lower, upper, seed=None):
"""Adjust the contrast of an image or images by a random factor. """Adjust the contrast of an image or images by a random factor.
@ -1796,6 +1814,7 @@ def random_contrast(image, lower, upper, seed=None):
@tf_export('image.adjust_brightness') @tf_export('image.adjust_brightness')
@dispatch.add_dispatch_support
def adjust_brightness(image, delta): def adjust_brightness(image, delta):
"""Adjust the brightness of RGB or Grayscale images. """Adjust the brightness of RGB or Grayscale images.
@ -1847,6 +1866,7 @@ def adjust_brightness(image, delta):
@tf_export('image.adjust_contrast') @tf_export('image.adjust_contrast')
@dispatch.add_dispatch_support
def adjust_contrast(images, contrast_factor): def adjust_contrast(images, contrast_factor):
"""Adjust contrast of RGB or grayscale images. """Adjust contrast of RGB or grayscale images.
@ -1903,6 +1923,7 @@ def adjust_contrast(images, contrast_factor):
@tf_export('image.adjust_gamma') @tf_export('image.adjust_gamma')
@dispatch.add_dispatch_support
def adjust_gamma(image, gamma=1, gain=1): def adjust_gamma(image, gamma=1, gain=1):
"""Performs [Gamma Correction](http://en.wikipedia.org/wiki/Gamma_correction). """Performs [Gamma Correction](http://en.wikipedia.org/wiki/Gamma_correction).
@ -1967,6 +1988,7 @@ def adjust_gamma(image, gamma=1, gain=1):
@tf_export('image.convert_image_dtype') @tf_export('image.convert_image_dtype')
@dispatch.add_dispatch_support
def convert_image_dtype(image, dtype, saturate=False, name=None): def convert_image_dtype(image, dtype, saturate=False, name=None):
"""Convert `image` to `dtype`, scaling its values if needed. """Convert `image` to `dtype`, scaling its values if needed.
@ -2066,6 +2088,7 @@ def convert_image_dtype(image, dtype, saturate=False, name=None):
@tf_export('image.rgb_to_grayscale') @tf_export('image.rgb_to_grayscale')
@dispatch.add_dispatch_support
def rgb_to_grayscale(images, name=None): def rgb_to_grayscale(images, name=None):
"""Converts one or more images from RGB to Grayscale. """Converts one or more images from RGB to Grayscale.
@ -2101,6 +2124,7 @@ def rgb_to_grayscale(images, name=None):
@tf_export('image.grayscale_to_rgb') @tf_export('image.grayscale_to_rgb')
@dispatch.add_dispatch_support
def grayscale_to_rgb(images, name=None): def grayscale_to_rgb(images, name=None):
"""Converts one or more images from Grayscale to RGB. """Converts one or more images from Grayscale to RGB.
@ -2137,6 +2161,7 @@ def grayscale_to_rgb(images, name=None):
# pylint: disable=invalid-name # pylint: disable=invalid-name
@tf_export('image.random_hue') @tf_export('image.random_hue')
@dispatch.add_dispatch_support
def random_hue(image, max_delta, seed=None): def random_hue(image, max_delta, seed=None):
"""Adjust the hue of RGB images by a random factor. """Adjust the hue of RGB images by a random factor.
@ -2179,6 +2204,7 @@ def random_hue(image, max_delta, seed=None):
@tf_export('image.adjust_hue') @tf_export('image.adjust_hue')
@dispatch.add_dispatch_support
def adjust_hue(image, delta, name=None): def adjust_hue(image, delta, name=None):
"""Adjust hue of RGB images. """Adjust hue of RGB images.
@ -2246,6 +2272,7 @@ def adjust_hue(image, delta, name=None):
# pylint: disable=invalid-name # pylint: disable=invalid-name
@tf_export('image.random_jpeg_quality') @tf_export('image.random_jpeg_quality')
@dispatch.add_dispatch_support
def random_jpeg_quality(image, min_jpeg_quality, max_jpeg_quality, seed=None): def random_jpeg_quality(image, min_jpeg_quality, max_jpeg_quality, seed=None):
"""Randomly changes jpeg encoding quality for inducing jpeg noise. """Randomly changes jpeg encoding quality for inducing jpeg noise.
@ -2293,6 +2320,7 @@ def random_jpeg_quality(image, min_jpeg_quality, max_jpeg_quality, seed=None):
@tf_export('image.adjust_jpeg_quality') @tf_export('image.adjust_jpeg_quality')
@dispatch.add_dispatch_support
def adjust_jpeg_quality(image, jpeg_quality, name=None): def adjust_jpeg_quality(image, jpeg_quality, name=None):
"""Adjust jpeg encoding quality of an image. """Adjust jpeg encoding quality of an image.
@ -2343,6 +2371,7 @@ def adjust_jpeg_quality(image, jpeg_quality, name=None):
@tf_export('image.random_saturation') @tf_export('image.random_saturation')
@dispatch.add_dispatch_support
def random_saturation(image, lower, upper, seed=None): def random_saturation(image, lower, upper, seed=None):
"""Adjust the saturation of RGB images by a random factor. """Adjust the saturation of RGB images by a random factor.
@ -2389,6 +2418,7 @@ def random_saturation(image, lower, upper, seed=None):
@tf_export('image.adjust_saturation') @tf_export('image.adjust_saturation')
@dispatch.add_dispatch_support
def adjust_saturation(image, saturation_factor, name=None): def adjust_saturation(image, saturation_factor, name=None):
"""Adjust saturation of RGB images. """Adjust saturation of RGB images.
@ -2480,42 +2510,43 @@ tf_export(
'io.decode_and_crop_jpeg', 'io.decode_and_crop_jpeg',
'image.decode_and_crop_jpeg', 'image.decode_and_crop_jpeg',
v1=['io.decode_and_crop_jpeg', 'image.decode_and_crop_jpeg'])( v1=['io.decode_and_crop_jpeg', 'image.decode_and_crop_jpeg'])(
gen_image_ops.decode_and_crop_jpeg) dispatch.add_dispatch_support(gen_image_ops.decode_and_crop_jpeg))
tf_export( tf_export(
'io.decode_bmp', 'io.decode_bmp',
'image.decode_bmp', 'image.decode_bmp',
v1=['io.decode_bmp', 'image.decode_bmp'])( v1=['io.decode_bmp', 'image.decode_bmp'])(
gen_image_ops.decode_bmp) dispatch.add_dispatch_support(gen_image_ops.decode_bmp))
tf_export( tf_export(
'io.decode_gif', 'io.decode_gif',
'image.decode_gif', 'image.decode_gif',
v1=['io.decode_gif', 'image.decode_gif'])( v1=['io.decode_gif', 'image.decode_gif'])(
gen_image_ops.decode_gif) dispatch.add_dispatch_support(gen_image_ops.decode_gif))
tf_export( tf_export(
'io.decode_jpeg', 'io.decode_jpeg',
'image.decode_jpeg', 'image.decode_jpeg',
v1=['io.decode_jpeg', 'image.decode_jpeg'])( v1=['io.decode_jpeg', 'image.decode_jpeg'])(
gen_image_ops.decode_jpeg) dispatch.add_dispatch_support(gen_image_ops.decode_jpeg))
tf_export( tf_export(
'io.decode_png', 'io.decode_png',
'image.decode_png', 'image.decode_png',
v1=['io.decode_png', 'image.decode_png'])( v1=['io.decode_png', 'image.decode_png'])(
gen_image_ops.decode_png) dispatch.add_dispatch_support(gen_image_ops.decode_png))
tf_export( tf_export(
'io.encode_jpeg', 'io.encode_jpeg',
'image.encode_jpeg', 'image.encode_jpeg',
v1=['io.encode_jpeg', 'image.encode_jpeg'])( v1=['io.encode_jpeg', 'image.encode_jpeg'])(
gen_image_ops.encode_jpeg) dispatch.add_dispatch_support(gen_image_ops.encode_jpeg))
tf_export( tf_export(
'io.extract_jpeg_shape', 'io.extract_jpeg_shape',
'image.extract_jpeg_shape', 'image.extract_jpeg_shape',
v1=['io.extract_jpeg_shape', 'image.extract_jpeg_shape'])( v1=['io.extract_jpeg_shape', 'image.extract_jpeg_shape'])(
gen_image_ops.extract_jpeg_shape) dispatch.add_dispatch_support(gen_image_ops.extract_jpeg_shape))
@tf_export('io.encode_png', 'image.encode_png') @tf_export('io.encode_png', 'image.encode_png')
@dispatch.add_dispatch_support
def encode_png(image, compression=-1, name=None): def encode_png(image, compression=-1, name=None):
r"""PNG-encode an image. r"""PNG-encode an image.
@ -2548,6 +2579,7 @@ def encode_png(image, compression=-1, name=None):
'io.decode_image', 'io.decode_image',
'image.decode_image', 'image.decode_image',
v1=['io.decode_image', 'image.decode_image']) v1=['io.decode_image', 'image.decode_image'])
@dispatch.add_dispatch_support
def decode_image(contents, def decode_image(contents,
channels=None, channels=None,
dtype=dtypes.uint8, dtype=dtypes.uint8,
@ -2661,6 +2693,7 @@ def decode_image(contents,
@tf_export('image.total_variation') @tf_export('image.total_variation')
@dispatch.add_dispatch_support
def total_variation(images, name=None): def total_variation(images, name=None):
"""Calculate and return the total variation for one or more images. """Calculate and return the total variation for one or more images.
@ -2732,6 +2765,7 @@ def total_variation(images, name=None):
@tf_export('image.sample_distorted_bounding_box', v1=[]) @tf_export('image.sample_distorted_bounding_box', v1=[])
@dispatch.add_dispatch_support
def sample_distorted_bounding_box_v2(image_size, def sample_distorted_bounding_box_v2(image_size,
bounding_boxes, bounding_boxes,
seed=0, seed=0,
@ -2831,6 +2865,7 @@ def sample_distorted_bounding_box_v2(image_size,
@tf_export(v1=['image.sample_distorted_bounding_box']) @tf_export(v1=['image.sample_distorted_bounding_box'])
@dispatch.add_dispatch_support
@deprecation.deprecated( @deprecation.deprecated(
date=None, date=None,
instructions='`seed2` arg is deprecated.' instructions='`seed2` arg is deprecated.'
@ -2945,6 +2980,7 @@ def sample_distorted_bounding_box(image_size,
@tf_export('image.non_max_suppression') @tf_export('image.non_max_suppression')
@dispatch.add_dispatch_support
def non_max_suppression(boxes, def non_max_suppression(boxes,
scores, scores,
max_output_size, max_output_size,
@ -2997,6 +3033,7 @@ def non_max_suppression(boxes,
@tf_export('image.non_max_suppression_with_scores') @tf_export('image.non_max_suppression_with_scores')
@dispatch.add_dispatch_support
def non_max_suppression_with_scores(boxes, def non_max_suppression_with_scores(boxes,
scores, scores,
max_output_size, max_output_size,
@ -3083,6 +3120,7 @@ def non_max_suppression_with_scores(boxes,
@tf_export('image.non_max_suppression_overlaps') @tf_export('image.non_max_suppression_overlaps')
@dispatch.add_dispatch_support
def non_max_suppression_with_overlaps(overlaps, def non_max_suppression_with_overlaps(overlaps,
scores, scores,
max_output_size, max_output_size,
@ -3134,6 +3172,7 @@ _rgb_to_yiq_kernel = [[0.299, 0.59590059, 0.2115],
@tf_export('image.rgb_to_yiq') @tf_export('image.rgb_to_yiq')
@dispatch.add_dispatch_support
def rgb_to_yiq(images): def rgb_to_yiq(images):
"""Converts one or more images from RGB to YIQ. """Converts one or more images from RGB to YIQ.
@ -3167,6 +3206,7 @@ _yiq_to_rgb_kernel = [[1, 1, 1], [0.95598634, -0.27201283, -1.10674021],
@tf_export('image.yiq_to_rgb') @tf_export('image.yiq_to_rgb')
@dispatch.add_dispatch_support
def yiq_to_rgb(images): def yiq_to_rgb(images):
"""Converts one or more images from YIQ to RGB. """Converts one or more images from YIQ to RGB.
@ -3195,6 +3235,7 @@ _rgb_to_yuv_kernel = [[0.299, -0.14714119, 0.61497538],
@tf_export('image.rgb_to_yuv') @tf_export('image.rgb_to_yuv')
@dispatch.add_dispatch_support
def rgb_to_yuv(images): def rgb_to_yuv(images):
"""Converts one or more images from RGB to YUV. """Converts one or more images from RGB to YUV.
@ -3221,6 +3262,7 @@ _yuv_to_rgb_kernel = [[1, 1, 1], [0, -0.394642334, 2.03206185],
@tf_export('image.yuv_to_rgb') @tf_export('image.yuv_to_rgb')
@dispatch.add_dispatch_support
def yuv_to_rgb(images): def yuv_to_rgb(images):
"""Converts one or more images from YUV to RGB. """Converts one or more images from YUV to RGB.
@ -3314,6 +3356,7 @@ def _verify_compatible_image_shapes(img1, img2):
@tf_export('image.psnr') @tf_export('image.psnr')
@dispatch.add_dispatch_support
def psnr(a, b, max_val, name=None): def psnr(a, b, max_val, name=None):
"""Returns the Peak Signal-to-Noise Ratio between a and b. """Returns the Peak Signal-to-Noise Ratio between a and b.
@ -3525,6 +3568,7 @@ def _ssim_per_channel(img1,
@tf_export('image.ssim') @tf_export('image.ssim')
@dispatch.add_dispatch_support
def ssim(img1, def ssim(img1,
img2, img2,
max_val, max_val,
@ -3604,6 +3648,7 @@ _MSSSIM_WEIGHTS = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333)
@tf_export('image.ssim_multiscale') @tf_export('image.ssim_multiscale')
@dispatch.add_dispatch_support
def ssim_multiscale(img1, def ssim_multiscale(img1,
img2, img2,
max_val, max_val,
@ -3731,6 +3776,7 @@ def ssim_multiscale(img1,
@tf_export('image.image_gradients') @tf_export('image.image_gradients')
@dispatch.add_dispatch_support
def image_gradients(image): def image_gradients(image):
"""Returns image gradients (dy, dx) for each color channel. """Returns image gradients (dy, dx) for each color channel.
@ -3804,6 +3850,7 @@ def image_gradients(image):
@tf_export('image.sobel_edges') @tf_export('image.sobel_edges')
@dispatch.add_dispatch_support
def sobel_edges(image): def sobel_edges(image):
"""Returns a tensor holding Sobel edge maps. """Returns a tensor holding Sobel edge maps.
@ -3888,21 +3935,22 @@ resize_area_deprecation = deprecation.deprecated(
instructions=( instructions=(
'Use `tf.image.resize(...method=ResizeMethod.AREA...)` instead.')) 'Use `tf.image.resize(...method=ResizeMethod.AREA...)` instead.'))
tf_export(v1=['image.resize_area'])( tf_export(v1=['image.resize_area'])(
resize_area_deprecation(gen_image_ops.resize_area)) resize_area_deprecation(
dispatch.add_dispatch_support(gen_image_ops.resize_area)))
resize_bicubic_deprecation = deprecation.deprecated( resize_bicubic_deprecation = deprecation.deprecated(
date=None, date=None,
instructions=( instructions=(
'Use `tf.image.resize(...method=ResizeMethod.BICUBIC...)` instead.')) 'Use `tf.image.resize(...method=ResizeMethod.BICUBIC...)` instead.'))
tf_export(v1=['image.resize_bicubic'])( tf_export(v1=['image.resize_bicubic'])(
resize_bicubic_deprecation(resize_bicubic)) dispatch.add_dispatch_support(resize_bicubic_deprecation(resize_bicubic)))
resize_bilinear_deprecation = deprecation.deprecated( resize_bilinear_deprecation = deprecation.deprecated(
date=None, date=None,
instructions=( instructions=(
'Use `tf.image.resize(...method=ResizeMethod.BILINEAR...)` instead.')) 'Use `tf.image.resize(...method=ResizeMethod.BILINEAR...)` instead.'))
tf_export(v1=['image.resize_bilinear'])( tf_export(v1=['image.resize_bilinear'])(
resize_bilinear_deprecation(resize_bilinear)) dispatch.add_dispatch_support(resize_bilinear_deprecation(resize_bilinear)))
resize_nearest_neighbor_deprecation = deprecation.deprecated( resize_nearest_neighbor_deprecation = deprecation.deprecated(
date=None, date=None,
@ -3910,10 +3958,12 @@ resize_nearest_neighbor_deprecation = deprecation.deprecated(
'Use `tf.image.resize(...method=ResizeMethod.NEAREST_NEIGHBOR...)` ' 'Use `tf.image.resize(...method=ResizeMethod.NEAREST_NEIGHBOR...)` '
'instead.')) 'instead.'))
tf_export(v1=['image.resize_nearest_neighbor'])( tf_export(v1=['image.resize_nearest_neighbor'])(
resize_nearest_neighbor_deprecation(resize_nearest_neighbor)) dispatch.add_dispatch_support(
resize_nearest_neighbor_deprecation(resize_nearest_neighbor)))
@tf_export('image.crop_and_resize', v1=[]) @tf_export('image.crop_and_resize', v1=[])
@dispatch.add_dispatch_support
def crop_and_resize_v2(image, def crop_and_resize_v2(image,
boxes, boxes,
box_indices, box_indices,
@ -3997,6 +4047,7 @@ def crop_and_resize_v2(image,
@tf_export(v1=['image.crop_and_resize']) @tf_export(v1=['image.crop_and_resize'])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, @deprecation.deprecated_args(None,
'box_ind is deprecated, use box_indices instead', 'box_ind is deprecated, use box_indices instead',
'box_ind') 'box_ind')
@ -4019,6 +4070,7 @@ crop_and_resize_v1.__doc__ = gen_image_ops.crop_and_resize.__doc__
@tf_export(v1=['image.extract_glimpse']) @tf_export(v1=['image.extract_glimpse'])
@dispatch.add_dispatch_support
def extract_glimpse( def extract_glimpse(
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
size, size,
@ -4104,6 +4156,7 @@ def extract_glimpse(
@tf_export('image.extract_glimpse', v1=[]) @tf_export('image.extract_glimpse', v1=[])
@dispatch.add_dispatch_support
def extract_glimpse_v2( def extract_glimpse_v2(
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
size, size,
@ -4190,6 +4243,7 @@ def extract_glimpse_v2(
@tf_export('image.combined_non_max_suppression') @tf_export('image.combined_non_max_suppression')
@dispatch.add_dispatch_support
def combined_non_max_suppression(boxes, def combined_non_max_suppression(boxes,
scores, scores,
max_output_size_per_class, max_output_size_per_class,
@ -4442,6 +4496,7 @@ def _suppression_loop_body(boxes, iou_threshold, output_size, idx, tile_size):
@tf_export('image.non_max_suppression_padded') @tf_export('image.non_max_suppression_padded')
@dispatch.add_dispatch_support
def non_max_suppression_padded(boxes, def non_max_suppression_padded(boxes,
scores, scores,
max_output_size, max_output_size,
@ -4816,6 +4871,7 @@ def non_max_suppression_padded_v1(boxes,
@tf_export('image.draw_bounding_boxes', v1=[]) @tf_export('image.draw_bounding_boxes', v1=[])
@dispatch.add_dispatch_support
def draw_bounding_boxes_v2(images, boxes, colors, name=None): def draw_bounding_boxes_v2(images, boxes, colors, name=None):
"""Draw bounding boxes on a batch of images. """Draw bounding boxes on a batch of images.
@ -4870,6 +4926,7 @@ def draw_bounding_boxes_v2(images, boxes, colors, name=None):
@tf_export(v1=['image.draw_bounding_boxes']) @tf_export(v1=['image.draw_bounding_boxes'])
@dispatch.add_dispatch_support
def draw_bounding_boxes(images, boxes, name=None, colors=None): def draw_bounding_boxes(images, boxes, name=None, colors=None):
"""Draw bounding boxes on a batch of images. """Draw bounding boxes on a batch of images.
@ -4922,6 +4979,7 @@ def draw_bounding_boxes(images, boxes, name=None, colors=None):
@tf_export('image.generate_bounding_box_proposals') @tf_export('image.generate_bounding_box_proposals')
@dispatch.add_dispatch_support
def generate_bounding_box_proposals(scores, def generate_bounding_box_proposals(scores,
bbox_deltas, bbox_deltas,
image_info, image_info,

View File

@ -41,7 +41,7 @@ cholesky = linalg_ops.cholesky
cholesky_solve = linalg_ops.cholesky_solve cholesky_solve = linalg_ops.cholesky_solve
det = linalg_ops.matrix_determinant det = linalg_ops.matrix_determinant
slogdet = gen_linalg_ops.log_matrix_determinant slogdet = gen_linalg_ops.log_matrix_determinant
tf_export('linalg.slogdet')(slogdet) tf_export('linalg.slogdet')(dispatch.add_dispatch_support(slogdet))
diag = array_ops.matrix_diag diag = array_ops.matrix_diag
diag_part = array_ops.matrix_diag_part diag_part = array_ops.matrix_diag_part
eigh = linalg_ops.self_adjoint_eig eigh = linalg_ops.self_adjoint_eig
@ -51,7 +51,7 @@ eye = linalg_ops.eye
inv = linalg_ops.matrix_inverse inv = linalg_ops.matrix_inverse
logm = gen_linalg_ops.matrix_logarithm logm = gen_linalg_ops.matrix_logarithm
lu = gen_linalg_ops.lu lu = gen_linalg_ops.lu
tf_export('linalg.logm')(logm) tf_export('linalg.logm')(dispatch.add_dispatch_support(logm))
lstsq = linalg_ops.matrix_solve_ls lstsq = linalg_ops.matrix_solve_ls
norm = linalg_ops.norm norm = linalg_ops.norm
qr = linalg_ops.qr qr = linalg_ops.qr
@ -230,6 +230,7 @@ def _matrix_exp_pade13(matrix):
@tf_export('linalg.expm') @tf_export('linalg.expm')
@dispatch.add_dispatch_support
def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
r"""Computes the matrix exponential of one or more square matrices. r"""Computes the matrix exponential of one or more square matrices.
@ -340,6 +341,7 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
@tf_export('linalg.tridiagonal_solve') @tf_export('linalg.tridiagonal_solve')
@dispatch.add_dispatch_support
def tridiagonal_solve(diagonals, def tridiagonal_solve(diagonals,
rhs, rhs,
diagonals_format='compact', diagonals_format='compact',
@ -541,6 +543,7 @@ def _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
@tf_export('linalg.tridiagonal_matmul') @tf_export('linalg.tridiagonal_matmul')
@dispatch.add_dispatch_support
def tridiagonal_matmul(diagonals, rhs, diagonals_format='compact', name=None): def tridiagonal_matmul(diagonals, rhs, diagonals_format='compact', name=None):
r"""Multiplies tridiagonal matrix by matrix. r"""Multiplies tridiagonal matrix by matrix.
@ -638,6 +641,7 @@ def _maybe_validate_matrix(a, validate_args):
@tf_export('linalg.matrix_rank') @tf_export('linalg.matrix_rank')
@dispatch.add_dispatch_support
def matrix_rank(a, tol=None, validate_args=False, name=None): def matrix_rank(a, tol=None, validate_args=False, name=None):
"""Compute the matrix rank of one or more matrices. """Compute the matrix rank of one or more matrices.
@ -676,6 +680,7 @@ def matrix_rank(a, tol=None, validate_args=False, name=None):
@tf_export('linalg.pinv') @tf_export('linalg.pinv')
@dispatch.add_dispatch_support
def pinv(a, rcond=None, validate_args=False, name=None): def pinv(a, rcond=None, validate_args=False, name=None):
"""Compute the Moore-Penrose pseudo-inverse of one or more matrices. """Compute the Moore-Penrose pseudo-inverse of one or more matrices.
@ -805,6 +810,7 @@ def pinv(a, rcond=None, validate_args=False, name=None):
@tf_export('linalg.lu_solve') @tf_export('linalg.lu_solve')
@dispatch.add_dispatch_support
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None): def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
"""Solves systems of linear eqns `A X = RHS`, given LU factorizations. """Solves systems of linear eqns `A X = RHS`, given LU factorizations.
@ -902,6 +908,7 @@ def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
@tf_export('linalg.lu_matrix_inverse') @tf_export('linalg.lu_matrix_inverse')
@dispatch.add_dispatch_support
def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None): def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None):
"""Computes the inverse given the LU decomposition(s) of one or more matrices. """Computes the inverse given the LU decomposition(s) of one or more matrices.
@ -966,6 +973,7 @@ def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None):
@tf_export('linalg.lu_reconstruct') @tf_export('linalg.lu_reconstruct')
@dispatch.add_dispatch_support
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None): def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
"""The reconstruct one or more matrices from their LU decomposition(s). """The reconstruct one or more matrices from their LU decomposition(s).

View File

@ -27,10 +27,12 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linalg_impl as linalg from tensorflow.python.ops.linalg import linalg_impl as linalg
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@tf_export('linalg.experimental.conjugate_gradient') @tf_export('linalg.experimental.conjugate_gradient')
@dispatch.add_dispatch_support
def conjugate_gradient(operator, def conjugate_gradient(operator,
rhs, rhs,
preconditioner=None, preconditioner=None,

View File

@ -32,6 +32,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops.gen_linalg_ops import * from tensorflow.python.ops.gen_linalg_ops import *
# pylint: enable=wildcard-import # pylint: enable=wildcard-import
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
# Names below are lower_case. # Names below are lower_case.
@ -82,6 +83,7 @@ def _RegularizedGramianCholesky(matrix, l2_regularizer, first_kind):
@tf_export( @tf_export(
'linalg.triangular_solve', 'linalg.triangular_solve',
v1=['linalg.triangular_solve', 'matrix_triangular_solve']) v1=['linalg.triangular_solve', 'matrix_triangular_solve'])
@dispatch.add_dispatch_support
def matrix_triangular_solve(matrix, rhs, lower=True, adjoint=False, name=None): def matrix_triangular_solve(matrix, rhs, lower=True, adjoint=False, name=None):
"""Solve systems of linear equations with upper or lower triangular matrices. """Solve systems of linear equations with upper or lower triangular matrices.
@ -143,6 +145,7 @@ def matrix_triangular_solve(matrix, rhs, lower=True, adjoint=False, name=None):
@tf_export( @tf_export(
'linalg.cholesky_solve', v1=['linalg.cholesky_solve', 'cholesky_solve']) 'linalg.cholesky_solve', v1=['linalg.cholesky_solve', 'cholesky_solve'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('cholesky_solve') @deprecation.deprecated_endpoints('cholesky_solve')
def cholesky_solve(chol, rhs, name=None): def cholesky_solve(chol, rhs, name=None):
"""Solves systems of linear eqns `A X = RHS`, given Cholesky factorizations. """Solves systems of linear eqns `A X = RHS`, given Cholesky factorizations.
@ -187,6 +190,7 @@ def cholesky_solve(chol, rhs, name=None):
@tf_export('eye', 'linalg.eye') @tf_export('eye', 'linalg.eye')
@dispatch.add_dispatch_support
def eye(num_rows, def eye(num_rows,
num_columns=None, num_columns=None,
batch_shape=None, batch_shape=None,
@ -234,6 +238,7 @@ def eye(num_rows,
@tf_export('linalg.lstsq', v1=['linalg.lstsq', 'matrix_solve_ls']) @tf_export('linalg.lstsq', v1=['linalg.lstsq', 'matrix_solve_ls'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('matrix_solve_ls') @deprecation.deprecated_endpoints('matrix_solve_ls')
def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None): def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None):
r"""Solves one or more linear least-squares problems. r"""Solves one or more linear least-squares problems.
@ -371,6 +376,7 @@ def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None):
@tf_export('linalg.eig', 'eig', v1=[]) @tf_export('linalg.eig', 'eig', v1=[])
@dispatch.add_dispatch_support
def eig(tensor, name=None): def eig(tensor, name=None):
"""Computes the eigen decomposition of a batch of matrices. """Computes the eigen decomposition of a batch of matrices.
@ -401,6 +407,7 @@ def eig(tensor, name=None):
@tf_export('linalg.eigvals', 'eigvals', v1=[]) @tf_export('linalg.eigvals', 'eigvals', v1=[])
@dispatch.add_dispatch_support
def eigvals(tensor, name=None): def eigvals(tensor, name=None):
"""Computes the eigenvalues of one or more matrices. """Computes the eigenvalues of one or more matrices.
@ -427,6 +434,7 @@ def eigvals(tensor, name=None):
@tf_export('linalg.eigh', v1=['linalg.eigh', 'self_adjoint_eig']) @tf_export('linalg.eigh', v1=['linalg.eigh', 'self_adjoint_eig'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('self_adjoint_eig') @deprecation.deprecated_endpoints('self_adjoint_eig')
def self_adjoint_eig(tensor, name=None): def self_adjoint_eig(tensor, name=None):
"""Computes the eigen decomposition of a batch of self-adjoint matrices. """Computes the eigen decomposition of a batch of self-adjoint matrices.
@ -450,6 +458,7 @@ def self_adjoint_eig(tensor, name=None):
@tf_export('linalg.eigvalsh', v1=['linalg.eigvalsh', 'self_adjoint_eigvals']) @tf_export('linalg.eigvalsh', v1=['linalg.eigvalsh', 'self_adjoint_eigvals'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('self_adjoint_eigvals') @deprecation.deprecated_endpoints('self_adjoint_eigvals')
def self_adjoint_eigvals(tensor, name=None): def self_adjoint_eigvals(tensor, name=None):
"""Computes the eigenvalues of one or more self-adjoint matrices. """Computes the eigenvalues of one or more self-adjoint matrices.
@ -473,6 +482,7 @@ def self_adjoint_eigvals(tensor, name=None):
@tf_export('linalg.svd', v1=['linalg.svd', 'svd']) @tf_export('linalg.svd', v1=['linalg.svd', 'svd'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('svd') @deprecation.deprecated_endpoints('svd')
def svd(tensor, full_matrices=False, compute_uv=True, name=None): def svd(tensor, full_matrices=False, compute_uv=True, name=None):
r"""Computes the singular value decompositions of one or more matrices. r"""Computes the singular value decompositions of one or more matrices.
@ -544,6 +554,7 @@ def svd(tensor, full_matrices=False, compute_uv=True, name=None):
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
@tf_export('norm', 'linalg.norm', v1=[]) @tf_export('norm', 'linalg.norm', v1=[])
@dispatch.add_dispatch_support
def norm_v2(tensor, def norm_v2(tensor,
ord='euclidean', ord='euclidean',
axis=None, axis=None,
@ -615,6 +626,7 @@ def norm_v2(tensor,
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
@tf_export(v1=['norm', 'linalg.norm']) @tf_export(v1=['norm', 'linalg.norm'])
@dispatch.add_dispatch_support
@deprecation.deprecated_args( @deprecation.deprecated_args(
None, 'keep_dims is deprecated, use keepdims instead', 'keep_dims') None, 'keep_dims is deprecated, use keepdims instead', 'keep_dims')
def norm(tensor, def norm(tensor,

View File

@ -38,6 +38,7 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.ops.gen_logging_ops import * from tensorflow.python.ops.gen_logging_ops import *
# pylint: enable=wildcard-import # pylint: enable=wildcard-import
from tensorflow.python.platform import tf_logging from tensorflow.python.platform import tf_logging
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -71,6 +72,7 @@ except NameError:
"only a concern in graph mode. Below is an example " "only a concern in graph mode. Below is an example "
"of how to ensure tf.print executes in graph mode:\n") "of how to ensure tf.print executes in graph mode:\n")
@tf_export(v1=["Print"]) @tf_export(v1=["Print"])
@dispatch.add_dispatch_support
def Print(input_, data, message=None, first_n=None, summarize=None, name=None): def Print(input_, data, message=None, first_n=None, summarize=None, name=None):
"""Prints a list of tensors. """Prints a list of tensors.
@ -136,6 +138,7 @@ def _is_filepath(output_stream):
# function definition. # function definition.
# pylint: disable=g-doc-args # pylint: disable=g-doc-args
@tf_export("print") @tf_export("print")
@dispatch.add_dispatch_support
def print_v2(*inputs, **kwargs): def print_v2(*inputs, **kwargs):
"""Print the specified inputs. """Print the specified inputs.

View File

@ -29,6 +29,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.ops.losses import util from tensorflow.python.ops.losses import util
from tensorflow.python.util import dispatch
from tensorflow.python.util.deprecation import deprecated_args from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.deprecation import deprecated_argument_lookup from tensorflow.python.util.deprecation import deprecated_argument_lookup
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -136,6 +137,7 @@ def _num_elements(losses):
@tf_export(v1=["losses.compute_weighted_loss"]) @tf_export(v1=["losses.compute_weighted_loss"])
@dispatch.add_dispatch_support
def compute_weighted_loss( def compute_weighted_loss(
losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
@ -204,6 +206,7 @@ def compute_weighted_loss(
@tf_export(v1=["losses.absolute_difference"]) @tf_export(v1=["losses.absolute_difference"])
@dispatch.add_dispatch_support
def absolute_difference( def absolute_difference(
labels, predictions, weights=1.0, scope=None, labels, predictions, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES, loss_collection=ops.GraphKeys.LOSSES,
@ -257,6 +260,7 @@ def absolute_difference(
@tf_export(v1=["losses.cosine_distance"]) @tf_export(v1=["losses.cosine_distance"])
@dispatch.add_dispatch_support
@deprecated_args(None, "dim is deprecated, use axis instead", "dim") @deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def cosine_distance( def cosine_distance(
labels, predictions, axis=None, weights=1.0, scope=None, labels, predictions, axis=None, weights=1.0, scope=None,
@ -313,6 +317,7 @@ def cosine_distance(
@tf_export(v1=["losses.hinge_loss"]) @tf_export(v1=["losses.hinge_loss"])
@dispatch.add_dispatch_support
def hinge_loss(labels, logits, weights=1.0, scope=None, def hinge_loss(labels, logits, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES, loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
@ -363,6 +368,7 @@ def hinge_loss(labels, logits, weights=1.0, scope=None,
@tf_export(v1=["losses.huber_loss"]) @tf_export(v1=["losses.huber_loss"])
@dispatch.add_dispatch_support
def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None, def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES, loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
@ -439,6 +445,7 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None,
@tf_export(v1=["losses.log_loss"]) @tf_export(v1=["losses.log_loss"])
@dispatch.add_dispatch_support
def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None, def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None,
loss_collection=ops.GraphKeys.LOSSES, loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
@ -496,6 +503,7 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None,
# TODO(b/37208492): Add reduction arg. # TODO(b/37208492): Add reduction arg.
@tf_export(v1=["losses.mean_pairwise_squared_error"]) @tf_export(v1=["losses.mean_pairwise_squared_error"])
@dispatch.add_dispatch_support
def mean_pairwise_squared_error( def mean_pairwise_squared_error(
labels, predictions, weights=1.0, scope=None, labels, predictions, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES): loss_collection=ops.GraphKeys.LOSSES):
@ -592,6 +600,7 @@ def mean_pairwise_squared_error(
@tf_export(v1=["losses.mean_squared_error"]) @tf_export(v1=["losses.mean_squared_error"])
@dispatch.add_dispatch_support
def mean_squared_error( def mean_squared_error(
labels, predictions, weights=1.0, scope=None, labels, predictions, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES, loss_collection=ops.GraphKeys.LOSSES,
@ -645,6 +654,7 @@ def mean_squared_error(
@tf_export(v1=["losses.sigmoid_cross_entropy"]) @tf_export(v1=["losses.sigmoid_cross_entropy"])
@dispatch.add_dispatch_support
def sigmoid_cross_entropy( def sigmoid_cross_entropy(
multi_class_labels, logits, weights=1.0, label_smoothing=0, scope=None, multi_class_labels, logits, weights=1.0, label_smoothing=0, scope=None,
loss_collection=ops.GraphKeys.LOSSES, loss_collection=ops.GraphKeys.LOSSES,
@ -709,6 +719,7 @@ def sigmoid_cross_entropy(
@tf_export(v1=["losses.softmax_cross_entropy"]) @tf_export(v1=["losses.softmax_cross_entropy"])
@dispatch.add_dispatch_support
def softmax_cross_entropy( def softmax_cross_entropy(
onehot_labels, logits, weights=1.0, label_smoothing=0, scope=None, onehot_labels, logits, weights=1.0, label_smoothing=0, scope=None,
loss_collection=ops.GraphKeys.LOSSES, loss_collection=ops.GraphKeys.LOSSES,
@ -831,6 +842,7 @@ def _remove_squeezable_dimensions(
@tf_export(v1=["losses.sparse_softmax_cross_entropy"]) @tf_export(v1=["losses.sparse_softmax_cross_entropy"])
@dispatch.add_dispatch_support
def sparse_softmax_cross_entropy( def sparse_softmax_cross_entropy(
labels, logits, weights=1.0, scope=None, labels, logits, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES, loss_collection=ops.GraphKeys.LOSSES,

View File

@ -20,11 +20,13 @@ from __future__ import print_function
from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access # pylint: disable=protected-access
@tf_export('roll', v1=['roll', 'manip.roll']) @tf_export('roll', v1=['roll', 'manip.roll'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('manip.roll') @deprecation.deprecated_endpoints('manip.roll')
def roll(input, shift, axis, name=None): # pylint: disable=redefined-builtin def roll(input, shift, axis, name=None): # pylint: disable=redefined-builtin
return _gen_manip_ops.roll(input, shift, axis, name) return _gen_manip_ops.roll(input, shift, axis, name)

View File

@ -104,6 +104,7 @@ nextafter = gen_math_ops.next_after
@tf_export("linspace", v1=["lin_space", "linspace"]) @tf_export("linspace", v1=["lin_space", "linspace"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("lin_space") @deprecation.deprecated_endpoints("lin_space")
def linspace_nd(start, stop, num, name=None, axis=0): def linspace_nd(start, stop, num, name=None, axis=0):
r"""Generates evenly-spaced values in an interval along a given axis. r"""Generates evenly-spaced values in an interval along a given axis.
@ -214,8 +215,8 @@ linspace = linspace_nd
arg_max = deprecation.deprecated(None, "Use `tf.math.argmax` instead")(arg_max) # pylint: disable=used-before-assignment arg_max = deprecation.deprecated(None, "Use `tf.math.argmax` instead")(arg_max) # pylint: disable=used-before-assignment
arg_min = deprecation.deprecated(None, "Use `tf.math.argmin` instead")(arg_min) # pylint: disable=used-before-assignment arg_min = deprecation.deprecated(None, "Use `tf.math.argmin` instead")(arg_min) # pylint: disable=used-before-assignment
tf_export(v1=["arg_max"])(arg_max) tf_export(v1=["arg_max"])(dispatch.add_dispatch_support(arg_max))
tf_export(v1=["arg_min"])(arg_min) tf_export(v1=["arg_min"])(dispatch.add_dispatch_support(arg_min))
# This is set by resource_variable_ops.py. It is included in this way since # This is set by resource_variable_ops.py. It is included in this way since
@ -234,6 +235,7 @@ def _set_doc(doc):
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
@tf_export(v1=["math.argmax", "argmax"]) @tf_export(v1=["math.argmax", "argmax"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, "Use the `axis` argument instead", @deprecation.deprecated_args(None, "Use the `axis` argument instead",
"dimension") "dimension")
@_set_doc( @_set_doc(
@ -250,6 +252,7 @@ def argmax(input,
@tf_export("math.argmax", "argmax", v1=[]) @tf_export("math.argmax", "argmax", v1=[])
@dispatch.add_dispatch_support
def argmax_v2(input, axis=None, output_type=dtypes.int64, name=None): def argmax_v2(input, axis=None, output_type=dtypes.int64, name=None):
"""Returns the index with the largest value across axes of a tensor. """Returns the index with the largest value across axes of a tensor.
@ -283,6 +286,7 @@ def argmax_v2(input, axis=None, output_type=dtypes.int64, name=None):
@tf_export(v1=["math.argmin", "argmin"]) @tf_export(v1=["math.argmin", "argmin"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, "Use the `axis` argument instead", @deprecation.deprecated_args(None, "Use the `axis` argument instead",
"dimension") "dimension")
@_set_doc( @_set_doc(
@ -299,6 +303,7 @@ def argmin(input,
@tf_export("math.argmin", "argmin", v1=[]) @tf_export("math.argmin", "argmin", v1=[])
@dispatch.add_dispatch_support
def argmin_v2(input, axis=None, output_type=dtypes.int64, name=None): def argmin_v2(input, axis=None, output_type=dtypes.int64, name=None):
"""Returns the index with the smallest value across axes of a tensor. """Returns the index with the smallest value across axes of a tensor.
@ -549,6 +554,7 @@ def _neg(x, name=None):
@tf_export(v1=["math.scalar_mul", "scalar_mul"]) @tf_export(v1=["math.scalar_mul", "scalar_mul"])
@dispatch.add_dispatch_support
def scalar_mul(scalar, x, name=None): def scalar_mul(scalar, x, name=None):
"""Multiplies a scalar times a `Tensor` or `IndexedSlices` object. """Multiplies a scalar times a `Tensor` or `IndexedSlices` object.
@ -581,6 +587,7 @@ def scalar_mul(scalar, x, name=None):
@tf_export("math.scalar_mul", "scalar_mul", v1=[]) @tf_export("math.scalar_mul", "scalar_mul", v1=[])
@dispatch.add_dispatch_support
@_set_doc(scalar_mul.__doc__) @_set_doc(scalar_mul.__doc__)
def scalar_mul_v2(scalar, x, name=None): def scalar_mul_v2(scalar, x, name=None):
with ops.name_scope(name, "scalar_mul", [x]) as name: with ops.name_scope(name, "scalar_mul", [x]) as name:
@ -701,6 +708,7 @@ def sign(x, name=None):
@tf_export("math.real", v1=["math.real", "real"]) @tf_export("math.real", v1=["math.real", "real"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("real") @deprecation.deprecated_endpoints("real")
@dispatch.add_dispatch_support @dispatch.add_dispatch_support
def real(input, name=None): def real(input, name=None):
@ -735,6 +743,7 @@ def real(input, name=None):
@tf_export("math.imag", v1=["math.imag", "imag"]) @tf_export("math.imag", v1=["math.imag", "imag"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("imag") @deprecation.deprecated_endpoints("imag")
@dispatch.add_dispatch_support @dispatch.add_dispatch_support
def imag(input, name=None): def imag(input, name=None):
@ -768,6 +777,7 @@ def imag(input, name=None):
@tf_export("math.angle", v1=["math.angle", "angle"]) @tf_export("math.angle", v1=["math.angle", "angle"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("angle") @deprecation.deprecated_endpoints("angle")
@dispatch.add_dispatch_support @dispatch.add_dispatch_support
def angle(input, name=None): def angle(input, name=None):
@ -937,6 +947,7 @@ def saturate_cast(value, dtype, name=None):
@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.") @deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
@tf_export(v1=["to_float"]) @tf_export(v1=["to_float"])
@dispatch.add_dispatch_support
def to_float(x, name="ToFloat"): def to_float(x, name="ToFloat"):
"""Casts a tensor to type `float32`. """Casts a tensor to type `float32`.
@ -956,6 +967,7 @@ def to_float(x, name="ToFloat"):
@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.") @deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
@tf_export(v1=["to_double"]) @tf_export(v1=["to_double"])
@dispatch.add_dispatch_support
def to_double(x, name="ToDouble"): def to_double(x, name="ToDouble"):
"""Casts a tensor to type `float64`. """Casts a tensor to type `float64`.
@ -975,6 +987,7 @@ def to_double(x, name="ToDouble"):
@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.") @deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
@tf_export(v1=["to_int32"]) @tf_export(v1=["to_int32"])
@dispatch.add_dispatch_support
def to_int32(x, name="ToInt32"): def to_int32(x, name="ToInt32"):
"""Casts a tensor to type `int32`. """Casts a tensor to type `int32`.
@ -994,6 +1007,7 @@ def to_int32(x, name="ToInt32"):
@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.") @deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
@tf_export(v1=["to_int64"]) @tf_export(v1=["to_int64"])
@dispatch.add_dispatch_support
def to_int64(x, name="ToInt64"): def to_int64(x, name="ToInt64"):
"""Casts a tensor to type `int64`. """Casts a tensor to type `int64`.
@ -1013,6 +1027,7 @@ def to_int64(x, name="ToInt64"):
@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.") @deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
@tf_export(v1=["to_bfloat16"]) @tf_export(v1=["to_bfloat16"])
@dispatch.add_dispatch_support
def to_bfloat16(x, name="ToBFloat16"): def to_bfloat16(x, name="ToBFloat16"):
"""Casts a tensor to type `bfloat16`. """Casts a tensor to type `bfloat16`.
@ -1032,6 +1047,7 @@ def to_bfloat16(x, name="ToBFloat16"):
@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.") @deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
@tf_export(v1=["to_complex64"]) @tf_export(v1=["to_complex64"])
@dispatch.add_dispatch_support
def to_complex64(x, name="ToComplex64"): def to_complex64(x, name="ToComplex64"):
"""Casts a tensor to type `complex64`. """Casts a tensor to type `complex64`.
@ -1051,6 +1067,7 @@ def to_complex64(x, name="ToComplex64"):
@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.") @deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
@tf_export(v1=["to_complex128"]) @tf_export(v1=["to_complex128"])
@dispatch.add_dispatch_support
def to_complex128(x, name="ToComplex128"): def to_complex128(x, name="ToComplex128"):
"""Casts a tensor to type `complex128`. """Casts a tensor to type `complex128`.
@ -1265,6 +1282,7 @@ def truediv(x, y, name=None):
date=None, date=None,
instructions="Deprecated in favor of operator or tf.math.divide.") instructions="Deprecated in favor of operator or tf.math.divide.")
@tf_export(v1=["div"]) @tf_export(v1=["div"])
@dispatch.add_dispatch_support
def div(x, y, name=None): def div(x, y, name=None):
"""Divides x / y elementwise (using Python 2 division operator semantics). """Divides x / y elementwise (using Python 2 division operator semantics).
@ -1288,6 +1306,7 @@ def div(x, y, name=None):
@tf_export("math.divide_no_nan", v1=["math.divide_no_nan", "div_no_nan"]) @tf_export("math.divide_no_nan", v1=["math.divide_no_nan", "div_no_nan"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("div_no_nan") @deprecation.deprecated_endpoints("div_no_nan")
@dispatch.add_dispatch_support @dispatch.add_dispatch_support
def div_no_nan(x, y, name=None): def div_no_nan(x, y, name=None):
@ -1620,6 +1639,7 @@ ops.Tensor._override_operator("__ne__", tensor_not_equals)
@tf_export("range") @tf_export("range")
@dispatch.add_dispatch_support
def range(start, limit=None, delta=1, dtype=None, name="range"): # pylint: disable=redefined-builtin def range(start, limit=None, delta=1, dtype=None, name="range"): # pylint: disable=redefined-builtin
"""Creates a sequence of numbers. """Creates a sequence of numbers.
@ -1751,6 +1771,7 @@ def _may_reduce_to_scalar(keepdims, axis, output):
@tf_export(v1=["math.reduce_sum", "reduce_sum"]) @tf_export(v1=["math.reduce_sum", "reduce_sum"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, @deprecation.deprecated_args(None,
"keep_dims is deprecated, use keepdims instead", "keep_dims is deprecated, use keepdims instead",
"keep_dims") "keep_dims")
@ -1885,6 +1906,7 @@ def reduce_sum_with_dims(input_tensor,
@tf_export("math.reduce_euclidean_norm") @tf_export("math.reduce_euclidean_norm")
@dispatch.add_dispatch_support
def reduce_euclidean_norm(input_tensor, axis=None, keepdims=False, name=None): def reduce_euclidean_norm(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the Euclidean norm of elements across dimensions of a tensor. """Computes the Euclidean norm of elements across dimensions of a tensor.
@ -1928,6 +1950,7 @@ def reduce_euclidean_norm(input_tensor, axis=None, keepdims=False, name=None):
@tf_export(v1=["math.count_nonzero", "count_nonzero"]) @tf_export(v1=["math.count_nonzero", "count_nonzero"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, @deprecation.deprecated_args(None,
"keep_dims is deprecated, use keepdims instead", "keep_dims is deprecated, use keepdims instead",
"keep_dims") "keep_dims")
@ -2005,6 +2028,7 @@ def count_nonzero(input_tensor=None,
@tf_export("math.count_nonzero", v1=[]) @tf_export("math.count_nonzero", v1=[])
@dispatch.add_dispatch_support
def count_nonzero_v2( def count_nonzero_v2(
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
axis=None, axis=None,
@ -2072,6 +2096,7 @@ def count_nonzero_v2(
@tf_export(v1=["math.reduce_mean", "reduce_mean"]) @tf_export(v1=["math.reduce_mean", "reduce_mean"])
@dispatch.add_dispatch_support
def reduce_mean_v1(input_tensor, def reduce_mean_v1(input_tensor,
axis=None, axis=None,
keepdims=None, keepdims=None,
@ -2198,6 +2223,7 @@ def reduce_mean(input_tensor, axis=None, keepdims=False, name=None):
@tf_export("math.reduce_variance") @tf_export("math.reduce_variance")
@dispatch.add_dispatch_support
def reduce_variance(input_tensor, axis=None, keepdims=False, name=None): def reduce_variance(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the variance of elements across dimensions of a tensor. """Computes the variance of elements across dimensions of a tensor.
@ -2246,6 +2272,7 @@ def reduce_variance(input_tensor, axis=None, keepdims=False, name=None):
@tf_export("math.reduce_std") @tf_export("math.reduce_std")
@dispatch.add_dispatch_support
def reduce_std(input_tensor, axis=None, keepdims=False, name=None): def reduce_std(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the standard deviation of elements across dimensions of a tensor. """Computes the standard deviation of elements across dimensions of a tensor.
@ -2328,6 +2355,7 @@ def reduce_prod(input_tensor, axis=None, keepdims=False, name=None):
@tf_export(v1=["math.reduce_prod", "reduce_prod"]) @tf_export(v1=["math.reduce_prod", "reduce_prod"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, @deprecation.deprecated_args(None,
"keep_dims is deprecated, use keepdims instead", "keep_dims is deprecated, use keepdims instead",
"keep_dims") "keep_dims")
@ -2373,6 +2401,7 @@ def reduce_prod_v1(input_tensor,
@tf_export(v1=["math.reduce_min", "reduce_min"]) @tf_export(v1=["math.reduce_min", "reduce_min"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, @deprecation.deprecated_args(None,
"keep_dims is deprecated, use keepdims instead", "keep_dims is deprecated, use keepdims instead",
"keep_dims") "keep_dims")
@ -2459,6 +2488,7 @@ def reduce_min(input_tensor, axis=None, keepdims=False, name=None):
@tf_export(v1=["math.reduce_max", "reduce_max"]) @tf_export(v1=["math.reduce_max", "reduce_max"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, @deprecation.deprecated_args(None,
"keep_dims is deprecated, use keepdims instead", "keep_dims is deprecated, use keepdims instead",
"keep_dims") "keep_dims")
@ -2563,6 +2593,7 @@ def reduce_max_with_dims(input_tensor,
@tf_export(v1=["math.reduce_all", "reduce_all"]) @tf_export(v1=["math.reduce_all", "reduce_all"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, @deprecation.deprecated_args(None,
"keep_dims is deprecated, use keepdims instead", "keep_dims is deprecated, use keepdims instead",
"keep_dims") "keep_dims")
@ -2662,6 +2693,7 @@ def reduce_all(input_tensor, axis=None, keepdims=False, name=None):
@tf_export(v1=["math.reduce_any", "reduce_any"]) @tf_export(v1=["math.reduce_any", "reduce_any"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, @deprecation.deprecated_args(None,
"keep_dims is deprecated, use keepdims instead", "keep_dims is deprecated, use keepdims instead",
"keep_dims") "keep_dims")
@ -2761,6 +2793,7 @@ def reduce_any(input_tensor, axis=None, keepdims=False, name=None):
@tf_export(v1=["math.reduce_logsumexp", "reduce_logsumexp"]) @tf_export(v1=["math.reduce_logsumexp", "reduce_logsumexp"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, @deprecation.deprecated_args(None,
"keep_dims is deprecated, use keepdims instead", "keep_dims is deprecated, use keepdims instead",
"keep_dims") "keep_dims")
@ -2817,6 +2850,7 @@ def reduce_logsumexp_v1(input_tensor,
@tf_export("math.reduce_logsumexp", "reduce_logsumexp", v1=[]) @tf_export("math.reduce_logsumexp", "reduce_logsumexp", v1=[])
@dispatch.add_dispatch_support
def reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None): def reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None):
"""Computes log(sum(exp(elements across dimensions of a tensor))). """Computes log(sum(exp(elements across dimensions of a tensor))).
@ -2877,6 +2911,7 @@ def reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None):
@tf_export("linalg.trace", v1=["linalg.trace", "trace"]) @tf_export("linalg.trace", v1=["linalg.trace", "trace"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("trace") @deprecation.deprecated_endpoints("trace")
@dispatch.add_dispatch_support @dispatch.add_dispatch_support
def trace(x, name=None): def trace(x, name=None):
@ -3116,6 +3151,7 @@ def matmul(a,
@tf_export("linalg.matvec") @tf_export("linalg.matvec")
@dispatch.add_dispatch_support
def matvec(a, def matvec(a,
b, b,
transpose_a=False, transpose_a=False,
@ -3219,6 +3255,7 @@ _OverrideBinaryOperatorHelper(matmul, "matmul")
sparse_matmul = deprecation.deprecated(None, "Use `tf.linalg.matmul` instead")( sparse_matmul = deprecation.deprecated(None, "Use `tf.linalg.matmul` instead")(
gen_math_ops.sparse_mat_mul) gen_math_ops.sparse_mat_mul)
tf_export(v1=["sparse_matmul"])(sparse_matmul) tf_export(v1=["sparse_matmul"])(sparse_matmul)
@dispatch.add_dispatch_support
@ops.RegisterStatistics("MatMul", "flops") @ops.RegisterStatistics("MatMul", "flops")
@ -3371,6 +3408,7 @@ def add_n(inputs, name=None):
@tf_export("math.accumulate_n", v1=["math.accumulate_n", "accumulate_n"]) @tf_export("math.accumulate_n", v1=["math.accumulate_n", "accumulate_n"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("accumulate_n") @deprecation.deprecated_endpoints("accumulate_n")
def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
"""Returns the element-wise sum of a list of tensors. """Returns the element-wise sum of a list of tensors.
@ -3449,6 +3487,7 @@ def _accumulate_n_grad(op, grad):
@tf_export("math.sigmoid", "nn.sigmoid", "sigmoid") @tf_export("math.sigmoid", "nn.sigmoid", "sigmoid")
@dispatch.add_dispatch_support
def sigmoid(x, name=None): def sigmoid(x, name=None):
r"""Computes sigmoid of `x` element-wise. r"""Computes sigmoid of `x` element-wise.
@ -3521,6 +3560,7 @@ def log_sigmoid(x, name=None):
@tf_export("math.bincount", v1=[]) @tf_export("math.bincount", v1=[])
@dispatch.add_dispatch_support
def bincount(arr, def bincount(arr,
weights=None, weights=None,
minlength=None, minlength=None,
@ -3596,6 +3636,7 @@ def bincount(arr,
@tf_export(v1=["math.bincount", "bincount"]) @tf_export(v1=["math.bincount", "bincount"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("bincount") @deprecation.deprecated_endpoints("bincount")
def bincount_v1(arr, def bincount_v1(arr,
weights=None, weights=None,
@ -3629,6 +3670,7 @@ def bincount_v1(arr,
@tf_export("math.cumsum", "cumsum") @tf_export("math.cumsum", "cumsum")
@dispatch.add_dispatch_support
def cumsum(x, axis=0, exclusive=False, reverse=False, name=None): def cumsum(x, axis=0, exclusive=False, reverse=False, name=None):
"""Compute the cumulative sum of the tensor `x` along `axis`. """Compute the cumulative sum of the tensor `x` along `axis`.
@ -3700,6 +3742,7 @@ def cumsum(x, axis=0, exclusive=False, reverse=False, name=None):
@tf_export("math.cumprod", v1=["math.cumprod", "cumprod"]) @tf_export("math.cumprod", v1=["math.cumprod", "cumprod"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("cumprod") @deprecation.deprecated_endpoints("cumprod")
def cumprod(x, axis=0, exclusive=False, reverse=False, name=None): def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
"""Compute the cumulative product of the tensor `x` along `axis`. """Compute the cumulative product of the tensor `x` along `axis`.
@ -3753,6 +3796,7 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
@tf_export("math.cumulative_logsumexp", v1=["math.cumulative_logsumexp"]) @tf_export("math.cumulative_logsumexp", v1=["math.cumulative_logsumexp"])
@dispatch.add_dispatch_support
def cumulative_logsumexp(x, axis=0, exclusive=False, reverse=False, name=None): def cumulative_logsumexp(x, axis=0, exclusive=False, reverse=False, name=None):
"""Compute the cumulative log-sum-exp of the tensor `x` along `axis`. """Compute the cumulative log-sum-exp of the tensor `x` along `axis`.
@ -3912,6 +3956,7 @@ def _unsorted_segment_N(data, segment_ids, num_segments):
@tf_export( @tf_export(
"math.unsorted_segment_mean", "math.unsorted_segment_mean",
v1=["math.unsorted_segment_mean", "unsorted_segment_mean"]) v1=["math.unsorted_segment_mean", "unsorted_segment_mean"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("unsorted_segment_mean") @deprecation.deprecated_endpoints("unsorted_segment_mean")
@dispatch.add_dispatch_support @dispatch.add_dispatch_support
def unsorted_segment_mean(data, segment_ids, num_segments, name=None): def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
@ -3958,6 +4003,7 @@ def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
@tf_export( @tf_export(
"math.unsorted_segment_sqrt_n", "math.unsorted_segment_sqrt_n",
v1=["math.unsorted_segment_sqrt_n", "unsorted_segment_sqrt_n"]) v1=["math.unsorted_segment_sqrt_n", "unsorted_segment_sqrt_n"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("unsorted_segment_sqrt_n") @deprecation.deprecated_endpoints("unsorted_segment_sqrt_n")
@dispatch.add_dispatch_support @dispatch.add_dispatch_support
def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None): def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None):
@ -4307,6 +4353,7 @@ def sparse_segment_sqrt_n_v2(data,
@tf_export("tensordot", "linalg.tensordot") @tf_export("tensordot", "linalg.tensordot")
@dispatch.add_dispatch_support
def tensordot(a, b, axes, name=None): def tensordot(a, b, axes, name=None):
r"""Tensor contraction of a and b along specified axes and outer product. r"""Tensor contraction of a and b along specified axes and outer product.
@ -4493,6 +4540,7 @@ def tensordot(a, b, axes, name=None):
@tf_export("math.polyval") @tf_export("math.polyval")
@dispatch.add_dispatch_support
def polyval(coeffs, x, name=None): def polyval(coeffs, x, name=None):
r"""Computes the elementwise value of a polynomial. r"""Computes the elementwise value of a polynomial.
@ -4563,6 +4611,7 @@ def polyval(coeffs, x, name=None):
@tf_export("math.reciprocal_no_nan") @tf_export("math.reciprocal_no_nan")
@dispatch.add_dispatch_support
def reciprocal_no_nan(x, name=None): def reciprocal_no_nan(x, name=None):
"""Performs a safe reciprocal operation, element wise. """Performs a safe reciprocal operation, element wise.
@ -4665,6 +4714,7 @@ def ndtri(x, name=None):
@tf_export("math.ceil", v1=["math.ceil", "ceil"]) @tf_export("math.ceil", v1=["math.ceil", "ceil"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("ceil") @deprecation.deprecated_endpoints("ceil")
@dispatch.add_dispatch_support @dispatch.add_dispatch_support
def ceil(x, name=None): def ceil(x, name=None):
@ -4778,6 +4828,7 @@ def exp(x, name=None):
@tf_export("math.sobol_sample") @tf_export("math.sobol_sample")
@dispatch.add_dispatch_support
def sobol_sample(dim, num_results, skip=0, dtype=dtypes.float32, name=None): def sobol_sample(dim, num_results, skip=0, dtype=dtypes.float32, name=None):
"""Generates points from the Sobol sequence. """Generates points from the Sobol sequence.
@ -4802,6 +4853,7 @@ def sobol_sample(dim, num_results, skip=0, dtype=dtypes.float32, name=None):
@tf_export("math.rsqrt", v1=["math.rsqrt", "rsqrt"]) @tf_export("math.rsqrt", v1=["math.rsqrt", "rsqrt"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("rsqrt") @deprecation.deprecated_endpoints("rsqrt")
@dispatch.add_dispatch_support @dispatch.add_dispatch_support
def rsqrt(x, name=None): def rsqrt(x, name=None):

View File

@ -39,12 +39,14 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import util as losses_util from tensorflow.python.ops.losses import util as losses_util
from tensorflow.python.platform import device_context from tensorflow.python.platform import device_context
from tensorflow.python.util import dispatch
from tensorflow.python.util.deprecation import deprecated_args from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.deprecation import deprecated_argument_lookup from tensorflow.python.util.deprecation import deprecated_argument_lookup
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@tf_export("nn.log_poisson_loss") @tf_export("nn.log_poisson_loss")
@dispatch.add_dispatch_support
def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None): def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None):
"""Computes log Poisson loss given `log_input`. """Computes log Poisson loss given `log_input`.
@ -110,6 +112,7 @@ def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None):
@tf_export(v1=["nn.sigmoid_cross_entropy_with_logits"]) @tf_export(v1=["nn.sigmoid_cross_entropy_with_logits"])
@dispatch.add_dispatch_support
def sigmoid_cross_entropy_with_logits( # pylint: disable=invalid-name def sigmoid_cross_entropy_with_logits( # pylint: disable=invalid-name
_sentinel=None, _sentinel=None,
labels=None, labels=None,
@ -192,6 +195,7 @@ def sigmoid_cross_entropy_with_logits( # pylint: disable=invalid-name
# Note: intentionally calling this v2 to not allow existing code with indirect # Note: intentionally calling this v2 to not allow existing code with indirect
# imports to ignore the sentinel behavior. # imports to ignore the sentinel behavior.
@tf_export("nn.sigmoid_cross_entropy_with_logits", v1=[]) @tf_export("nn.sigmoid_cross_entropy_with_logits", v1=[])
@dispatch.add_dispatch_support
def sigmoid_cross_entropy_with_logits_v2( # pylint: disable=invalid-name def sigmoid_cross_entropy_with_logits_v2( # pylint: disable=invalid-name
labels=None, labels=None,
logits=None, logits=None,
@ -242,6 +246,7 @@ def sigmoid_cross_entropy_with_logits_v2( # pylint: disable=invalid-name
@tf_export("nn.weighted_cross_entropy_with_logits", v1=[]) @tf_export("nn.weighted_cross_entropy_with_logits", v1=[])
@dispatch.add_dispatch_support
def weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight, def weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight,
name=None): name=None):
"""Computes a weighted cross entropy. """Computes a weighted cross entropy.
@ -320,6 +325,7 @@ def weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight,
@tf_export(v1=["nn.weighted_cross_entropy_with_logits"]) @tf_export(v1=["nn.weighted_cross_entropy_with_logits"])
@dispatch.add_dispatch_support
@deprecated_args(None, "targets is deprecated, use labels instead", "targets") @deprecated_args(None, "targets is deprecated, use labels instead", "targets")
def weighted_cross_entropy_with_logits(labels=None, def weighted_cross_entropy_with_logits(labels=None,
logits=None, logits=None,
@ -384,6 +390,7 @@ def weighted_cross_entropy_with_logits(labels=None,
@tf_export("nn.compute_average_loss") @tf_export("nn.compute_average_loss")
@dispatch.add_dispatch_support
def compute_average_loss(per_example_loss, def compute_average_loss(per_example_loss,
sample_weight=None, sample_weight=None,
global_batch_size=None): global_batch_size=None):
@ -440,6 +447,7 @@ def compute_average_loss(per_example_loss,
@tf_export("nn.scale_regularization_loss") @tf_export("nn.scale_regularization_loss")
@dispatch.add_dispatch_support
def scale_regularization_loss(regularization_loss): def scale_regularization_loss(regularization_loss):
"""Scales the sum of the given regularization losses by number of replicas. """Scales the sum of the given regularization losses by number of replicas.
@ -478,6 +486,7 @@ def scale_regularization_loss(regularization_loss):
@tf_export(v1=["nn.relu_layer"]) @tf_export(v1=["nn.relu_layer"])
@dispatch.add_dispatch_support
def relu_layer(x, weights, biases, name=None): def relu_layer(x, weights, biases, name=None):
"""Computes Relu(x * weight + biases). """Computes Relu(x * weight + biases).
@ -501,6 +510,7 @@ def relu_layer(x, weights, biases, name=None):
@tf_export("nn.swish") @tf_export("nn.swish")
@dispatch.add_dispatch_support
@custom_gradient.custom_gradient @custom_gradient.custom_gradient
def swish(features): def swish(features):
# pylint: disable=g-doc-args # pylint: disable=g-doc-args
@ -538,6 +548,7 @@ def swish(features):
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
@tf_export("linalg.normalize") @tf_export("linalg.normalize")
@dispatch.add_dispatch_support
def normalize(tensor, ord="euclidean", axis=None, name=None): def normalize(tensor, ord="euclidean", axis=None, name=None):
"""Normalizes `tensor` along dimension `axis` using specified norm. """Normalizes `tensor` along dimension `axis` using specified norm.
@ -590,6 +601,7 @@ def normalize(tensor, ord="euclidean", axis=None, name=None):
@tf_export(v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"]) @tf_export(v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"])
@dispatch.add_dispatch_support
@deprecated_args(None, "dim is deprecated, use axis instead", "dim") @deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None): def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
"""Normalizes along dimension `axis` using an L2 norm. """Normalizes along dimension `axis` using an L2 norm.
@ -618,6 +630,7 @@ def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
@tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize", v1=[]) @tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize", v1=[])
@dispatch.add_dispatch_support
def l2_normalize_v2(x, axis=None, epsilon=1e-12, name=None): def l2_normalize_v2(x, axis=None, epsilon=1e-12, name=None):
"""Normalizes along dimension `axis` using an L2 norm. """Normalizes along dimension `axis` using an L2 norm.
@ -668,6 +681,7 @@ def _count_nonzero(input_tensor, dtype=dtypes.int64):
@tf_export("math.zero_fraction", "nn.zero_fraction") @tf_export("math.zero_fraction", "nn.zero_fraction")
@dispatch.add_dispatch_support
def zero_fraction(value, name=None): def zero_fraction(value, name=None):
"""Returns the fraction of zeros in `value`. """Returns the fraction of zeros in `value`.
@ -710,6 +724,7 @@ def zero_fraction(value, name=None):
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
@tf_export(v1=["nn.depthwise_conv2d"]) @tf_export(v1=["nn.depthwise_conv2d"])
@dispatch.add_dispatch_support
def depthwise_conv2d(input, def depthwise_conv2d(input,
filter, filter,
strides, strides,
@ -838,6 +853,7 @@ def depthwise_conv2d(input,
@tf_export("nn.depthwise_conv2d", v1=[]) @tf_export("nn.depthwise_conv2d", v1=[])
@dispatch.add_dispatch_support
def depthwise_conv2d_v2(input, def depthwise_conv2d_v2(input,
filter, filter,
strides, strides,
@ -935,6 +951,7 @@ def depthwise_conv2d_v2(input,
# pylint: disable=redefined-builtin,line-too-long # pylint: disable=redefined-builtin,line-too-long
@tf_export(v1=["nn.separable_conv2d"]) @tf_export(v1=["nn.separable_conv2d"])
@dispatch.add_dispatch_support
def separable_conv2d(input, def separable_conv2d(input,
depthwise_filter, depthwise_filter,
pointwise_filter, pointwise_filter,
@ -1042,6 +1059,7 @@ def separable_conv2d(input,
@tf_export("nn.separable_conv2d", v1=[]) @tf_export("nn.separable_conv2d", v1=[])
@dispatch.add_dispatch_support
def separable_conv2d_v2( def separable_conv2d_v2(
input, input,
depthwise_filter, depthwise_filter,
@ -1117,6 +1135,7 @@ def separable_conv2d_v2(
@tf_export(v1=["nn.sufficient_statistics"]) @tf_export(v1=["nn.sufficient_statistics"])
@dispatch.add_dispatch_support
def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None, def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None,
keepdims=None): keepdims=None):
"""Calculate the sufficient statistics for the mean and variance of `x`. """Calculate the sufficient statistics for the mean and variance of `x`.
@ -1174,6 +1193,7 @@ def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None,
@tf_export("nn.sufficient_statistics", v1=[]) @tf_export("nn.sufficient_statistics", v1=[])
@dispatch.add_dispatch_support
def sufficient_statistics_v2(x, axes, shift=None, keepdims=False, name=None): def sufficient_statistics_v2(x, axes, shift=None, keepdims=False, name=None):
"""Calculate the sufficient statistics for the mean and variance of `x`. """Calculate the sufficient statistics for the mean and variance of `x`.
@ -1203,6 +1223,7 @@ def sufficient_statistics_v2(x, axes, shift=None, keepdims=False, name=None):
@tf_export("nn.normalize_moments") @tf_export("nn.normalize_moments")
@dispatch.add_dispatch_support
def normalize_moments(counts, mean_ss, variance_ss, shift, name=None): def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
"""Calculate the mean and variance of based on the sufficient statistics. """Calculate the mean and variance of based on the sufficient statistics.
@ -1235,6 +1256,7 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
@tf_export(v1=["nn.moments"]) @tf_export(v1=["nn.moments"])
@dispatch.add_dispatch_support
def moments( def moments(
x, x,
axes, axes,
@ -1300,6 +1322,7 @@ def moments(
@tf_export("nn.moments", v1=[]) @tf_export("nn.moments", v1=[])
@dispatch.add_dispatch_support
def moments_v2( def moments_v2(
x, x,
axes, axes,
@ -1336,6 +1359,7 @@ def moments_v2(
@tf_export(v1=["nn.weighted_moments"]) @tf_export(v1=["nn.weighted_moments"])
@dispatch.add_dispatch_support
def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=None, def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=None,
keepdims=None): keepdims=None):
"""Returns the frequency-weighted mean and variance of `x`. """Returns the frequency-weighted mean and variance of `x`.
@ -1414,6 +1438,7 @@ def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=None,
@tf_export("nn.weighted_moments", v1=[]) @tf_export("nn.weighted_moments", v1=[])
@dispatch.add_dispatch_support
def weighted_moments_v2(x, axes, frequency_weights, keepdims=False, name=None): def weighted_moments_v2(x, axes, frequency_weights, keepdims=False, name=None):
"""Returns the frequency-weighted mean and variance of `x`. """Returns the frequency-weighted mean and variance of `x`.
@ -1438,6 +1463,7 @@ def weighted_moments_v2(x, axes, frequency_weights, keepdims=False, name=None):
@tf_export("nn.batch_normalization") @tf_export("nn.batch_normalization")
@dispatch.add_dispatch_support
def batch_normalization(x, def batch_normalization(x,
mean, mean,
variance, variance,
@ -1508,6 +1534,7 @@ def batch_normalization(x,
@tf_export(v1=["nn.fused_batch_norm"]) @tf_export(v1=["nn.fused_batch_norm"])
@dispatch.add_dispatch_support
def fused_batch_norm( def fused_batch_norm(
x, x,
scale, scale,
@ -1631,6 +1658,7 @@ def fused_batch_norm(
@tf_export(v1=["nn.batch_norm_with_global_normalization"]) @tf_export(v1=["nn.batch_norm_with_global_normalization"])
@dispatch.add_dispatch_support
def batch_norm_with_global_normalization(t=None, def batch_norm_with_global_normalization(t=None,
m=None, m=None,
v=None, v=None,
@ -1685,6 +1713,7 @@ def batch_norm_with_global_normalization(t=None,
# pylint: disable=redefined-builtin,line-too-long # pylint: disable=redefined-builtin,line-too-long
@tf_export("nn.batch_norm_with_global_normalization", v1=[]) @tf_export("nn.batch_norm_with_global_normalization", v1=[])
@dispatch.add_dispatch_support
def batch_norm_with_global_normalization_v2(input, def batch_norm_with_global_normalization_v2(input,
mean, mean,
variance, variance,
@ -1934,6 +1963,7 @@ def _compute_sampled_logits(weights,
@tf_export("nn.nce_loss", v1=[]) @tf_export("nn.nce_loss", v1=[])
@dispatch.add_dispatch_support
def nce_loss_v2(weights, def nce_loss_v2(weights,
biases, biases,
labels, labels,
@ -2038,6 +2068,7 @@ def nce_loss_v2(weights,
@tf_export(v1=["nn.nce_loss"]) @tf_export(v1=["nn.nce_loss"])
@dispatch.add_dispatch_support
def nce_loss(weights, def nce_loss(weights,
biases, biases,
labels, labels,
@ -2149,6 +2180,7 @@ def nce_loss(weights,
@tf_export("nn.sampled_softmax_loss", v1=[]) @tf_export("nn.sampled_softmax_loss", v1=[])
@dispatch.add_dispatch_support
def sampled_softmax_loss_v2(weights, def sampled_softmax_loss_v2(weights,
biases, biases,
labels, labels,
@ -2240,6 +2272,7 @@ def sampled_softmax_loss_v2(weights,
@tf_export(v1=["nn.sampled_softmax_loss"]) @tf_export(v1=["nn.sampled_softmax_loss"])
@dispatch.add_dispatch_support
def sampled_softmax_loss(weights, def sampled_softmax_loss(weights,
biases, biases,
labels, labels,

View File

@ -239,6 +239,7 @@ class _NonAtrousConvolution(object):
@tf_export("nn.dilation2d", v1=[]) @tf_export("nn.dilation2d", v1=[])
@dispatch.add_dispatch_support
def dilation2d_v2( def dilation2d_v2(
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
filters, # pylint: disable=redefined-builtin filters, # pylint: disable=redefined-builtin
@ -306,6 +307,7 @@ def dilation2d_v2(
@tf_export(v1=["nn.dilation2d"]) @tf_export(v1=["nn.dilation2d"])
@dispatch.add_dispatch_support
def dilation2d_v1( # pylint: disable=missing-docstring def dilation2d_v1( # pylint: disable=missing-docstring
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
filter=None, # pylint: disable=redefined-builtin filter=None, # pylint: disable=redefined-builtin
@ -324,6 +326,7 @@ dilation2d_v1.__doc__ = gen_nn_ops.dilation2d.__doc__
@tf_export("nn.with_space_to_batch") @tf_export("nn.with_space_to_batch")
@dispatch.add_dispatch_support
def with_space_to_batch( def with_space_to_batch(
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
dilation_rate, dilation_rate,
@ -772,6 +775,7 @@ def _get_strides_and_dilation_rate(num_spatial_dims, strides, dilation_rate):
@tf_export(v1=["nn.convolution"]) @tf_export(v1=["nn.convolution"])
@dispatch.add_dispatch_support
def convolution( def convolution(
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
filter, # pylint: disable=redefined-builtin filter, # pylint: disable=redefined-builtin
@ -907,7 +911,8 @@ def convolution(
@tf_export("nn.convolution", v1=[]) @tf_export("nn.convolution", v1=[])
def convolution_v2( @dispatch.add_dispatch_support
def convolution_v2( # pylint: disable=missing-docstring
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
filters, filters,
strides=None, strides=None,
@ -1116,6 +1121,7 @@ class Convolution(object):
@tf_export(v1=["nn.pool"]) @tf_export(v1=["nn.pool"])
@dispatch.add_dispatch_support
def pool( def pool(
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
window_shape, window_shape,
@ -1290,6 +1296,7 @@ def pool(
@tf_export("nn.pool", v1=[]) @tf_export("nn.pool", v1=[])
@dispatch.add_dispatch_support
def pool_v2( def pool_v2(
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
window_shape, window_shape,
@ -1389,6 +1396,7 @@ def pool_v2(
@tf_export("nn.atrous_conv2d") @tf_export("nn.atrous_conv2d")
@dispatch.add_dispatch_support
def atrous_conv2d(value, filters, rate, padding, name=None): def atrous_conv2d(value, filters, rate, padding, name=None):
"""Atrous convolution (a.k.a. convolution with holes or dilated convolution). """Atrous convolution (a.k.a. convolution with holes or dilated convolution).
@ -1576,6 +1584,7 @@ def convert_padding(padding):
@tf_export(v1=["nn.conv1d"]) @tf_export(v1=["nn.conv1d"])
@dispatch.add_dispatch_support
@deprecation.deprecated_arg_values( @deprecation.deprecated_arg_values(
None, None,
"`NCHW` for data_format is deprecated, use `NCW` instead", "`NCHW` for data_format is deprecated, use `NCW` instead",
@ -1674,6 +1683,7 @@ def conv1d(
@tf_export("nn.conv1d", v1=[]) @tf_export("nn.conv1d", v1=[])
@dispatch.add_dispatch_support
def conv1d_v2( def conv1d_v2(
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
filters, filters,
@ -1739,6 +1749,7 @@ def conv1d_v2(
@tf_export("nn.conv1d_transpose") @tf_export("nn.conv1d_transpose")
@dispatch.add_dispatch_support
def conv1d_transpose( def conv1d_transpose(
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
filters, filters,
@ -1827,6 +1838,7 @@ def conv1d_transpose(
@tf_export("nn.conv2d", v1=[]) @tf_export("nn.conv2d", v1=[])
@dispatch.add_dispatch_support
def conv2d_v2(input, # pylint: disable=redefined-builtin def conv2d_v2(input, # pylint: disable=redefined-builtin
filters, filters,
strides, strides,
@ -1927,6 +1939,7 @@ def conv2d_v2(input, # pylint: disable=redefined-builtin
@tf_export(v1=["nn.conv2d"]) @tf_export(v1=["nn.conv2d"])
@dispatch.add_dispatch_support
def conv2d( # pylint: disable=redefined-builtin,dangerous-default-value def conv2d( # pylint: disable=redefined-builtin,dangerous-default-value
input, input,
filter=None, filter=None,
@ -2024,6 +2037,7 @@ def conv2d( # pylint: disable=redefined-builtin,dangerous-default-value
@tf_export(v1=["nn.conv2d_backprop_filter"]) @tf_export(v1=["nn.conv2d_backprop_filter"])
@dispatch.add_dispatch_support
def conv2d_backprop_filter( # pylint: disable=redefined-builtin,dangerous-default-value def conv2d_backprop_filter( # pylint: disable=redefined-builtin,dangerous-default-value
input, input,
filter_sizes, filter_sizes,
@ -2084,6 +2098,7 @@ def conv2d_backprop_filter( # pylint: disable=redefined-builtin,dangerous-defau
@tf_export(v1=["nn.conv2d_backprop_input"]) @tf_export(v1=["nn.conv2d_backprop_input"])
@dispatch.add_dispatch_support
def conv2d_backprop_input( # pylint: disable=redefined-builtin,dangerous-default-value def conv2d_backprop_input( # pylint: disable=redefined-builtin,dangerous-default-value
input_sizes, input_sizes,
filter=None, filter=None,
@ -2148,6 +2163,7 @@ def conv2d_backprop_input( # pylint: disable=redefined-builtin,dangerous-defaul
@tf_export(v1=["nn.conv2d_transpose"]) @tf_export(v1=["nn.conv2d_transpose"])
@dispatch.add_dispatch_support
def conv2d_transpose( def conv2d_transpose(
value=None, value=None,
filter=None, # pylint: disable=redefined-builtin filter=None, # pylint: disable=redefined-builtin
@ -2224,6 +2240,7 @@ def conv2d_transpose(
@tf_export("nn.conv2d_transpose", v1=[]) @tf_export("nn.conv2d_transpose", v1=[])
@dispatch.add_dispatch_support
def conv2d_transpose_v2( def conv2d_transpose_v2(
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
filters, # pylint: disable=redefined-builtin filters, # pylint: disable=redefined-builtin
@ -2301,6 +2318,7 @@ def conv2d_transpose_v2(
@tf_export("nn.atrous_conv2d_transpose") @tf_export("nn.atrous_conv2d_transpose")
@dispatch.add_dispatch_support
def atrous_conv2d_transpose(value, def atrous_conv2d_transpose(value,
filters, filters,
output_shape, output_shape,
@ -2459,6 +2477,7 @@ def atrous_conv2d_transpose(value,
@tf_export(v1=["nn.depthwise_conv2d_native"]) @tf_export(v1=["nn.depthwise_conv2d_native"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("nn.depthwise_conv2d_native") @deprecation.deprecated_endpoints("nn.depthwise_conv2d_native")
def depthwise_conv2d_native( # pylint: disable=redefined-builtin,dangerous-default-value def depthwise_conv2d_native( # pylint: disable=redefined-builtin,dangerous-default-value
input, input,
@ -2538,6 +2557,7 @@ def depthwise_conv2d_native( # pylint: disable=redefined-builtin,dangerous-defa
"nn.depthwise_conv2d_native_backprop_input", "nn.depthwise_conv2d_native_backprop_input",
"nn.depthwise_conv2d_backprop_input" "nn.depthwise_conv2d_backprop_input"
]) ])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("nn.depthwise_conv2d_native_backprop_input") @deprecation.deprecated_endpoints("nn.depthwise_conv2d_native_backprop_input")
def depthwise_conv2d_native_backprop_input( # pylint: disable=redefined-builtin,dangerous-default-value def depthwise_conv2d_native_backprop_input( # pylint: disable=redefined-builtin,dangerous-default-value
input_sizes, input_sizes,
@ -2607,6 +2627,7 @@ def depthwise_conv2d_native_backprop_input( # pylint: disable=redefined-builtin
"nn.depthwise_conv2d_native_backprop_filter", "nn.depthwise_conv2d_native_backprop_filter",
"nn.depthwise_conv2d_backprop_filter" "nn.depthwise_conv2d_backprop_filter"
]) ])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("nn.depthwise_conv2d_native_backprop_filter") @deprecation.deprecated_endpoints("nn.depthwise_conv2d_native_backprop_filter")
def depthwise_conv2d_native_backprop_filter( # pylint: disable=redefined-builtin,dangerous-default-value def depthwise_conv2d_native_backprop_filter( # pylint: disable=redefined-builtin,dangerous-default-value
input, input,
@ -2672,6 +2693,7 @@ def depthwise_conv2d_native_backprop_filter( # pylint: disable=redefined-builti
@tf_export("nn.conv3d", v1=[]) @tf_export("nn.conv3d", v1=[])
@dispatch.add_dispatch_support
def conv3d_v2(input, # pylint: disable=redefined-builtin,missing-docstring def conv3d_v2(input, # pylint: disable=redefined-builtin,missing-docstring
filters, filters,
strides, strides,
@ -2691,6 +2713,7 @@ def conv3d_v2(input, # pylint: disable=redefined-builtin,missing-docstring
@tf_export(v1=["nn.conv3d"]) @tf_export(v1=["nn.conv3d"])
@dispatch.add_dispatch_support
def conv3d_v1( # pylint: disable=missing-docstring,dangerous-default-value def conv3d_v1( # pylint: disable=missing-docstring,dangerous-default-value
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
filter=None, # pylint: disable=redefined-builtin filter=None, # pylint: disable=redefined-builtin
@ -2711,6 +2734,7 @@ conv3d_v1.__doc__ = gen_nn_ops.conv3d.__doc__
@tf_export(v1=["nn.conv3d_transpose"]) @tf_export(v1=["nn.conv3d_transpose"])
@dispatch.add_dispatch_support
def conv3d_transpose( def conv3d_transpose(
value, value,
filter=None, # pylint: disable=redefined-builtin filter=None, # pylint: disable=redefined-builtin
@ -2782,6 +2806,7 @@ def conv3d_transpose(
@tf_export("nn.conv3d_transpose", v1=[]) @tf_export("nn.conv3d_transpose", v1=[])
@dispatch.add_dispatch_support
def conv3d_transpose_v2(input, # pylint: disable=redefined-builtin def conv3d_transpose_v2(input, # pylint: disable=redefined-builtin
filters, filters,
output_shape, output_shape,
@ -2861,6 +2886,7 @@ CONV_TRANSPOSE_OPS = (
@tf_export("nn.conv_transpose") @tf_export("nn.conv_transpose")
@dispatch.add_dispatch_support
def conv_transpose(input, # pylint: disable=redefined-builtin def conv_transpose(input, # pylint: disable=redefined-builtin
filters, filters,
output_shape, output_shape,
@ -2958,6 +2984,7 @@ _tf_deterministic_ops.value = None
@tf_export("nn.bias_add") @tf_export("nn.bias_add")
@dispatch.add_dispatch_support
def bias_add(value, bias, data_format=None, name=None): def bias_add(value, bias, data_format=None, name=None):
"""Adds `bias` to `value`. """Adds `bias` to `value`.
@ -3047,6 +3074,7 @@ def bias_add_v1(value, bias, name=None):
@tf_export(v1=["nn.crelu"]) @tf_export(v1=["nn.crelu"])
@dispatch.add_dispatch_support
def crelu(features, name=None, axis=-1): def crelu(features, name=None, axis=-1):
"""Computes Concatenated ReLU. """Computes Concatenated ReLU.
@ -3079,12 +3107,14 @@ def crelu(features, name=None, axis=-1):
@tf_export("nn.crelu", v1=[]) @tf_export("nn.crelu", v1=[])
@dispatch.add_dispatch_support
def crelu_v2(features, axis=-1, name=None): def crelu_v2(features, axis=-1, name=None):
return crelu(features, name=name, axis=axis) return crelu(features, name=name, axis=axis)
crelu_v2.__doc__ = crelu.__doc__ crelu_v2.__doc__ = crelu.__doc__
@tf_export("nn.relu6") @tf_export("nn.relu6")
@dispatch.add_dispatch_support
def relu6(features, name=None): def relu6(features, name=None):
"""Computes Rectified Linear 6: `min(max(features, 0), 6)`. """Computes Rectified Linear 6: `min(max(features, 0), 6)`.
@ -3107,6 +3137,7 @@ def relu6(features, name=None):
@tf_export("nn.leaky_relu") @tf_export("nn.leaky_relu")
@dispatch.add_dispatch_support
def leaky_relu(features, alpha=0.2, name=None): def leaky_relu(features, alpha=0.2, name=None):
"""Compute the Leaky ReLU activation function. """Compute the Leaky ReLU activation function.
@ -3245,6 +3276,7 @@ def _softmax(logits, compute_op, dim=-1, name=None):
@tf_export(v1=["nn.softmax", "math.softmax"]) @tf_export(v1=["nn.softmax", "math.softmax"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim") @deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def softmax(logits, axis=None, name=None, dim=None): def softmax(logits, axis=None, name=None, dim=None):
"""Computes softmax activations. """Computes softmax activations.
@ -3289,6 +3321,7 @@ def softmax(logits, axis=None, name=None, dim=None):
@tf_export("nn.softmax", "math.softmax", v1=[]) @tf_export("nn.softmax", "math.softmax", v1=[])
@dispatch.add_dispatch_support
def softmax_v2(logits, axis=None, name=None): def softmax_v2(logits, axis=None, name=None):
"""Computes softmax activations. """Computes softmax activations.
@ -3316,6 +3349,7 @@ def softmax_v2(logits, axis=None, name=None):
@tf_export(v1=["nn.log_softmax", "math.log_softmax"]) @tf_export(v1=["nn.log_softmax", "math.log_softmax"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim") @deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def log_softmax(logits, axis=None, name=None, dim=None): def log_softmax(logits, axis=None, name=None, dim=None):
"""Computes log softmax activations. """Computes log softmax activations.
@ -3346,6 +3380,7 @@ def log_softmax(logits, axis=None, name=None, dim=None):
@tf_export("nn.log_softmax", "math.log_softmax", v1=[]) @tf_export("nn.log_softmax", "math.log_softmax", v1=[])
@dispatch.add_dispatch_support
def log_softmax_v2(logits, axis=None, name=None): def log_softmax_v2(logits, axis=None, name=None):
"""Computes log softmax activations. """Computes log softmax activations.
@ -3382,6 +3417,7 @@ def _ensure_xent_args(name, sentinel, labels, logits):
@tf_export("nn.softmax_cross_entropy_with_logits", v1=[]) @tf_export("nn.softmax_cross_entropy_with_logits", v1=[])
@dispatch.add_dispatch_support
def softmax_cross_entropy_with_logits_v2(labels, logits, axis=-1, name=None): def softmax_cross_entropy_with_logits_v2(labels, logits, axis=-1, name=None):
"""Computes softmax cross entropy between `logits` and `labels`. """Computes softmax cross entropy between `logits` and `labels`.
@ -3444,6 +3480,7 @@ def softmax_cross_entropy_with_logits_v2(labels, logits, axis=-1, name=None):
@tf_export(v1=["nn.softmax_cross_entropy_with_logits_v2"]) @tf_export(v1=["nn.softmax_cross_entropy_with_logits_v2"])
@dispatch.add_dispatch_support
@deprecated_args(None, "dim is deprecated, use axis instead", "dim") @deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def softmax_cross_entropy_with_logits_v2_helper( def softmax_cross_entropy_with_logits_v2_helper(
labels, logits, axis=None, name=None, dim=None): labels, logits, axis=None, name=None, dim=None):
@ -3571,6 +3608,7 @@ See `tf.nn.softmax_cross_entropy_with_logits_v2`.
@tf_export(v1=["nn.softmax_cross_entropy_with_logits"]) @tf_export(v1=["nn.softmax_cross_entropy_with_logits"])
@dispatch.add_dispatch_support
@deprecation.deprecated(date=None, instructions=_XENT_DEPRECATION) @deprecation.deprecated(date=None, instructions=_XENT_DEPRECATION)
def softmax_cross_entropy_with_logits( def softmax_cross_entropy_with_logits(
_sentinel=None, # pylint: disable=invalid-name _sentinel=None, # pylint: disable=invalid-name
@ -3639,6 +3677,7 @@ def softmax_cross_entropy_with_logits(
@tf_export(v1=["nn.sparse_softmax_cross_entropy_with_logits"]) @tf_export(v1=["nn.sparse_softmax_cross_entropy_with_logits"])
@dispatch.add_dispatch_support
def sparse_softmax_cross_entropy_with_logits( def sparse_softmax_cross_entropy_with_logits(
_sentinel=None, # pylint: disable=invalid-name _sentinel=None, # pylint: disable=invalid-name
labels=None, labels=None,
@ -3764,6 +3803,7 @@ def sparse_softmax_cross_entropy_with_logits(
@tf_export("nn.sparse_softmax_cross_entropy_with_logits", v1=[]) @tf_export("nn.sparse_softmax_cross_entropy_with_logits", v1=[])
@dispatch.add_dispatch_support
def sparse_softmax_cross_entropy_with_logits_v2(labels, logits, name=None): def sparse_softmax_cross_entropy_with_logits_v2(labels, logits, name=None):
"""Computes sparse softmax cross entropy between `logits` and `labels`. """Computes sparse softmax cross entropy between `logits` and `labels`.
@ -3816,6 +3856,7 @@ def sparse_softmax_cross_entropy_with_logits_v2(labels, logits, name=None):
@tf_export("nn.avg_pool", v1=["nn.avg_pool_v2"]) @tf_export("nn.avg_pool", v1=["nn.avg_pool_v2"])
@dispatch.add_dispatch_support
def avg_pool_v2(input, ksize, strides, padding, data_format=None, name=None): # pylint: disable=redefined-builtin def avg_pool_v2(input, ksize, strides, padding, data_format=None, name=None): # pylint: disable=redefined-builtin
"""Performs the avg pooling on the input. """Performs the avg pooling on the input.
@ -3878,6 +3919,7 @@ def avg_pool_v2(input, ksize, strides, padding, data_format=None, name=None): #
@tf_export(v1=["nn.avg_pool", "nn.avg_pool2d"]) @tf_export(v1=["nn.avg_pool", "nn.avg_pool2d"])
@dispatch.add_dispatch_support
def avg_pool(value, ksize, strides, padding, data_format="NHWC", def avg_pool(value, ksize, strides, padding, data_format="NHWC",
name=None, input=None): # pylint: disable=redefined-builtin name=None, input=None): # pylint: disable=redefined-builtin
"""Performs the average pooling on the input. """Performs the average pooling on the input.
@ -3922,6 +3964,7 @@ def avg_pool(value, ksize, strides, padding, data_format="NHWC",
@tf_export("nn.avg_pool2d", v1=[]) @tf_export("nn.avg_pool2d", v1=[])
@dispatch.add_dispatch_support
def avg_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None): # pylint: disable=redefined-builtin def avg_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None): # pylint: disable=redefined-builtin
"""Performs the average pooling on the input. """Performs the average pooling on the input.
@ -3961,6 +4004,7 @@ def avg_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None):
@tf_export("nn.avg_pool1d") @tf_export("nn.avg_pool1d")
@dispatch.add_dispatch_support
def avg_pool1d(input, ksize, strides, padding, data_format="NWC", name=None): # pylint: disable=redefined-builtin def avg_pool1d(input, ksize, strides, padding, data_format="NWC", name=None): # pylint: disable=redefined-builtin
"""Performs the average pooling on the input. """Performs the average pooling on the input.
@ -4006,6 +4050,7 @@ def avg_pool1d(input, ksize, strides, padding, data_format="NWC", name=None): #
@tf_export("nn.avg_pool3d") @tf_export("nn.avg_pool3d")
@dispatch.add_dispatch_support
def avg_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None): # pylint: disable=redefined-builtin def avg_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None): # pylint: disable=redefined-builtin
"""Performs the average pooling on the input. """Performs the average pooling on the input.
@ -4046,6 +4091,7 @@ def avg_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None):
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
@tf_export("nn.max_pool", v1=["nn.max_pool_v2"]) @tf_export("nn.max_pool", v1=["nn.max_pool_v2"])
@dispatch.add_dispatch_support
def max_pool_v2(input, ksize, strides, padding, data_format=None, name=None): def max_pool_v2(input, ksize, strides, padding, data_format=None, name=None):
"""Performs the max pooling on the input. """Performs the max pooling on the input.
@ -4106,6 +4152,7 @@ def max_pool_v2(input, ksize, strides, padding, data_format=None, name=None):
@tf_export(v1=["nn.max_pool"]) @tf_export(v1=["nn.max_pool"])
@dispatch.add_dispatch_support
def max_pool(value, def max_pool(value,
ksize, ksize,
strides, strides,
@ -4155,6 +4202,7 @@ def max_pool(value,
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
@tf_export("nn.max_pool1d") @tf_export("nn.max_pool1d")
@dispatch.add_dispatch_support
def max_pool1d(input, ksize, strides, padding, data_format="NWC", name=None): def max_pool1d(input, ksize, strides, padding, data_format="NWC", name=None):
"""Performs the max pooling on the input. """Performs the max pooling on the input.
@ -4199,6 +4247,7 @@ def max_pool1d(input, ksize, strides, padding, data_format="NWC", name=None):
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
@tf_export("nn.max_pool2d") @tf_export("nn.max_pool2d")
@dispatch.add_dispatch_support
def max_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None): def max_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None):
"""Performs the max pooling on the input. """Performs the max pooling on the input.
@ -4237,6 +4286,7 @@ def max_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None):
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
@tf_export("nn.max_pool3d") @tf_export("nn.max_pool3d")
@dispatch.add_dispatch_support
def max_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None): def max_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None):
"""Performs the max pooling on the input. """Performs the max pooling on the input.
@ -4279,6 +4329,7 @@ def max_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None):
@tf_export("nn.max_pool_with_argmax", v1=[]) @tf_export("nn.max_pool_with_argmax", v1=[])
@dispatch.add_dispatch_support
def max_pool_with_argmax_v2( def max_pool_with_argmax_v2(
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
ksize, ksize,
@ -4348,6 +4399,7 @@ def max_pool_with_argmax_v2(
@tf_export(v1=["nn.max_pool_with_argmax"]) @tf_export(v1=["nn.max_pool_with_argmax"])
@dispatch.add_dispatch_support
def max_pool_with_argmax_v1( # pylint: disable=missing-docstring,invalid-name def max_pool_with_argmax_v1( # pylint: disable=missing-docstring,invalid-name
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
ksize, ksize,
@ -4442,6 +4494,7 @@ def _calc_bias_add_flops(graph, node):
@tf_export(v1=["nn.xw_plus_b"]) @tf_export(v1=["nn.xw_plus_b"])
@dispatch.add_dispatch_support
def xw_plus_b(x, weights, biases, name=None): # pylint: disable=invalid-name def xw_plus_b(x, weights, biases, name=None): # pylint: disable=invalid-name
"""Computes matmul(x, weights) + biases. """Computes matmul(x, weights) + biases.
@ -4691,6 +4744,7 @@ def dropout_v2(x, rate, noise_shape=None, seed=None, name=None):
@tf_export("math.top_k", "nn.top_k") @tf_export("math.top_k", "nn.top_k")
@dispatch.add_dispatch_support
def top_k(input, k=1, sorted=True, name=None): # pylint: disable=redefined-builtin def top_k(input, k=1, sorted=True, name=None): # pylint: disable=redefined-builtin
"""Finds values and indices of the `k` largest entries for the last dimension. """Finds values and indices of the `k` largest entries for the last dimension.
@ -4751,6 +4805,7 @@ def nth_element(input, n, reverse=False, name=None): # pylint: disable=redefine
@tf_export(v1=["nn.fractional_max_pool"]) @tf_export(v1=["nn.fractional_max_pool"])
@dispatch.add_dispatch_support
@deprecation.deprecated(date=None, instructions="`seed2` and `deterministic` " @deprecation.deprecated(date=None, instructions="`seed2` and `deterministic` "
"args are deprecated. Use fractional_max_pool_v2.") "args are deprecated. Use fractional_max_pool_v2.")
def fractional_max_pool(value, def fractional_max_pool(value,
@ -4837,6 +4892,7 @@ def fractional_max_pool(value,
@tf_export("nn.fractional_max_pool", v1=[]) @tf_export("nn.fractional_max_pool", v1=[])
@dispatch.add_dispatch_support
def fractional_max_pool_v2(value, def fractional_max_pool_v2(value,
pooling_ratio, pooling_ratio,
pseudo_random=False, pseudo_random=False,
@ -4922,6 +4978,7 @@ def fractional_max_pool_v2(value,
@tf_export(v1=["nn.fractional_avg_pool"]) @tf_export(v1=["nn.fractional_avg_pool"])
@dispatch.add_dispatch_support
@deprecation.deprecated(date=None, instructions="`seed2` and `deterministic` " @deprecation.deprecated(date=None, instructions="`seed2` and `deterministic` "
"args are deprecated. Use fractional_avg_pool_v2.") "args are deprecated. Use fractional_avg_pool_v2.")
def fractional_avg_pool(value, def fractional_avg_pool(value,
@ -4987,6 +5044,7 @@ def fractional_avg_pool(value,
@tf_export("nn.fractional_avg_pool", v1=[]) @tf_export("nn.fractional_avg_pool", v1=[])
@dispatch.add_dispatch_support
def fractional_avg_pool_v2(value, def fractional_avg_pool_v2(value,
pooling_ratio, pooling_ratio,
pseudo_random=False, pseudo_random=False,
@ -5065,6 +5123,7 @@ def _calc_dilation2d_flops(graph, node):
@tf_export(v1=["nn.erosion2d"]) @tf_export(v1=["nn.erosion2d"])
@dispatch.add_dispatch_support
def erosion2d(value, kernel, strides, rates, padding, name=None): def erosion2d(value, kernel, strides, rates, padding, name=None):
"""Computes the grayscale erosion of 4-D `value` and 3-D `kernel` tensors. """Computes the grayscale erosion of 4-D `value` and 3-D `kernel` tensors.
@ -5124,6 +5183,7 @@ def erosion2d(value, kernel, strides, rates, padding, name=None):
@tf_export("nn.erosion2d", v1=[]) @tf_export("nn.erosion2d", v1=[])
@dispatch.add_dispatch_support
def erosion2d_v2(value, def erosion2d_v2(value,
filters, filters,
strides, strides,
@ -5193,6 +5253,7 @@ def erosion2d_v2(value,
@tf_export(v1=["math.in_top_k", "nn.in_top_k"]) @tf_export(v1=["math.in_top_k", "nn.in_top_k"])
@dispatch.add_dispatch_support
def in_top_k(predictions, targets, k, name=None): def in_top_k(predictions, targets, k, name=None):
r"""Says whether the targets are in the top `K` predictions. r"""Says whether the targets are in the top `K` predictions.
@ -5227,6 +5288,7 @@ def in_top_k(predictions, targets, k, name=None):
@tf_export("math.in_top_k", "nn.in_top_k", v1=[]) @tf_export("math.in_top_k", "nn.in_top_k", v1=[])
@dispatch.add_dispatch_support
def in_top_k_v2(targets, predictions, k, name=None): def in_top_k_v2(targets, predictions, k, name=None):
return in_top_k(predictions, targets, k, name) return in_top_k(predictions, targets, k, name)
@ -5234,7 +5296,11 @@ def in_top_k_v2(targets, predictions, k, name=None):
in_top_k_v2.__doc__ = in_top_k.__doc__ in_top_k_v2.__doc__ = in_top_k.__doc__
tf_export(v1=["nn.quantized_avg_pool"])(gen_nn_ops.quantized_avg_pool) tf_export(v1=["nn.quantized_avg_pool"])(
tf_export(v1=["nn.quantized_conv2d"])(gen_nn_ops.quantized_conv2d) dispatch.add_dispatch_support(gen_nn_ops.quantized_avg_pool))
tf_export(v1=["nn.quantized_relu_x"])(gen_nn_ops.quantized_relu_x) tf_export(v1=["nn.quantized_conv2d"])(
tf_export(v1=["nn.quantized_max_pool"])(gen_nn_ops.quantized_max_pool) dispatch.add_dispatch_support(gen_nn_ops.quantized_conv2d))
tf_export(v1=["nn.quantized_relu_x"])(
dispatch.add_dispatch_support(gen_nn_ops.quantized_relu_x))
tf_export(v1=["nn.quantized_max_pool"])(
dispatch.add_dispatch_support(gen_nn_ops.quantized_max_pool))

View File

@ -25,10 +25,12 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@tf_export(v1=["debugging.assert_all_finite", "verify_tensor_all_finite"]) @tf_export(v1=["debugging.assert_all_finite", "verify_tensor_all_finite"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("verify_tensor_all_finite") @deprecation.deprecated_endpoints("verify_tensor_all_finite")
def verify_tensor_all_finite(t=None, msg=None, name=None, x=None, message=None): def verify_tensor_all_finite(t=None, msg=None, name=None, x=None, message=None):
"""Assert that the tensor does not contain any NaN's or Inf's. """Assert that the tensor does not contain any NaN's or Inf's.
@ -50,6 +52,7 @@ def verify_tensor_all_finite(t=None, msg=None, name=None, x=None, message=None):
@tf_export("debugging.assert_all_finite", v1=[]) @tf_export("debugging.assert_all_finite", v1=[])
@dispatch.add_dispatch_support
def verify_tensor_all_finite_v2(x, message, name=None): def verify_tensor_all_finite_v2(x, message, name=None):
"""Assert that the tensor does not contain any NaN's or Inf's. """Assert that the tensor does not contain any NaN's or Inf's.

View File

@ -30,6 +30,7 @@ from tensorflow.python.ops import parsing_config
from tensorflow.python.ops.gen_parsing_ops import * from tensorflow.python.ops.gen_parsing_ops import *
# pylint: enable=wildcard-import,undefined-variable # pylint: enable=wildcard-import,undefined-variable
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -77,6 +78,7 @@ def _prepend_none_dimension(features):
@tf_export("io.parse_example", v1=[]) @tf_export("io.parse_example", v1=[])
@dispatch.add_dispatch_support
def parse_example_v2(serialized, features, example_names=None, name=None): def parse_example_v2(serialized, features, example_names=None, name=None):
# pylint: disable=line-too-long # pylint: disable=line-too-long
"""Parses `Example` protos into a `dict` of tensors. """Parses `Example` protos into a `dict` of tensors.
@ -314,6 +316,7 @@ def parse_example_v2(serialized, features, example_names=None, name=None):
@tf_export(v1=["io.parse_example", "parse_example"]) @tf_export(v1=["io.parse_example", "parse_example"])
@dispatch.add_dispatch_support
def parse_example(serialized, features, name=None, example_names=None): def parse_example(serialized, features, name=None, example_names=None):
return parse_example_v2(serialized, features, example_names, name) return parse_example_v2(serialized, features, example_names, name)
@ -373,6 +376,7 @@ def _parse_example_raw(serialized, names, params, name):
@tf_export(v1=["io.parse_single_example", "parse_single_example"]) @tf_export(v1=["io.parse_single_example", "parse_single_example"])
@dispatch.add_dispatch_support
def parse_single_example(serialized, features, name=None, example_names=None): def parse_single_example(serialized, features, name=None, example_names=None):
"""Parses a single `Example` proto. """Parses a single `Example` proto.
@ -407,6 +411,7 @@ def parse_single_example(serialized, features, name=None, example_names=None):
@tf_export("io.parse_single_example", v1=[]) @tf_export("io.parse_single_example", v1=[])
@dispatch.add_dispatch_support
def parse_single_example_v2( def parse_single_example_v2(
serialized, features, example_names=None, name=None serialized, features, example_names=None, name=None
): ):
@ -448,6 +453,7 @@ def parse_single_example_v2(
@tf_export("io.parse_sequence_example") @tf_export("io.parse_sequence_example")
@dispatch.add_dispatch_support
def parse_sequence_example(serialized, def parse_sequence_example(serialized,
context_features=None, context_features=None,
sequence_features=None, sequence_features=None,
@ -692,6 +698,7 @@ def _parse_sequence_example_raw(serialized,
@tf_export("io.parse_single_sequence_example", @tf_export("io.parse_single_sequence_example",
v1=["io.parse_single_sequence_example", v1=["io.parse_single_sequence_example",
"parse_single_sequence_example"]) "parse_single_sequence_example"])
@dispatch.add_dispatch_support
def parse_single_sequence_example( def parse_single_sequence_example(
serialized, context_features=None, sequence_features=None, serialized, context_features=None, sequence_features=None,
example_name=None, name=None): example_name=None, name=None):
@ -835,6 +842,7 @@ def _parse_single_sequence_example_raw(serialized,
@tf_export("io.decode_raw", v1=[]) @tf_export("io.decode_raw", v1=[])
@dispatch.add_dispatch_support
def decode_raw(input_bytes, def decode_raw(input_bytes,
out_type, out_type,
little_endian=True, little_endian=True,
@ -877,6 +885,7 @@ def decode_raw(input_bytes,
@tf_export(v1=["decode_raw", "io.decode_raw"]) @tf_export(v1=["decode_raw", "io.decode_raw"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, @deprecation.deprecated_args(None,
"bytes is deprecated, use input_bytes instead", "bytes is deprecated, use input_bytes instead",
"bytes") "bytes")
@ -921,6 +930,7 @@ def decode_raw_v1(
# Swap `name` and `na_value` for backward compatibility. # Swap `name` and `na_value` for backward compatibility.
@tf_export(v1=["io.decode_csv", "decode_csv"]) @tf_export(v1=["io.decode_csv", "decode_csv"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("decode_csv") @deprecation.deprecated_endpoints("decode_csv")
def decode_csv(records, def decode_csv(records,
record_defaults, record_defaults,
@ -970,6 +980,7 @@ def decode_csv(records,
@tf_export("io.decode_csv", v1=[]) @tf_export("io.decode_csv", v1=[])
@dispatch.add_dispatch_support
def decode_csv_v2(records, def decode_csv_v2(records,
record_defaults, record_defaults,
field_delim=",", field_delim=",",

View File

@ -22,10 +22,11 @@ from __future__ import print_function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops.gen_decode_proto_ops import decode_proto_v2 as decode_proto from tensorflow.python.ops.gen_decode_proto_ops import decode_proto_v2 as decode_proto
from tensorflow.python.ops.gen_encode_proto_ops import encode_proto from tensorflow.python.ops.gen_encode_proto_ops import encode_proto
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
tf_export("io.decode_proto")(decode_proto) tf_export("io.decode_proto")(dispatch.add_dispatch_support(decode_proto))
tf_export("io.encode_proto")(encode_proto) tf_export("io.encode_proto")(dispatch.add_dispatch_support(encode_proto))
ops.NotDifferentiable("DecodeProtoV2") ops.NotDifferentiable("DecodeProtoV2")
ops.NotDifferentiable("EncodeProto") ops.NotDifferentiable("EncodeProto")

View File

@ -32,6 +32,7 @@ from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_util from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.ops.ragged import segment_id_ops from tensorflow.python.ops.ragged import segment_id_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
#=============================================================================== #===============================================================================
@ -40,6 +41,7 @@ from tensorflow.python.util.tf_export import tf_export
@tf_export('ragged.boolean_mask') @tf_export('ragged.boolean_mask')
@dispatch.add_dispatch_support
def boolean_mask(data, mask, name=None): def boolean_mask(data, mask, name=None):
"""Applies a boolean mask to `data` without flattening the mask dimensions. """Applies a boolean mask to `data` without flattening the mask dimensions.
@ -538,6 +540,7 @@ def ragged_one_hot(indices,
# ragged.stack_dynamic_partitions # ragged.stack_dynamic_partitions
#=============================================================================== #===============================================================================
@tf_export('ragged.stack_dynamic_partitions') @tf_export('ragged.stack_dynamic_partitions')
@dispatch.add_dispatch_support
def stack_dynamic_partitions(data, partitions, num_partitions, name=None): def stack_dynamic_partitions(data, partitions, num_partitions, name=None):
"""Stacks dynamic partitions of a Tensor or RaggedTensor. """Stacks dynamic partitions of a Tensor or RaggedTensor.
@ -699,6 +702,7 @@ def reverse(tensor, axis, name=None):
@tf_export('ragged.cross') @tf_export('ragged.cross')
@dispatch.add_dispatch_support
def cross(inputs, name=None): def cross(inputs, name=None):
"""Generates feature cross from a list of tensors. """Generates feature cross from a list of tensors.
@ -725,6 +729,7 @@ def cross(inputs, name=None):
@tf_export('ragged.cross_hashed') @tf_export('ragged.cross_hashed')
@dispatch.add_dispatch_support
def cross_hashed(inputs, num_buckets=0, hash_key=None, name=None): def cross_hashed(inputs, num_buckets=0, hash_key=None, name=None):
"""Generates hashed feature cross from a list of tensors. """Generates hashed feature cross from a list of tensors.

View File

@ -27,6 +27,7 @@ from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_gather_ops from tensorflow.python.ops.ragged import ragged_gather_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_util from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -71,6 +72,7 @@ def concat(values, axis, name=None):
@tf_export('ragged.stack') @tf_export('ragged.stack')
@dispatch.add_dispatch_support
def stack(values, axis=0, name=None): def stack(values, axis=0, name=None):
"""Stacks a list of rank-`R` tensors into one rank-`(R+1)` `RaggedTensor`. """Stacks a list of rank-`R` tensors into one rank-`(R+1)` `RaggedTensor`.

View File

@ -27,6 +27,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_value from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -34,6 +35,7 @@ from tensorflow.python.util.tf_export import tf_export
# Op to construct a constant RaggedTensor from a nested Python list. # Op to construct a constant RaggedTensor from a nested Python list.
#=============================================================================== #===============================================================================
@tf_export("ragged.constant") @tf_export("ragged.constant")
@dispatch.add_dispatch_support
def constant(pylist, dtype=None, ragged_rank=None, inner_shape=None, def constant(pylist, dtype=None, ragged_rank=None, inner_shape=None,
name=None, row_splits_dtype=dtypes.int64): name=None, row_splits_dtype=dtypes.int64):
"""Constructs a constant RaggedTensor from a nested Python list. """Constructs a constant RaggedTensor from a nested Python list.
@ -86,6 +88,7 @@ def constant(pylist, dtype=None, ragged_rank=None, inner_shape=None,
@tf_export(v1=["ragged.constant_value"]) @tf_export(v1=["ragged.constant_value"])
@dispatch.add_dispatch_support
def constant_value(pylist, dtype=None, ragged_rank=None, inner_shape=None, def constant_value(pylist, dtype=None, ragged_rank=None, inner_shape=None,
row_splits_dtype="int64"): row_splits_dtype="int64"):
"""Constructs a RaggedTensorValue from a nested Python list. """Constructs a RaggedTensorValue from a nested Python list.
@ -311,6 +314,7 @@ def _default_inner_shape_for_pylist(pylist, ragged_rank):
@tf_export(v1=["ragged.placeholder"]) @tf_export(v1=["ragged.placeholder"])
@dispatch.add_dispatch_support
def placeholder(dtype, ragged_rank, value_shape=None, name=None): def placeholder(dtype, ragged_rank, value_shape=None, name=None):
"""Creates a placeholder for a `tf.RaggedTensor` that will always be fed. """Creates a placeholder for a `tf.RaggedTensor` that will always be fed.

View File

@ -24,10 +24,12 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_config from tensorflow.python.ops.ragged import ragged_config
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_util from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@tf_export("ragged.map_flat_values") @tf_export("ragged.map_flat_values")
@dispatch.add_dispatch_support
def map_flat_values(op, *args, **kwargs): def map_flat_values(op, *args, **kwargs):
"""Applies `op` to the values of one or more RaggedTensors. """Applies `op` to the values of one or more RaggedTensors.

View File

@ -30,6 +30,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_functional_ops from tensorflow.python.ops.ragged import ragged_functional_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import segment_id_ops from tensorflow.python.ops.ragged import segment_id_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -38,6 +39,7 @@ from tensorflow.python.util.tf_export import tf_export
#=============================================================================== #===============================================================================
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
@tf_export('ragged.range') @tf_export('ragged.range')
@dispatch.add_dispatch_support
def range(starts, limits=None, deltas=1, dtype=None, def range(starts, limits=None, deltas=1, dtype=None,
name=None, row_splits_dtype=dtypes.int64): name=None, row_splits_dtype=dtypes.int64):
"""Returns a `RaggedTensor` containing the specified sequences of numbers. """Returns a `RaggedTensor` containing the specified sequences of numbers.

View File

@ -29,10 +29,12 @@ from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util import compat as util_compat from tensorflow.python.util import compat as util_compat
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@tf_export("strings.bytes_split") @tf_export("strings.bytes_split")
@dispatch.add_dispatch_support
def string_bytes_split(input, name=None): # pylint: disable=redefined-builtin def string_bytes_split(input, name=None): # pylint: disable=redefined-builtin
"""Split string elements of `input` into bytes. """Split string elements of `input` into bytes.
@ -80,6 +82,7 @@ def string_bytes_split(input, name=None): # pylint: disable=redefined-builtin
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
@tf_export("strings.unicode_encode") @tf_export("strings.unicode_encode")
@dispatch.add_dispatch_support
def unicode_encode(input, def unicode_encode(input,
output_encoding, output_encoding,
errors="replace", errors="replace",
@ -177,6 +180,7 @@ def unicode_encode(input,
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
@tf_export("strings.unicode_decode") @tf_export("strings.unicode_decode")
@dispatch.add_dispatch_support
def unicode_decode(input, def unicode_decode(input,
input_encoding, input_encoding,
errors="replace", errors="replace",
@ -222,6 +226,7 @@ def unicode_decode(input,
@tf_export("strings.unicode_decode_with_offsets") @tf_export("strings.unicode_decode_with_offsets")
@dispatch.add_dispatch_support
def unicode_decode_with_offsets(input, def unicode_decode_with_offsets(input,
input_encoding, input_encoding,
errors="replace", errors="replace",
@ -283,6 +288,7 @@ def unicode_decode_with_offsets(input,
@tf_export("strings.unicode_split") @tf_export("strings.unicode_split")
@dispatch.add_dispatch_support
def unicode_split(input, def unicode_split(input,
input_encoding, input_encoding,
errors="replace", errors="replace",
@ -330,6 +336,7 @@ def unicode_split(input,
@tf_export("strings.unicode_split_with_offsets") @tf_export("strings.unicode_split_with_offsets")
@dispatch.add_dispatch_support
def unicode_split_with_offsets(input, def unicode_split_with_offsets(input,
input_encoding, input_encoding,
errors="replace", errors="replace",
@ -453,6 +460,7 @@ def _unicode_decode(input, input_encoding, errors, replacement_char,
@tf_export("strings.split", v1=[]) @tf_export("strings.split", v1=[])
@dispatch.add_dispatch_support
def string_split_v2(input, sep=None, maxsplit=-1, name=None): # pylint: disable=redefined-builtin def string_split_v2(input, sep=None, maxsplit=-1, name=None): # pylint: disable=redefined-builtin
"""Split elements of `input` based on `sep` into a `RaggedTensor`. """Split elements of `input` based on `sep` into a `RaggedTensor`.
@ -514,6 +522,7 @@ def string_split_v2(input, sep=None, maxsplit=-1, name=None): # pylint: disable
@tf_export(v1=["string_split"]) @tf_export(v1=["string_split"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, @deprecation.deprecated_args(None,
"delimiter is deprecated, please use sep instead.", "delimiter is deprecated, please use sep instead.",
"delimiter") "delimiter")
@ -578,6 +587,7 @@ def string_split(source, sep=None, skip_empty=True, delimiter=None,
# In TensorFlow 1.x, "tf.strings.split" uses the new signature (with maxsplit), # In TensorFlow 1.x, "tf.strings.split" uses the new signature (with maxsplit),
# but we need to add the result_type argument. # but we need to add the result_type argument.
@tf_export(v1=["strings.split"]) @tf_export(v1=["strings.split"])
@dispatch.add_dispatch_support
def strings_split_v1(input=None, sep=None, maxsplit=-1, # pylint: disable=redefined-builtin def strings_split_v1(input=None, sep=None, maxsplit=-1, # pylint: disable=redefined-builtin
result_type="SparseTensor", source=None, name=None): result_type="SparseTensor", source=None, name=None):
"""Split elements of `input` based on `sep`. """Split elements of `input` based on `sep`.
@ -651,6 +661,7 @@ def reduce_join(inputs, axis=None, keepdims=None, separator="", name=None):
@tf_export("strings.ngrams") @tf_export("strings.ngrams")
@dispatch.add_dispatch_support
def ngrams(data, def ngrams(data,
ngram_width, ngram_width,
separator=" ", separator=" ",

View File

@ -25,12 +25,14 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_util from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
# For background on "segments" and "segment ids", see: # For background on "segments" and "segment ids", see:
# https://www.tensorflow.org/api_docs/python/tf/math#Segmentation # https://www.tensorflow.org/api_docs/python/tf/math#Segmentation
@tf_export("ragged.row_splits_to_segment_ids") @tf_export("ragged.row_splits_to_segment_ids")
@dispatch.add_dispatch_support
def row_splits_to_segment_ids(splits, name=None, out_type=None): def row_splits_to_segment_ids(splits, name=None, out_type=None):
"""Generates the segmentation corresponding to a RaggedTensor `row_splits`. """Generates the segmentation corresponding to a RaggedTensor `row_splits`.
@ -74,6 +76,7 @@ def row_splits_to_segment_ids(splits, name=None, out_type=None):
# For background on "segments" and "segment ids", see: # For background on "segments" and "segment ids", see:
# https://www.tensorflow.org/api_docs/python/tf/math#Segmentation # https://www.tensorflow.org/api_docs/python/tf/math#Segmentation
@tf_export("ragged.segment_ids_to_row_splits") @tf_export("ragged.segment_ids_to_row_splits")
@dispatch.add_dispatch_support
def segment_ids_to_row_splits(segment_ids, num_segments=None, def segment_ids_to_row_splits(segment_ids, num_segments=None,
out_type=None, name=None): out_type=None, name=None):
"""Generates the RaggedTensor `row_splits` corresponding to a segmentation. """Generates the RaggedTensor `row_splits` corresponding to a segmentation.

View File

@ -36,10 +36,12 @@ from tensorflow.python.ops.gen_random_ops import *
# pylint: enable=wildcard-import # pylint: enable=wildcard-import
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@tf_export("random.normal", v1=["random.normal", "random_normal"]) @tf_export("random.normal", v1=["random.normal", "random_normal"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("random_normal") @deprecation.deprecated_endpoints("random_normal")
def random_normal(shape, def random_normal(shape,
mean=0.0, mean=0.0,
@ -155,6 +157,7 @@ def parameterized_truncated_normal(shape,
@tf_export("random.truncated_normal", @tf_export("random.truncated_normal",
v1=["random.truncated_normal", "truncated_normal"]) v1=["random.truncated_normal", "truncated_normal"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("truncated_normal") @deprecation.deprecated_endpoints("truncated_normal")
def truncated_normal(shape, def truncated_normal(shape,
mean=0.0, mean=0.0,
@ -202,6 +205,7 @@ ops.NotDifferentiable("TruncatedNormal")
@tf_export("random.uniform", v1=["random.uniform", "random_uniform"]) @tf_export("random.uniform", v1=["random.uniform", "random_uniform"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("random_uniform") @deprecation.deprecated_endpoints("random_uniform")
def random_uniform(shape, def random_uniform(shape,
minval=0, minval=0,
@ -313,6 +317,7 @@ ops.NotDifferentiable("RandomUniform")
@tf_export("random.shuffle", v1=["random.shuffle", "random_shuffle"]) @tf_export("random.shuffle", v1=["random.shuffle", "random_shuffle"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("random_shuffle") @deprecation.deprecated_endpoints("random_shuffle")
def random_shuffle(value, seed=None, name=None): def random_shuffle(value, seed=None, name=None):
"""Randomly shuffles a tensor along its first dimension. """Randomly shuffles a tensor along its first dimension.
@ -345,6 +350,7 @@ def random_shuffle(value, seed=None, name=None):
@tf_export("image.random_crop", v1=["image.random_crop", "random_crop"]) @tf_export("image.random_crop", v1=["image.random_crop", "random_crop"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("random_crop") @deprecation.deprecated_endpoints("random_crop")
def random_crop(value, size, seed=None, name=None): def random_crop(value, size, seed=None, name=None):
"""Randomly crops a tensor to a given size. """Randomly crops a tensor to a given size.
@ -389,6 +395,7 @@ def random_crop(value, size, seed=None, name=None):
@tf_export(v1=["random.multinomial", "multinomial"]) @tf_export(v1=["random.multinomial", "multinomial"])
@dispatch.add_dispatch_support
@deprecation.deprecated( @deprecation.deprecated(
date=None, instructions="Use `tf.random.categorical` instead.") date=None, instructions="Use `tf.random.categorical` instead.")
def multinomial(logits, num_samples, seed=None, name=None, output_dtype=None): def multinomial(logits, num_samples, seed=None, name=None, output_dtype=None):
@ -468,6 +475,7 @@ def _maybe_set_static_shape_helper(tensor, shape, postfix_tensor):
@tf_export("random.gamma", v1=["random.gamma", "random_gamma"]) @tf_export("random.gamma", v1=["random.gamma", "random_gamma"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("random_gamma") @deprecation.deprecated_endpoints("random_gamma")
def random_gamma(shape, def random_gamma(shape,
alpha, alpha,
@ -561,6 +569,7 @@ def random_gamma(shape,
@tf_export(v1=["random.poisson", "random_poisson"]) @tf_export(v1=["random.poisson", "random_poisson"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("random_poisson") @deprecation.deprecated_endpoints("random_poisson")
def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None): def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None):
"""Draws `shape` samples from each of the given Poisson distribution(s). """Draws `shape` samples from each of the given Poisson distribution(s).
@ -601,6 +610,7 @@ def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None):
@tf_export("random.poisson", v1=[]) @tf_export("random.poisson", v1=[])
@dispatch.add_dispatch_support
def random_poisson_v2(shape, lam, dtype=dtypes.float32, seed=None, name=None): def random_poisson_v2(shape, lam, dtype=dtypes.float32, seed=None, name=None):
"""Draws `shape` samples from each of the given Poisson distribution(s). """Draws `shape` samples from each of the given Poisson distribution(s).

View File

@ -32,6 +32,7 @@ from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -342,6 +343,7 @@ def _reverse_seq(input_seq, lengths):
"keras.layers.RNN(cell))`, which is equivalent to " "keras.layers.RNN(cell))`, which is equivalent to "
"this API") "this API")
@tf_export(v1=["nn.bidirectional_dynamic_rnn"]) @tf_export(v1=["nn.bidirectional_dynamic_rnn"])
@dispatch.add_dispatch_support
def bidirectional_dynamic_rnn(cell_fw, def bidirectional_dynamic_rnn(cell_fw,
cell_bw, cell_bw,
inputs, inputs,
@ -499,6 +501,7 @@ def bidirectional_dynamic_rnn(cell_fw,
None, None,
"Please use `keras.layers.RNN(cell)`, which is equivalent to this API") "Please use `keras.layers.RNN(cell)`, which is equivalent to this API")
@tf_export(v1=["nn.dynamic_rnn"]) @tf_export(v1=["nn.dynamic_rnn"])
@dispatch.add_dispatch_support
def dynamic_rnn(cell, def dynamic_rnn(cell,
inputs, inputs,
sequence_length=None, sequence_length=None,
@ -912,6 +915,7 @@ def _dynamic_rnn_loop(cell,
@tf_export(v1=["nn.raw_rnn"]) @tf_export(v1=["nn.raw_rnn"])
@dispatch.add_dispatch_support
def raw_rnn(cell, def raw_rnn(cell,
loop_fn, loop_fn,
parallel_iterations=None, parallel_iterations=None,
@ -1238,6 +1242,7 @@ def raw_rnn(cell,
"Please use `keras.layers.RNN(cell, unroll=True)`, " "Please use `keras.layers.RNN(cell, unroll=True)`, "
"which is equivalent to this API") "which is equivalent to this API")
@tf_export(v1=["nn.static_rnn"]) @tf_export(v1=["nn.static_rnn"])
@dispatch.add_dispatch_support
def static_rnn(cell, def static_rnn(cell,
inputs, inputs,
initial_state=None, initial_state=None,
@ -1416,6 +1421,7 @@ def static_rnn(cell,
"Please use `keras.layers.RNN(cell, stateful=True)`, " "Please use `keras.layers.RNN(cell, stateful=True)`, "
"which is equivalent to this API") "which is equivalent to this API")
@tf_export(v1=["nn.static_state_saving_rnn"]) @tf_export(v1=["nn.static_state_saving_rnn"])
@dispatch.add_dispatch_support
def static_state_saving_rnn(cell, def static_state_saving_rnn(cell,
inputs, inputs,
state_saver, state_saver,
@ -1510,6 +1516,7 @@ def static_state_saving_rnn(cell,
"keras.layers.RNN(cell, unroll=True))`, which is " "keras.layers.RNN(cell, unroll=True))`, which is "
"equivalent to this API") "equivalent to this API")
@tf_export(v1=["nn.static_bidirectional_rnn"]) @tf_export(v1=["nn.static_bidirectional_rnn"])
@dispatch.add_dispatch_support
def static_bidirectional_rnn(cell_fw, def static_bidirectional_rnn(cell_fw,
cell_bw, cell_bw,
inputs, inputs,

View File

@ -39,6 +39,7 @@ from tensorflow.python.ops import gen_script_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import lazy_loader from tensorflow.python.util import lazy_loader
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect from tensorflow.python.util import tf_inspect
@ -370,6 +371,7 @@ def _EagerPyFuncGrad(op, *dy):
@tf_export("py_function") @tf_export("py_function")
@dispatch.add_dispatch_support
def eager_py_func(func, inp, Tout, name=None): def eager_py_func(func, inp, Tout, name=None):
"""Wraps a python function into a TensorFlow op that executes it eagerly. """Wraps a python function into a TensorFlow op that executes it eagerly.
@ -551,6 +553,7 @@ def py_func_common(func, inp, Tout, stateful=True, name=None):
stateful argument making all functions stateful. stateful argument making all functions stateful.
""") """)
@tf_export(v1=["py_func"]) @tf_export(v1=["py_func"])
@dispatch.add_dispatch_support
def py_func(func, inp, Tout, stateful=True, name=None): def py_func(func, inp, Tout, stateful=True, name=None):
return py_func_common(func, inp, Tout, stateful, name=name) return py_func_common(func, inp, Tout, stateful, name=name)
@ -559,6 +562,7 @@ py_func.__doc__ = "%s" % py_func_common.__doc__
@tf_export("numpy_function") @tf_export("numpy_function")
@dispatch.add_dispatch_support
def numpy_function(func, inp, Tout, name=None): def numpy_function(func, inp, Tout, name=None):
"""Wraps a python function and uses it as a TensorFlow op. """Wraps a python function and uses it as a TensorFlow op.

View File

@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import gen_set_ops from tensorflow.python.ops import gen_set_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -32,6 +33,7 @@ _VALID_DTYPES = set([
@tf_export("sets.size", v1=["sets.size", "sets.set_size"]) @tf_export("sets.size", v1=["sets.size", "sets.set_size"])
@dispatch.add_dispatch_support
def set_size(a, validate_indices=True): def set_size(a, validate_indices=True):
"""Compute number of unique elements along last dimension of `a`. """Compute number of unique elements along last dimension of `a`.
@ -135,6 +137,7 @@ def _set_operation(a, b, set_operation, validate_indices=True):
@tf_export( @tf_export(
"sets.intersection", v1=["sets.intersection", "sets.set_intersection"]) "sets.intersection", v1=["sets.intersection", "sets.set_intersection"])
@dispatch.add_dispatch_support
def set_intersection(a, b, validate_indices=True): def set_intersection(a, b, validate_indices=True):
"""Compute set intersection of elements in last dimension of `a` and `b`. """Compute set intersection of elements in last dimension of `a` and `b`.
@ -205,6 +208,7 @@ def set_intersection(a, b, validate_indices=True):
@tf_export( @tf_export(
"sets.difference", v1=["sets.difference", "sets.set_difference"]) "sets.difference", v1=["sets.difference", "sets.set_difference"])
@dispatch.add_dispatch_support
def set_difference(a, b, aminusb=True, validate_indices=True): def set_difference(a, b, aminusb=True, validate_indices=True):
"""Compute set difference of elements in last dimension of `a` and `b`. """Compute set difference of elements in last dimension of `a` and `b`.
@ -286,6 +290,7 @@ def set_difference(a, b, aminusb=True, validate_indices=True):
@tf_export( @tf_export(
"sets.union", v1=["sets.union", "sets.set_union"]) "sets.union", v1=["sets.union", "sets.set_union"])
@dispatch.add_dispatch_support
def set_union(a, b, validate_indices=True): def set_union(a, b, validate_indices=True):
"""Compute set union of elements in last dimension of `a` and `b`. """Compute set union of elements in last dimension of `a` and `b`.

View File

@ -25,6 +25,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops as _array_ops from tensorflow.python.ops import array_ops as _array_ops
from tensorflow.python.ops import math_ops as _math_ops from tensorflow.python.ops import math_ops as _math_ops
from tensorflow.python.ops.signal import fft_ops from tensorflow.python.ops.signal import fft_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -50,6 +51,7 @@ def _validate_dct_arguments(input_tensor, dct_type, n, axis, norm):
# TODO(rjryan): Implement `axis` parameter. # TODO(rjryan): Implement `axis` parameter.
@tf_export("signal.dct", v1=["signal.dct", "spectral.dct"]) @tf_export("signal.dct", v1=["signal.dct", "spectral.dct"])
@dispatch.add_dispatch_support
def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
"""Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`. """Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`.
@ -181,6 +183,7 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
# TODO(rjryan): Implement `n` and `axis` parameters. # TODO(rjryan): Implement `n` and `axis` parameters.
@tf_export("signal.idct", v1=["signal.idct", "spectral.idct"]) @tf_export("signal.idct", v1=["signal.idct", "spectral.idct"])
@dispatch.add_dispatch_support
def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
"""Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`. """Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`.

View File

@ -26,6 +26,7 @@ from tensorflow.python.ops import array_ops as _array_ops
from tensorflow.python.ops import gen_spectral_ops from tensorflow.python.ops import gen_spectral_ops
from tensorflow.python.ops import manip_ops from tensorflow.python.ops import manip_ops
from tensorflow.python.ops import math_ops as _math_ops from tensorflow.python.ops import math_ops as _math_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -181,17 +182,23 @@ ifft2d = gen_spectral_ops.ifft2d
fft3d = gen_spectral_ops.fft3d fft3d = gen_spectral_ops.fft3d
ifft3d = gen_spectral_ops.ifft3d ifft3d = gen_spectral_ops.ifft3d
rfft = _rfft_wrapper(gen_spectral_ops.rfft, 1, "rfft") rfft = _rfft_wrapper(gen_spectral_ops.rfft, 1, "rfft")
tf_export("signal.rfft", v1=["signal.rfft", "spectral.rfft"])(rfft) tf_export("signal.rfft", v1=["signal.rfft", "spectral.rfft"])(
dispatch.add_dispatch_support(rfft))
irfft = _irfft_wrapper(gen_spectral_ops.irfft, 1, "irfft") irfft = _irfft_wrapper(gen_spectral_ops.irfft, 1, "irfft")
tf_export("signal.irfft", v1=["signal.irfft", "spectral.irfft"])(irfft) tf_export("signal.irfft", v1=["signal.irfft", "spectral.irfft"])(
dispatch.add_dispatch_support(irfft))
rfft2d = _rfft_wrapper(gen_spectral_ops.rfft2d, 2, "rfft2d") rfft2d = _rfft_wrapper(gen_spectral_ops.rfft2d, 2, "rfft2d")
tf_export("signal.rfft2d", v1=["signal.rfft2d", "spectral.rfft2d"])(rfft2d) tf_export("signal.rfft2d", v1=["signal.rfft2d", "spectral.rfft2d"])(
dispatch.add_dispatch_support(rfft2d))
irfft2d = _irfft_wrapper(gen_spectral_ops.irfft2d, 2, "irfft2d") irfft2d = _irfft_wrapper(gen_spectral_ops.irfft2d, 2, "irfft2d")
tf_export("signal.irfft2d", v1=["signal.irfft2d", "spectral.irfft2d"])(irfft2d) tf_export("signal.irfft2d", v1=["signal.irfft2d", "spectral.irfft2d"])(
dispatch.add_dispatch_support(irfft2d))
rfft3d = _rfft_wrapper(gen_spectral_ops.rfft3d, 3, "rfft3d") rfft3d = _rfft_wrapper(gen_spectral_ops.rfft3d, 3, "rfft3d")
tf_export("signal.rfft3d", v1=["signal.rfft3d", "spectral.rfft3d"])(rfft3d) tf_export("signal.rfft3d", v1=["signal.rfft3d", "spectral.rfft3d"])(
dispatch.add_dispatch_support(rfft3d))
irfft3d = _irfft_wrapper(gen_spectral_ops.irfft3d, 3, "irfft3d") irfft3d = _irfft_wrapper(gen_spectral_ops.irfft3d, 3, "irfft3d")
tf_export("signal.irfft3d", v1=["signal.irfft3d", "spectral.irfft3d"])(irfft3d) tf_export("signal.irfft3d", v1=["signal.irfft3d", "spectral.irfft3d"])(
dispatch.add_dispatch_support(irfft3d))
def _fft_size_for_grad(grad, rank): def _fft_size_for_grad(grad, rank):
@ -363,6 +370,7 @@ def _irfft_grad_helper(rank, rfft_fn):
@tf_export("signal.fftshift") @tf_export("signal.fftshift")
@dispatch.add_dispatch_support
def fftshift(x, axes=None, name=None): def fftshift(x, axes=None, name=None):
"""Shift the zero-frequency component to the center of the spectrum. """Shift the zero-frequency component to the center of the spectrum.
@ -407,6 +415,7 @@ def fftshift(x, axes=None, name=None):
@tf_export("signal.ifftshift") @tf_export("signal.ifftshift")
@dispatch.add_dispatch_support
def ifftshift(x, axes=None, name=None): def ifftshift(x, axes=None, name=None):
"""The inverse of fftshift. """The inverse of fftshift.

View File

@ -24,6 +24,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.signal import shape_ops from tensorflow.python.ops.signal import shape_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -90,6 +91,7 @@ def _validate_arguments(num_mel_bins, sample_rate,
@tf_export('signal.linear_to_mel_weight_matrix') @tf_export('signal.linear_to_mel_weight_matrix')
@dispatch.add_dispatch_support
def linear_to_mel_weight_matrix(num_mel_bins=20, def linear_to_mel_weight_matrix(num_mel_bins=20,
num_spectrogram_bins=129, num_spectrogram_bins=129,
sample_rate=8000, sample_rate=8000,

View File

@ -22,10 +22,12 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.signal import dct_ops from tensorflow.python.ops.signal import dct_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@tf_export('signal.mfccs_from_log_mel_spectrograms') @tf_export('signal.mfccs_from_log_mel_spectrograms')
@dispatch.add_dispatch_support
def mfccs_from_log_mel_spectrograms(log_mel_spectrograms, name=None): def mfccs_from_log_mel_spectrograms(log_mel_spectrograms, name=None):
"""Computes [MFCCs][mfcc] of `log_mel_spectrograms`. """Computes [MFCCs][mfcc] of `log_mel_spectrograms`.

View File

@ -23,10 +23,12 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@tf_export("signal.overlap_and_add") @tf_export("signal.overlap_and_add")
@dispatch.add_dispatch_support
def overlap_and_add(signal, frame_step, name=None): def overlap_and_add(signal, frame_step, name=None):
"""Reconstructs a signal from a framed representation. """Reconstructs a signal from a framed representation.

View File

@ -25,6 +25,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.signal import util_ops from tensorflow.python.ops.signal import util_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -55,6 +56,7 @@ def _infer_frame_shape(signal, frame_length, frame_step, pad_end, axis):
@tf_export("signal.frame") @tf_export("signal.frame")
@dispatch.add_dispatch_support
def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1, def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1,
name=None): name=None):
"""Expands `signal`'s `axis` dimension into frames of `frame_length`. """Expands `signal`'s `axis` dimension into frames of `frame_length`.

View File

@ -31,10 +31,12 @@ from tensorflow.python.ops.signal import fft_ops
from tensorflow.python.ops.signal import reconstruction_ops from tensorflow.python.ops.signal import reconstruction_ops
from tensorflow.python.ops.signal import shape_ops from tensorflow.python.ops.signal import shape_ops
from tensorflow.python.ops.signal import window_ops from tensorflow.python.ops.signal import window_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@tf_export('signal.stft') @tf_export('signal.stft')
@dispatch.add_dispatch_support
def stft(signals, frame_length, frame_step, fft_length=None, def stft(signals, frame_length, frame_step, fft_length=None,
window_fn=window_ops.hann_window, window_fn=window_ops.hann_window,
pad_end=False, name=None): pad_end=False, name=None):
@ -95,6 +97,7 @@ def stft(signals, frame_length, frame_step, fft_length=None,
@tf_export('signal.inverse_stft_window_fn') @tf_export('signal.inverse_stft_window_fn')
@dispatch.add_dispatch_support
def inverse_stft_window_fn(frame_step, def inverse_stft_window_fn(frame_step,
forward_window_fn=window_ops.hann_window, forward_window_fn=window_ops.hann_window,
name=None): name=None):
@ -156,6 +159,7 @@ def inverse_stft_window_fn(frame_step,
@tf_export('signal.inverse_stft') @tf_export('signal.inverse_stft')
@dispatch.add_dispatch_support
def inverse_stft(stfts, def inverse_stft(stfts,
frame_length, frame_length,
frame_step, frame_step,
@ -291,6 +295,7 @@ def _enclosing_power_of_two(value):
@tf_export('signal.mdct') @tf_export('signal.mdct')
@dispatch.add_dispatch_support
def mdct(signals, frame_length, window_fn=window_ops.vorbis_window, def mdct(signals, frame_length, window_fn=window_ops.vorbis_window,
pad_end=False, norm=None, name=None): pad_end=False, norm=None, name=None):
"""Computes the [Modified Discrete Cosine Transform][mdct] of `signals`. """Computes the [Modified Discrete Cosine Transform][mdct] of `signals`.
@ -366,6 +371,7 @@ def mdct(signals, frame_length, window_fn=window_ops.vorbis_window,
@tf_export('signal.inverse_mdct') @tf_export('signal.inverse_mdct')
@dispatch.add_dispatch_support
def inverse_mdct(mdcts, def inverse_mdct(mdcts,
window_fn=window_ops.vorbis_window, window_fn=window_ops.vorbis_window,
norm=None, norm=None,

View File

@ -27,6 +27,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -52,6 +53,7 @@ def _check_params(window_length, dtype):
@tf_export('signal.kaiser_window') @tf_export('signal.kaiser_window')
@dispatch.add_dispatch_support
def kaiser_window(window_length, beta=12., dtype=dtypes.float32, name=None): def kaiser_window(window_length, beta=12., dtype=dtypes.float32, name=None):
"""Generate a [Kaiser window][kaiser]. """Generate a [Kaiser window][kaiser].
@ -91,6 +93,7 @@ def kaiser_window(window_length, beta=12., dtype=dtypes.float32, name=None):
@tf_export('signal.kaiser_bessel_derived_window') @tf_export('signal.kaiser_bessel_derived_window')
@dispatch.add_dispatch_support
def kaiser_bessel_derived_window(window_length, beta=12., def kaiser_bessel_derived_window(window_length, beta=12.,
dtype=dtypes.float32, name=None): dtype=dtypes.float32, name=None):
"""Generate a [Kaiser Bessel derived window][kbd]. """Generate a [Kaiser Bessel derived window][kbd].
@ -118,6 +121,7 @@ def kaiser_bessel_derived_window(window_length, beta=12.,
@tf_export('signal.vorbis_window') @tf_export('signal.vorbis_window')
@dispatch.add_dispatch_support
def vorbis_window(window_length, dtype=dtypes.float32, name=None): def vorbis_window(window_length, dtype=dtypes.float32, name=None):
"""Generate a [Vorbis power complementary window][vorbis]. """Generate a [Vorbis power complementary window][vorbis].
@ -142,6 +146,7 @@ def vorbis_window(window_length, dtype=dtypes.float32, name=None):
@tf_export('signal.hann_window') @tf_export('signal.hann_window')
@dispatch.add_dispatch_support
def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None): def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None):
"""Generate a [Hann window][hann]. """Generate a [Hann window][hann].
@ -167,6 +172,7 @@ def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None):
@tf_export('signal.hamming_window') @tf_export('signal.hamming_window')
@dispatch.add_dispatch_support
def hamming_window(window_length, periodic=True, dtype=dtypes.float32, def hamming_window(window_length, periodic=True, dtype=dtypes.float32,
name=None): name=None):
"""Generate a [Hamming][hamming] window. """Generate a [Hamming][hamming] window.

View File

@ -30,10 +30,12 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@tf_export('sort') @tf_export('sort')
@dispatch.add_dispatch_support
def sort(values, axis=-1, direction='ASCENDING', name=None): def sort(values, axis=-1, direction='ASCENDING', name=None):
"""Sorts a tensor. """Sorts a tensor.
@ -67,6 +69,7 @@ def sort(values, axis=-1, direction='ASCENDING', name=None):
@tf_export('argsort') @tf_export('argsort')
@dispatch.add_dispatch_support
def argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None): def argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None):
"""Returns the indices of a tensor that give its sorted order along an axis. """Returns the indices of a tensor that give its sorted order along an axis.

View File

@ -1065,6 +1065,7 @@ def sparse_slice(sp_input, start, size, name=None):
@tf_export(v1=["sparse_to_dense"]) @tf_export(v1=["sparse_to_dense"])
@dispatch.add_dispatch_support
@deprecation.deprecated( @deprecation.deprecated(
None, None,
"Create a `tf.sparse.SparseTensor` and use `tf.sparse.to_dense` instead.") "Create a `tf.sparse.SparseTensor` and use `tf.sparse.to_dense` instead.")
@ -1994,6 +1995,7 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None):
@tf_export(v1=["io.serialize_sparse", "serialize_sparse"]) @tf_export(v1=["io.serialize_sparse", "serialize_sparse"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("serialize_sparse") @deprecation.deprecated_endpoints("serialize_sparse")
def serialize_sparse(sp_input, name=None, out_type=dtypes.string): def serialize_sparse(sp_input, name=None, out_type=dtypes.string):
"""Serialize a `SparseTensor` into a 3-vector (1-D `Tensor`) object. """Serialize a `SparseTensor` into a 3-vector (1-D `Tensor`) object.
@ -2014,6 +2016,7 @@ def serialize_sparse(sp_input, name=None, out_type=dtypes.string):
@tf_export("io.serialize_sparse", v1=[]) @tf_export("io.serialize_sparse", v1=[])
@dispatch.add_dispatch_support
def serialize_sparse_v2(sp_input, out_type=dtypes.string, name=None): def serialize_sparse_v2(sp_input, out_type=dtypes.string, name=None):
"""Serialize a `SparseTensor` into a 3-vector (1-D `Tensor`) object. """Serialize a `SparseTensor` into a 3-vector (1-D `Tensor`) object.
@ -2040,6 +2043,7 @@ def serialize_sparse_v2(sp_input, out_type=dtypes.string, name=None):
@tf_export(v1=["io.serialize_many_sparse", "serialize_many_sparse"]) @tf_export(v1=["io.serialize_many_sparse", "serialize_many_sparse"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("serialize_many_sparse") @deprecation.deprecated_endpoints("serialize_many_sparse")
def serialize_many_sparse(sp_input, name=None, out_type=dtypes.string): def serialize_many_sparse(sp_input, name=None, out_type=dtypes.string):
"""Serialize `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor`. """Serialize `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor`.
@ -2069,6 +2073,7 @@ def serialize_many_sparse(sp_input, name=None, out_type=dtypes.string):
@tf_export("io.serialize_many_sparse", v1=[]) @tf_export("io.serialize_many_sparse", v1=[])
@dispatch.add_dispatch_support
def serialize_many_sparse_v2(sp_input, out_type=dtypes.string, name=None): def serialize_many_sparse_v2(sp_input, out_type=dtypes.string, name=None):
"""Serialize `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor`. """Serialize `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor`.
@ -2172,6 +2177,7 @@ def deserialize_sparse(serialized_sparse, dtype, rank=None, name=None):
@tf_export( @tf_export(
"io.deserialize_many_sparse", "io.deserialize_many_sparse",
v1=["io.deserialize_many_sparse", "deserialize_many_sparse"]) v1=["io.deserialize_many_sparse", "deserialize_many_sparse"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("deserialize_many_sparse") @deprecation.deprecated_endpoints("deserialize_many_sparse")
def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None): def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None):
"""Deserialize and concatenate `SparseTensors` from a serialized minibatch. """Deserialize and concatenate `SparseTensors` from a serialized minibatch.

View File

@ -42,11 +42,13 @@ from tensorflow.python.ops import gen_special_math_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
# TODO(b/27419586) Change docstring for required dtype of x once int allowed # TODO(b/27419586) Change docstring for required dtype of x once int allowed
@tf_export('math.lbeta', v1=['math.lbeta', 'lbeta']) @tf_export('math.lbeta', v1=['math.lbeta', 'lbeta'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('lbeta') @deprecation.deprecated_endpoints('lbeta')
def lbeta(x, name=None): def lbeta(x, name=None):
r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension. r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension.
@ -102,6 +104,7 @@ def lbeta(x, name=None):
@tf_export('math.special.dawsn') @tf_export('math.special.dawsn')
@dispatch.add_dispatch_support
def dawsn(x, name=None): def dawsn(x, name=None):
"""Computes Dawson's integral of `x` element-wise. """Computes Dawson's integral of `x` element-wise.
@ -131,6 +134,7 @@ def dawsn(x, name=None):
@tf_export('math.special.expint') @tf_export('math.special.expint')
@dispatch.add_dispatch_support
def expint(x, name=None): def expint(x, name=None):
"""Computes the Exponential integral of `x` element-wise. """Computes the Exponential integral of `x` element-wise.
@ -159,6 +163,7 @@ def expint(x, name=None):
@tf_export('math.special.fresnel_cos') @tf_export('math.special.fresnel_cos')
@dispatch.add_dispatch_support
def fresnel_cos(x, name=None): def fresnel_cos(x, name=None):
"""Computes Fresnel's cosine integral of `x` element-wise. """Computes Fresnel's cosine integral of `x` element-wise.
@ -188,6 +193,7 @@ def fresnel_cos(x, name=None):
@tf_export('math.special.fresnel_sin') @tf_export('math.special.fresnel_sin')
@dispatch.add_dispatch_support
def fresnel_sin(x, name=None): def fresnel_sin(x, name=None):
"""Computes Fresnel's sine integral of `x` element-wise. """Computes Fresnel's sine integral of `x` element-wise.
@ -216,6 +222,7 @@ def fresnel_sin(x, name=None):
@tf_export('math.special.spence') @tf_export('math.special.spence')
@dispatch.add_dispatch_support
def spence(x, name=None): def spence(x, name=None):
"""Computes Spence's integral of `x` element-wise. """Computes Spence's integral of `x` element-wise.
@ -244,6 +251,7 @@ def spence(x, name=None):
@tf_export('math.bessel_i0') @tf_export('math.bessel_i0')
@dispatch.add_dispatch_support
def bessel_i0(x, name=None): def bessel_i0(x, name=None):
"""Computes the Bessel i0 function of `x` element-wise. """Computes the Bessel i0 function of `x` element-wise.
@ -268,6 +276,7 @@ def bessel_i0(x, name=None):
@tf_export('math.bessel_i1') @tf_export('math.bessel_i1')
@dispatch.add_dispatch_support
def bessel_i1(x, name=None): def bessel_i1(x, name=None):
"""Computes the Bessel i1 function of `x` element-wise. """Computes the Bessel i1 function of `x` element-wise.
@ -325,6 +334,7 @@ def _enclosing_tpu_context():
@tf_export('einsum', 'linalg.einsum') @tf_export('einsum', 'linalg.einsum')
@dispatch.add_dispatch_support
def einsum(equation, *inputs, **kwargs): def einsum(equation, *inputs, **kwargs):
"""Tensor contraction over specified indices and outer product. """Tensor contraction over specified indices and outer product.

View File

@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_stateless_random_ops from tensorflow.python.ops import gen_stateless_random_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
ops.NotDifferentiable("StatelessMultinomial") ops.NotDifferentiable("StatelessMultinomial")
@ -40,6 +41,7 @@ ops.NotDifferentiable("StatelessTruncatedNormal")
@tf_export("random.experimental.stateless_split") @tf_export("random.experimental.stateless_split")
@dispatch.add_dispatch_support
def split(seed, num=2): def split(seed, num=2):
"""Splits an RNG seed into `num` new seeds by adding a leading axis. """Splits an RNG seed into `num` new seeds by adding a leading axis.
@ -73,6 +75,7 @@ def split(seed, num=2):
@tf_export("random.experimental.stateless_fold_in") @tf_export("random.experimental.stateless_fold_in")
@dispatch.add_dispatch_support
def fold_in(seed, data): def fold_in(seed, data):
"""Folds in data to an RNG seed to form a new RNG seed. """Folds in data to an RNG seed to form a new RNG seed.
@ -111,6 +114,7 @@ def fold_in(seed, data):
@tf_export("random.stateless_uniform") @tf_export("random.stateless_uniform")
@dispatch.add_dispatch_support
def stateless_random_uniform(shape, def stateless_random_uniform(shape,
seed, seed,
minval=0, minval=0,
@ -205,6 +209,7 @@ def stateless_random_uniform(shape,
@tf_export("random.stateless_binomial") @tf_export("random.stateless_binomial")
@dispatch.add_dispatch_support
def stateless_random_binomial(shape, def stateless_random_binomial(shape,
seed, seed,
counts, counts,
@ -274,6 +279,7 @@ def stateless_random_binomial(shape,
@tf_export("random.stateless_gamma") @tf_export("random.stateless_gamma")
@dispatch.add_dispatch_support
def stateless_random_gamma(shape, def stateless_random_gamma(shape,
seed, seed,
alpha, alpha,
@ -372,6 +378,7 @@ def stateless_random_gamma(shape,
@tf_export("random.stateless_poisson") @tf_export("random.stateless_poisson")
@dispatch.add_dispatch_support
def stateless_random_poisson(shape, def stateless_random_poisson(shape,
seed, seed,
lam, lam,
@ -434,6 +441,7 @@ def stateless_random_poisson(shape,
@tf_export("random.stateless_normal") @tf_export("random.stateless_normal")
@dispatch.add_dispatch_support
def stateless_random_normal(shape, def stateless_random_normal(shape,
seed, seed,
mean=0.0, mean=0.0,
@ -474,6 +482,7 @@ def stateless_random_normal(shape,
@tf_export("random.stateless_truncated_normal") @tf_export("random.stateless_truncated_normal")
@dispatch.add_dispatch_support
def stateless_truncated_normal(shape, def stateless_truncated_normal(shape,
seed, seed,
mean=0.0, mean=0.0,
@ -520,6 +529,7 @@ def stateless_truncated_normal(shape,
@tf_export(v1=["random.stateless_multinomial"]) @tf_export(v1=["random.stateless_multinomial"])
@dispatch.add_dispatch_support
@deprecation.deprecated( @deprecation.deprecated(
date=None, instructions="Use `tf.random.stateless_categorical` instead.") date=None, instructions="Use `tf.random.stateless_categorical` instead.")
def stateless_multinomial(logits, def stateless_multinomial(logits,
@ -562,6 +572,7 @@ def stateless_multinomial(logits,
@tf_export("random.stateless_categorical") @tf_export("random.stateless_categorical")
@dispatch.add_dispatch_support
def stateless_categorical(logits, def stateless_categorical(logits,
num_samples, num_samples,
seed, seed,

View File

@ -73,6 +73,7 @@ regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__
@tf_export( @tf_export(
"strings.regex_replace", v1=["strings.regex_replace", "regex_replace"]) "strings.regex_replace", v1=["strings.regex_replace", "regex_replace"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("regex_replace") @deprecation.deprecated_endpoints("regex_replace")
@dispatch.add_dispatch_support @dispatch.add_dispatch_support
def regex_replace(input, pattern, rewrite, replace_global=True, name=None): def regex_replace(input, pattern, rewrite, replace_global=True, name=None):
@ -112,6 +113,7 @@ def regex_replace(input, pattern, rewrite, replace_global=True, name=None):
@tf_export("strings.format") @tf_export("strings.format")
@dispatch.add_dispatch_support
def string_format(template, inputs, placeholder="{}", summarize=3, name=None): def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
r"""Formats a string template using a list of tensors. r"""Formats a string template using a list of tensors.
@ -300,6 +302,7 @@ def _reduce_join_reduction_dims(x, axis):
@tf_export(v1=["strings.reduce_join", "reduce_join"]) @tf_export(v1=["strings.reduce_join", "reduce_join"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, @deprecation.deprecated_args(None,
"keep_dims is deprecated, use keepdims instead", "keep_dims is deprecated, use keepdims instead",
"keep_dims") "keep_dims")
@ -412,6 +415,7 @@ string_length_v2.__doc__ = gen_string_ops.string_length.__doc__
@tf_export(v1=["substr"]) @tf_export(v1=["substr"])
@dispatch.add_dispatch_support
@deprecation.deprecated(None, "Use `tf.strings.substr` instead of `tf.substr`.") @deprecation.deprecated(None, "Use `tf.strings.substr` instead of `tf.substr`.")
def substr_deprecated(input, pos, len, name=None, unit="BYTE"): def substr_deprecated(input, pos, len, name=None, unit="BYTE"):
return substr(input, pos, len, name=name, unit=unit) return substr(input, pos, len, name=name, unit=unit)
@ -476,6 +480,7 @@ def string_to_number(input, out_type=dtypes.float32, name=None):
@tf_export(v1=["strings.to_number", "string_to_number"]) @tf_export(v1=["strings.to_number", "string_to_number"])
@dispatch.add_dispatch_support
def string_to_number_v1( def string_to_number_v1(
string_tensor=None, string_tensor=None,
out_type=dtypes.float32, out_type=dtypes.float32,
@ -519,6 +524,7 @@ def string_to_hash_bucket(input, num_buckets, name=None):
@tf_export(v1=["strings.to_hash_bucket", "string_to_hash_bucket"]) @tf_export(v1=["strings.to_hash_bucket", "string_to_hash_bucket"])
@dispatch.add_dispatch_support
def string_to_hash_bucket_v1( def string_to_hash_bucket_v1(
string_tensor=None, string_tensor=None,
num_buckets=None, num_buckets=None,
@ -532,6 +538,7 @@ string_to_hash_bucket_v1.__doc__ = gen_string_ops.string_to_hash_bucket.__doc__
@tf_export("strings.join", v1=["strings.join", "string_join"]) @tf_export("strings.join", v1=["strings.join", "string_join"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("string_join") @deprecation.deprecated_endpoints("string_join")
@dispatch.add_dispatch_support @dispatch.add_dispatch_support
def string_join(inputs, separator="", name=None): def string_join(inputs, separator="", name=None):