Merge pull request #29691 from sleighsoft:master
PiperOrigin-RevId: 254776377
This commit is contained in:
commit
5b9ebedc98
@ -3540,6 +3540,25 @@ cuda_py_test(
|
|||||||
xla_enable_strict_auto_jit = True,
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "normalize_op_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["normalize_op_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
"//tensorflow/python:nn",
|
||||||
|
],
|
||||||
|
shard_count = 20,
|
||||||
|
# TODO(b/117236102): Re-enable in msan build.
|
||||||
|
tags = [
|
||||||
|
"no_windows_gpu",
|
||||||
|
"nomsan",
|
||||||
|
],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "tensordot_op_test",
|
name = "tensordot_op_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
|
100
tensorflow/python/kernel_tests/normalize_op_test.py
Normal file
100
tensorflow/python/kernel_tests/normalize_op_test.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tensorflow.ops.tf.norm."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import nn_impl
|
||||||
|
from tensorflow.python.platform import test as test_lib
|
||||||
|
|
||||||
|
|
||||||
|
def _AddTest(test, test_name, fn):
|
||||||
|
test_name = "_".join(["test", test_name])
|
||||||
|
if hasattr(test, test_name):
|
||||||
|
raise RuntimeError("Test %s defined more than once" % test_name)
|
||||||
|
setattr(test, test_name, fn)
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=redefined-builtin
|
||||||
|
def _Normalize(x, ord, axis):
|
||||||
|
if isinstance(axis, (list, tuple)):
|
||||||
|
norm = np.linalg.norm(x, ord, tuple(axis))
|
||||||
|
if axis[0] < axis[1]:
|
||||||
|
# This prevents axis to be inserted in-between
|
||||||
|
# e.g. when (-2, -1)
|
||||||
|
for d in reversed(axis):
|
||||||
|
norm = np.expand_dims(norm, d)
|
||||||
|
else:
|
||||||
|
for d in axis:
|
||||||
|
norm = np.expand_dims(norm, d)
|
||||||
|
return x / norm
|
||||||
|
elif axis is None:
|
||||||
|
# Tensorflow handles None differently
|
||||||
|
norm = np.linalg.norm(x.flatten(), ord, axis)
|
||||||
|
return x / norm
|
||||||
|
else:
|
||||||
|
norm = np.apply_along_axis(np.linalg.norm, axis, x, ord)
|
||||||
|
return x / np.expand_dims(norm, axis)
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeOpTest(test_lib.TestCase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _GetNormalizeOpTest(dtype_, shape_, ord_, axis_):
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def Test(self):
|
||||||
|
is_matrix_norm = (isinstance(axis_, tuple) or
|
||||||
|
isinstance(axis_, list)) and len(axis_) == 2
|
||||||
|
is_fancy_p_norm = np.isreal(ord_) and np.floor(ord_) != ord_
|
||||||
|
if ((not is_matrix_norm and ord_ == "fro") or
|
||||||
|
(is_matrix_norm and is_fancy_p_norm)):
|
||||||
|
self.skipTest("Not supported by neither numpy.linalg.norm nor tf.norm")
|
||||||
|
if ord_ == "euclidean" or (axis_ is None and len(shape) > 2):
|
||||||
|
self.skipTest("Not supported by numpy.linalg.norm")
|
||||||
|
matrix = np.random.randn(*shape_).astype(dtype_)
|
||||||
|
if dtype_ in (np.complex64, np.complex128):
|
||||||
|
matrix += 1j * np.random.randn(*shape_).astype(dtype_)
|
||||||
|
tf_np_n, _ = self.evaluate(nn_impl.normalize(matrix, ord_, axis_))
|
||||||
|
np_n = _Normalize(matrix, ord_, axis_)
|
||||||
|
self.assertAllClose(tf_np_n, np_n, rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
|
return Test
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=redefined-builtin
|
||||||
|
if __name__ == "__main__":
|
||||||
|
for dtype in np.float32, np.float64, np.complex64, np.complex128:
|
||||||
|
for rows in 2, 5:
|
||||||
|
for cols in 2, 5:
|
||||||
|
for batch in [], [2], [2, 3]:
|
||||||
|
shape = batch + [rows, cols]
|
||||||
|
for ord in "euclidean", "fro", 0.5, 1, 2, np.inf:
|
||||||
|
for axis in [
|
||||||
|
None, (-2, -1), (-1, -2), -len(shape), 0,
|
||||||
|
len(shape) - 1
|
||||||
|
]:
|
||||||
|
name = "%s_%s_ord_%s_axis_%s" % (dtype.__name__, "_".join(
|
||||||
|
map(str, shape)), ord, axis)
|
||||||
|
_AddTest(NormalizeOpTest, "Normalize_" + name,
|
||||||
|
_GetNormalizeOpTest(dtype, shape, ord, axis))
|
||||||
|
|
||||||
|
test_lib.main()
|
@ -32,6 +32,7 @@ from tensorflow.python.ops import control_flow_ops
|
|||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
from tensorflow.python.ops import gen_array_ops # pylint: disable=unused-import
|
from tensorflow.python.ops import gen_array_ops # pylint: disable=unused-import
|
||||||
from tensorflow.python.ops import gen_nn_ops
|
from tensorflow.python.ops import gen_nn_ops
|
||||||
|
from tensorflow.python.ops import linalg_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
from tensorflow.python.ops import gen_sparse_ops
|
from tensorflow.python.ops import gen_sparse_ops
|
||||||
@ -531,6 +532,59 @@ def swish(features):
|
|||||||
return features * math_ops.sigmoid(features)
|
return features * math_ops.sigmoid(features)
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=redefined-builtin
|
||||||
|
@tf_export("linalg.normalize")
|
||||||
|
def normalize(tensor, ord="euclidean", axis=None, name=None):
|
||||||
|
"""Normalizes `tensor` along dimension `axis` using specified norm.
|
||||||
|
|
||||||
|
This uses `tf.linalg.norm` to compute the norm along `axis`.
|
||||||
|
|
||||||
|
This function can compute several different vector norms (the 1-norm, the
|
||||||
|
Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0) and
|
||||||
|
matrix norms (Frobenius, 1-norm, 2-norm and inf-norm).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: `Tensor` of types `float32`, `float64`, `complex64`, `complex128`
|
||||||
|
ord: Order of the norm. Supported values are `'fro'`, `'euclidean'`, `1`,
|
||||||
|
`2`, `np.inf` and any positive real number yielding the corresponding
|
||||||
|
p-norm. Default is `'euclidean'` which is equivalent to Frobenius norm if
|
||||||
|
`tensor` is a matrix and equivalent to 2-norm for vectors.
|
||||||
|
Some restrictions apply: a) The Frobenius norm `'fro'` is not defined for
|
||||||
|
vectors, b) If axis is a 2-tuple (matrix norm), only `'euclidean'`,
|
||||||
|
'`fro'`, `1`, `2`, `np.inf` are supported. See the description of `axis`
|
||||||
|
on how to compute norms for a batch of vectors or matrices stored in a
|
||||||
|
tensor.
|
||||||
|
axis: If `axis` is `None` (the default), the input is considered a vector
|
||||||
|
and a single vector norm is computed over the entire set of values in the
|
||||||
|
tensor, i.e. `norm(tensor, ord=ord)` is equivalent to
|
||||||
|
`norm(reshape(tensor, [-1]), ord=ord)`. If `axis` is a Python integer, the
|
||||||
|
input is considered a batch of vectors, and `axis` determines the axis in
|
||||||
|
`tensor` over which to compute vector norms. If `axis` is a 2-tuple of
|
||||||
|
Python integers it is considered a batch of matrices and `axis` determines
|
||||||
|
the axes in `tensor` over which to compute a matrix norm.
|
||||||
|
Negative indices are supported. Example: If you are passing a tensor that
|
||||||
|
can be either a matrix or a batch of matrices at runtime, pass
|
||||||
|
`axis=[-2,-1]` instead of `axis=None` to make sure that matrix norms are
|
||||||
|
computed.
|
||||||
|
name: The name of the op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
normalized: A normalized `Tensor` with the same shape as `tensor`.
|
||||||
|
norm: The computed norms with the same shape and dtype `tensor` but the
|
||||||
|
final axis is 1 instead. Same as running
|
||||||
|
`tf.cast(tf.linalg.norm(tensor, ord, axis keepdims=True), tensor.dtype)`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `ord` or `axis` is invalid.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(name, "normalize", [tensor]) as name:
|
||||||
|
tensor = ops.convert_to_tensor(tensor)
|
||||||
|
norm = linalg_ops.norm(tensor, ord, axis, keepdims=True)
|
||||||
|
norm = math_ops.cast(norm, tensor.dtype)
|
||||||
|
normalized = tensor / norm
|
||||||
|
return normalized, norm
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"])
|
@tf_export(v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"])
|
||||||
@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
|
@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
|
||||||
def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
|
def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
|
||||||
|
@ -168,6 +168,10 @@ tf_module {
|
|||||||
name: "norm"
|
name: "norm"
|
||||||
argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "normalize"
|
||||||
|
argspec: "args=[\'tensor\', \'ord\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "qr"
|
name: "qr"
|
||||||
argspec: "args=[\'input\', \'full_matrices\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
argspec: "args=[\'input\', \'full_matrices\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
|
@ -168,6 +168,10 @@ tf_module {
|
|||||||
name: "norm"
|
name: "norm"
|
||||||
argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "normalize"
|
||||||
|
argspec: "args=[\'tensor\', \'ord\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "qr"
|
name: "qr"
|
||||||
argspec: "args=[\'input\', \'full_matrices\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
argspec: "args=[\'input\', \'full_matrices\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
|
Loading…
Reference in New Issue
Block a user