From 96d80dfa09b279ec191c3c97795d31a4f45433c4 Mon Sep 17 00:00:00 2001 From: sleighsoft Date: Fri, 31 May 2019 16:52:53 +0200 Subject: [PATCH 01/14] Added tf.math.normalize, tf.linalg.normalize and tf.nn.normalize as discussed in #28741 --- tensorflow/python/ops/nn_impl.py | 58 ++++++++++++++++++++++++++++++++ tensorflow/python/ops/nn_test.py | 52 ++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+) diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 126daeb3d14..6776aa150e0 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import gen_array_ops # pylint: disable=unused-import from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import gen_sparse_ops @@ -435,6 +436,63 @@ def swish(features): return features * math_ops.sigmoid(features) +@tf_export("math.normalize", "linalg.normalize", "nn.normalize") +def normalize(tensor, + ord='euclidean', + axis=None, + name=None): + """Normalizes `tensor` along dimension `axis` using specified norm. + + This uses `tf.linalg.norm` to compute the norm along `axis`. + + This function can compute several different vector norms (the 1-norm, the + Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0) and + matrix norms (Frobenius, 1-norm, 2-norm and inf-norm). + + Args: + tensor: `Tensor` of types `float32`, `float64`, `complex64`, `complex128` + ord: Order of the norm. Supported values are `'fro'`, `'euclidean'`, + `1`, `2`, `np.inf` and any positive real number yielding the corresponding + p-norm. Default is `'euclidean'` which is equivalent to Frobenius norm if + `tensor` is a matrix and equivalent to 2-norm for vectors. + Some restrictions apply: + a) The Frobenius norm `'fro'` is not defined for vectors, + b) If axis is a 2-tuple (matrix norm), only `'euclidean'`, '`fro'`, `1`, + `2`, `np.inf` are supported. + See the description of `axis` on how to compute norms for a batch of + vectors or matrices stored in a tensor. + axis: If `axis` is `None` (the default), the input is considered a vector + and a single vector norm is computed over the entire set of values in the + tensor, i.e. `norm(tensor, ord=ord)` is equivalent to + `norm(reshape(tensor, [-1]), ord=ord)`. + If `axis` is a Python integer, the input is considered a batch of vectors, + and `axis` determines the axis in `tensor` over which to compute vector + norms. + If `axis` is a 2-tuple of Python integers it is considered a batch of + matrices and `axis` determines the axes in `tensor` over which to compute + a matrix norm. + Negative indices are supported. Example: If you are passing a tensor that + can be either a matrix or a batch of matrices at runtime, pass + `axis=[-2,-1]` instead of `axis=None` to make sure that matrix norms are + computed. + name: The name of the op. + + Returns: + normalized: A normalized `Tensor` with the same shape as `tensor`. + norm: The computed norms with the same shape as `tensor` but the final axis + is 1 instead. Same as running + `tf.linalg.norm(tensor, ord, axis keepdims=True)`. + + Raises: + ValueError: If `ord` or `axis` is invalid. + """ + with ops.name_scope(name, "normalize", [tensor]) as name: + tensor = ops.convert_to_tensor(tensor) + norm = linalg_ops.norm(tensor, ord, axis, keepdims=True) + normalized = tensor / norm + return normalized, norm + + @tf_export(v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"]) @deprecated_args(None, "dim is deprecated, use axis instead", "dim") def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None): diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index df07721e5d3..ae81cb22bf3 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -253,6 +253,58 @@ class L2LossTest(test_lib.TestCase): self.assertLess(err, err_tolerance) +class NormalizeTest(test_lib.TestCase): + + def _Normalize(x, ord, axis): + if isinstance(axis, (list, tuple)): + norm = np.linalg.norm(x, ord, tuple(axis)) + if axis[0] < axis[1]: + # This prevents axis to be inserted in-between + # e.g. when (-2, -1) + for d in reversed(axis): + norm = np.expand_dims(norm, d) + else: + for d in axis: + norm = np.expand_dims(norm, d) + return x / norm + elif axis is None: + # Tensorlfow handles None differently + norm = np.linalg.norm(x.flatten(), ord, axis) + return x / norm + else: + norm = np.apply_along_axis(np.linalg.norm, axis, x, ord) + return x / np.expand_dims(norm, axis) + + @test_util.run_in_graph_and_eager_modes + def testNormalize(self): + for use_static_shape in False, True: + for dtype in np.float32, np.float64, np.complex64, np.complex128: + for rows in 2, 5: + for cols in 2, 5: + for batch in [], [2], [2, 3]: + shape = batch + [rows, cols] + for ord in "euclidean", "fro", 0.5, 1, 2, np.inf: + for axis in [ + None, (-2, -1), (-1, -2), -len(shape), 0, len(shape) - 1 + ]: + is_matrix_norm = (isinstance(axis, tuple) or + isinstance(axis, list)) and len(axis) == 2 + is_fancy_p_norm = np.isreal(ord) and np.floor(ord) != ord + if ((not is_matrix_norm and ord == "fro") or + (is_matrix_norm and is_fancy_p_norm)): + # Not supported by neither numpy.linalg.norm nor tf.norm + continue + if ord == "euclidean" or (axis is None and len(shape) > 2): + # Not supported by numpy.linalg.norm" + continue + matrix = np.random.randn(*shape).astype(dtype) + if dtype in (np.complex64, np.complex128): + matrix += 1j * np.random.randn(*shape).astype(dtype) + tf_n = nn_impl.normalize(matrix, ord, axis) + np_n = _Normalize(matrix, ord, axis) + self.assertAllClose(tf_n, np_n, rtol=1e-5, atol=1e-5) + + class L2NormalizeTest(test_lib.TestCase): def _l2Normalize(self, x, dim): From 1f087eae871b0eac6bbaa6471e8d42f29ca9006c Mon Sep 17 00:00:00 2001 From: sleighsoft Date: Fri, 31 May 2019 16:59:04 +0200 Subject: [PATCH 02/14] Fixed typo --- tensorflow/python/ops/nn_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index ae81cb22bf3..000d53b8f21 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -268,7 +268,7 @@ class NormalizeTest(test_lib.TestCase): norm = np.expand_dims(norm, d) return x / norm elif axis is None: - # Tensorlfow handles None differently + # Tensorflow handles None differently norm = np.linalg.norm(x.flatten(), ord, axis) return x / norm else: From 9cd3b856a732c62e803ad60d2464e5043a9be7c1 Mon Sep 17 00:00:00 2001 From: sleighsoft Date: Fri, 31 May 2019 20:30:43 +0200 Subject: [PATCH 03/14] Just exporting linalg.normalize --- tensorflow/python/ops/nn_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 6776aa150e0..8db18ae8f79 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -436,7 +436,7 @@ def swish(features): return features * math_ops.sigmoid(features) -@tf_export("math.normalize", "linalg.normalize", "nn.normalize") +@tf_export("linalg.normalize") def normalize(tensor, ord='euclidean', axis=None, From 54d1891e2c0a309b6112c1e66c91b478fd1f14dc Mon Sep 17 00:00:00 2001 From: sleighsoft Date: Tue, 4 Jun 2019 10:19:30 +0200 Subject: [PATCH 04/14] Added self, ignore redefined-builtin --- tensorflow/python/ops/nn_impl.py | 1 + tensorflow/python/ops/nn_test.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 8db18ae8f79..c5296df35d0 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -436,6 +436,7 @@ def swish(features): return features * math_ops.sigmoid(features) +# pylint: disable=redefined-builtin @tf_export("linalg.normalize") def normalize(tensor, ord='euclidean', diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 000d53b8f21..adb6546619e 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -255,7 +255,8 @@ class L2LossTest(test_lib.TestCase): class NormalizeTest(test_lib.TestCase): - def _Normalize(x, ord, axis): + # pylint: disable=redefined-builtin + def _Normalize(self, x, ord, axis): if isinstance(axis, (list, tuple)): norm = np.linalg.norm(x, ord, tuple(axis)) if axis[0] < axis[1]: @@ -275,6 +276,7 @@ class NormalizeTest(test_lib.TestCase): norm = np.apply_along_axis(np.linalg.norm, axis, x, ord) return x / np.expand_dims(norm, axis) + # pylint: disable=redefined-builtin @test_util.run_in_graph_and_eager_modes def testNormalize(self): for use_static_shape in False, True: @@ -301,7 +303,7 @@ class NormalizeTest(test_lib.TestCase): if dtype in (np.complex64, np.complex128): matrix += 1j * np.random.randn(*shape).astype(dtype) tf_n = nn_impl.normalize(matrix, ord, axis) - np_n = _Normalize(matrix, ord, axis) + np_n = self._Normalize(matrix, ord, axis) self.assertAllClose(tf_n, np_n, rtol=1e-5, atol=1e-5) From 418491d1082e5e95033ab09fc425ec68b1463493 Mon Sep 17 00:00:00 2001 From: sleighsoft Date: Tue, 4 Jun 2019 10:23:36 +0200 Subject: [PATCH 05/14] Fixed indentation and removed use_static_shape --- tensorflow/python/ops/nn_test.py | 51 ++++++++++++++++---------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index adb6546619e..9ff5305216c 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -279,32 +279,31 @@ class NormalizeTest(test_lib.TestCase): # pylint: disable=redefined-builtin @test_util.run_in_graph_and_eager_modes def testNormalize(self): - for use_static_shape in False, True: - for dtype in np.float32, np.float64, np.complex64, np.complex128: - for rows in 2, 5: - for cols in 2, 5: - for batch in [], [2], [2, 3]: - shape = batch + [rows, cols] - for ord in "euclidean", "fro", 0.5, 1, 2, np.inf: - for axis in [ - None, (-2, -1), (-1, -2), -len(shape), 0, len(shape) - 1 - ]: - is_matrix_norm = (isinstance(axis, tuple) or - isinstance(axis, list)) and len(axis) == 2 - is_fancy_p_norm = np.isreal(ord) and np.floor(ord) != ord - if ((not is_matrix_norm and ord == "fro") or - (is_matrix_norm and is_fancy_p_norm)): - # Not supported by neither numpy.linalg.norm nor tf.norm - continue - if ord == "euclidean" or (axis is None and len(shape) > 2): - # Not supported by numpy.linalg.norm" - continue - matrix = np.random.randn(*shape).astype(dtype) - if dtype in (np.complex64, np.complex128): - matrix += 1j * np.random.randn(*shape).astype(dtype) - tf_n = nn_impl.normalize(matrix, ord, axis) - np_n = self._Normalize(matrix, ord, axis) - self.assertAllClose(tf_n, np_n, rtol=1e-5, atol=1e-5) + for dtype in np.float32, np.float64, np.complex64, np.complex128: + for rows in 2, 5: + for cols in 2, 5: + for batch in [], [2], [2, 3]: + shape = batch + [rows, cols] + for ord in "euclidean", "fro", 0.5, 1, 2, np.inf: + for axis in [ + None, (-2, -1), (-1, -2), -len(shape), 0, len(shape) - 1 + ]: + is_matrix_norm = (isinstance(axis, tuple) or + isinstance(axis, list)) and len(axis) == 2 + is_fancy_p_norm = np.isreal(ord) and np.floor(ord) != ord + if ((not is_matrix_norm and ord == "fro") or + (is_matrix_norm and is_fancy_p_norm)): + # Not supported by neither numpy.linalg.norm nor tf.norm + continue + if ord == "euclidean" or (axis is None and len(shape) > 2): + # Not supported by numpy.linalg.norm" + continue + matrix = np.random.randn(*shape).astype(dtype) + if dtype in (np.complex64, np.complex128): + matrix += 1j * np.random.randn(*shape).astype(dtype) + tf_n = nn_impl.normalize(matrix, ord, axis) + np_n = self._Normalize(matrix, ord, axis) + self.assertAllClose(tf_n, np_n, rtol=1e-5, atol=1e-5) class L2NormalizeTest(test_lib.TestCase): From a950fb44f2b6c6ce16b8804c2534b2f637fb0958 Mon Sep 17 00:00:00 2001 From: Julian Niedermeier Date: Fri, 7 Jun 2019 14:34:57 +0200 Subject: [PATCH 06/14] Properly unpack .normalize result which caused wrong shape --- tensorflow/python/ops/nn_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 9ff5305216c..f8078ddd2e6 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -301,7 +301,7 @@ class NormalizeTest(test_lib.TestCase): matrix = np.random.randn(*shape).astype(dtype) if dtype in (np.complex64, np.complex128): matrix += 1j * np.random.randn(*shape).astype(dtype) - tf_n = nn_impl.normalize(matrix, ord, axis) + tf_n, _ = nn_impl.normalize(matrix, ord, axis) np_n = self._Normalize(matrix, ord, axis) self.assertAllClose(tf_n, np_n, rtol=1e-5, atol=1e-5) From 81cff046fc5ed28f9f17b4f45c9ad23bf9e7cd06 Mon Sep 17 00:00:00 2001 From: Julian Niedermeier Date: Fri, 7 Jun 2019 14:39:26 +0200 Subject: [PATCH 07/14] Cast norm to same dtype as tensor --- tensorflow/python/ops/nn_impl.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index c5296df35d0..c9c492cefe9 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -480,9 +480,9 @@ def normalize(tensor, Returns: normalized: A normalized `Tensor` with the same shape as `tensor`. - norm: The computed norms with the same shape as `tensor` but the final axis - is 1 instead. Same as running - `tf.linalg.norm(tensor, ord, axis keepdims=True)`. + norm: The computed norms with the same shape and dtype `tensor` but the + final axis is 1 instead. Same as running + `tf.cast(tf.linalg.norm(tensor, ord, axis keepdims=True), tensor.dtype)`. Raises: ValueError: If `ord` or `axis` is invalid. @@ -490,6 +490,7 @@ def normalize(tensor, with ops.name_scope(name, "normalize", [tensor]) as name: tensor = ops.convert_to_tensor(tensor) norm = linalg_ops.norm(tensor, ord, axis, keepdims=True) + norm = math_ops.cast(norm, tensor.dtype) normalized = tensor / norm return normalized, norm From 36fbc3c6db7bba84b86deced94eaf8c970454a29 Mon Sep 17 00:00:00 2001 From: sleighsoft Date: Mon, 10 Jun 2019 14:20:28 +0200 Subject: [PATCH 08/14] Added separate test to prevent out of time --- .../python/kernel_tests/normalize_op_test.py | 98 +++++++++++++++++++ tensorflow/python/ops/nn_test.py | 53 ---------- 2 files changed, 98 insertions(+), 53 deletions(-) create mode 100644 tensorflow/python/kernel_tests/normalize_op_test.py diff --git a/tensorflow/python/kernel_tests/normalize_op_test.py b/tensorflow/python/kernel_tests/normalize_op_test.py new file mode 100644 index 00000000000..37d3844835f --- /dev/null +++ b/tensorflow/python/kernel_tests/normalize_op_test.py @@ -0,0 +1,98 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.tf.norm.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import test_util +from tensorflow.python.ops import nn_impl +from tensorflow.python.platform import test as test_lib + + +def _AddTest(test, test_name, fn): + test_name = "_".join(["test", test_name]) + if hasattr(test, test_name): + raise RuntimeError("Test %s defined more than once" % test_name) + setattr(test, test_name, fn) + + +# pylint: disable=redefined-builtin +def _Normalize(x, ord, axis): + if isinstance(axis, (list, tuple)): + norm = np.linalg.norm(x, ord, tuple(axis)) + if axis[0] < axis[1]: + # This prevents axis to be inserted in-between + # e.g. when (-2, -1) + for d in reversed(axis): + norm = np.expand_dims(norm, d) + else: + for d in axis: + norm = np.expand_dims(norm, d) + return x / norm + elif axis is None: + # Tensorflow handles None differently + norm = np.linalg.norm(x.flatten(), ord, axis) + return x / norm + else: + norm = np.apply_along_axis(np.linalg.norm, axis, x, ord) + return x / np.expand_dims(norm, axis) + + +class NormalizeOpTest(test_lib.TestCase): + pass + + +def _GetNormalizeOpTest(dtype_, shape_, ord_, axis_): + + @test_util.run_in_graph_and_eager_modes + def Test(self): + is_matrix_norm = (isinstance(axis_, tuple) or + isinstance(axis_, list)) and len(axis_) == 2 + is_fancy_p_norm = np.isreal(ord_) and np.floor(ord_) != ord_ + if ((not is_matrix_norm and ord_ == "fro") or + (is_matrix_norm and is_fancy_p_norm)): + self.skipTest("Not supported by neither numpy.linalg.norm nor tf.norm") + if ord_ == "euclidean" or (axis_ is None and len(shape) > 2): + self.skipTest("Not supported by numpy.linalg.norm") + matrix = np.random.randn(*shape_).astype(dtype_) + if dtype_ in (np.complex64, np.complex128): + matrix += 1j * np.random.randn(*shape_).astype(dtype_) + tf_np_n, _ = self.evaluate(nn_impl.normalize(matrix, ord_, axis_)) + np_n = _Normalize(matrix, ord_, axis_) + self.assertAllClose(tf_np_n, np_n, rtol=1e-5, atol=1e-5) + + return Test + +# pylint: disable=redefined-builtin +if __name__ == "__main__": + for dtype in np.float32, np.float64, np.complex64, np.complex128: + for rows in 2, 5: + for cols in 2, 5: + for batch in [], [2], [2, 3]: + shape = batch + [rows, cols] + for ord in "euclidean", "fro", 0.5, 1, 2, np.inf: + for axis in [ + None, (-2, -1), (-1, -2), -len(shape), 0, len(shape) - 1 + ]: + name = "%s_%s_ord_%s_axis_%s" % ( + dtype.__name__, "_".join(map(str, shape)), ord, axis) + _AddTest(NormalizeOpTest, "Normalize_" + name, + _GetNormalizeOpTest(dtype, shape, ord, axis)) + + test_lib.main() diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index f8078ddd2e6..df07721e5d3 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -253,59 +253,6 @@ class L2LossTest(test_lib.TestCase): self.assertLess(err, err_tolerance) -class NormalizeTest(test_lib.TestCase): - - # pylint: disable=redefined-builtin - def _Normalize(self, x, ord, axis): - if isinstance(axis, (list, tuple)): - norm = np.linalg.norm(x, ord, tuple(axis)) - if axis[0] < axis[1]: - # This prevents axis to be inserted in-between - # e.g. when (-2, -1) - for d in reversed(axis): - norm = np.expand_dims(norm, d) - else: - for d in axis: - norm = np.expand_dims(norm, d) - return x / norm - elif axis is None: - # Tensorflow handles None differently - norm = np.linalg.norm(x.flatten(), ord, axis) - return x / norm - else: - norm = np.apply_along_axis(np.linalg.norm, axis, x, ord) - return x / np.expand_dims(norm, axis) - - # pylint: disable=redefined-builtin - @test_util.run_in_graph_and_eager_modes - def testNormalize(self): - for dtype in np.float32, np.float64, np.complex64, np.complex128: - for rows in 2, 5: - for cols in 2, 5: - for batch in [], [2], [2, 3]: - shape = batch + [rows, cols] - for ord in "euclidean", "fro", 0.5, 1, 2, np.inf: - for axis in [ - None, (-2, -1), (-1, -2), -len(shape), 0, len(shape) - 1 - ]: - is_matrix_norm = (isinstance(axis, tuple) or - isinstance(axis, list)) and len(axis) == 2 - is_fancy_p_norm = np.isreal(ord) and np.floor(ord) != ord - if ((not is_matrix_norm and ord == "fro") or - (is_matrix_norm and is_fancy_p_norm)): - # Not supported by neither numpy.linalg.norm nor tf.norm - continue - if ord == "euclidean" or (axis is None and len(shape) > 2): - # Not supported by numpy.linalg.norm" - continue - matrix = np.random.randn(*shape).astype(dtype) - if dtype in (np.complex64, np.complex128): - matrix += 1j * np.random.randn(*shape).astype(dtype) - tf_n, _ = nn_impl.normalize(matrix, ord, axis) - np_n = self._Normalize(matrix, ord, axis) - self.assertAllClose(tf_n, np_n, rtol=1e-5, atol=1e-5) - - class L2NormalizeTest(test_lib.TestCase): def _l2Normalize(self, x, dim): From 45cdac0bfca7f7dd5a9ebb1bffe5d21993083148 Mon Sep 17 00:00:00 2001 From: sleighsoft Date: Mon, 10 Jun 2019 20:28:34 +0200 Subject: [PATCH 09/14] Add normalize_op_test to BUILD --- tensorflow/python/kernel_tests/BUILD | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 0c4e47b7972..22148f7f908 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -3506,6 +3506,20 @@ cuda_py_test( xla_enable_strict_auto_jit = True, ) +cuda_py_test( + name = "normalize_op_test", + size = "medium", + srcs = ["normalize_op_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python/ops:nn_impl", + ], + shard_count = 20, + xla_enable_strict_auto_jit = True, +) + cuda_py_test( name = "tensordot_op_test", size = "medium", From 57473526b0d754e8aef9f914ab115c8473d807e3 Mon Sep 17 00:00:00 2001 From: sleighsoft Date: Mon, 10 Jun 2019 21:49:21 +0200 Subject: [PATCH 10/14] Fix BUILD ops:nn --- tensorflow/python/kernel_tests/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 22148f7f908..de3ad00b80b 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -3514,7 +3514,7 @@ cuda_py_test( "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python/ops:nn_impl", + "//tensorflow/python/ops:nn", ], shard_count = 20, xla_enable_strict_auto_jit = True, From 12dc5c0ef2632c3bfa727393a65564671c6f98d2 Mon Sep 17 00:00:00 2001 From: sleighsoft Date: Tue, 11 Jun 2019 22:25:35 +0200 Subject: [PATCH 11/14] Fixed BUILD --- tensorflow/python/kernel_tests/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index de3ad00b80b..4f9d4a16159 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -3514,7 +3514,7 @@ cuda_py_test( "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python/ops:nn", + "//tensorflow/python:nn", ], shard_count = 20, xla_enable_strict_auto_jit = True, From 888a2eb2982b6a1e974f71f9192231777231ca75 Mon Sep 17 00:00:00 2001 From: sleighsoft Date: Wed, 12 Jun 2019 13:08:20 +0200 Subject: [PATCH 12/14] Updated golden v1,v2 --- tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt | 4 ++++ tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt index a5b312343ed..018ab07cc20 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt @@ -168,6 +168,10 @@ tf_module { name: "norm" argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "normalize" + argspec: "args=[\'tensor\', \'ord\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\'], " + } member_method { name: "qr" argspec: "args=[\'input\', \'full_matrices\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt index d5ab294a317..9b4909937d7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt @@ -168,6 +168,10 @@ tf_module { name: "norm" argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\'], " } + member_method { + name: "normalize" + argspec: "args=[\'tensor\', \'ord\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\'], " + } member_method { name: "qr" argspec: "args=[\'input\', \'full_matrices\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " From ab325a80f865ac285bd42715af6ca08f82e89f29 Mon Sep 17 00:00:00 2001 From: sleighsoft Date: Thu, 13 Jun 2019 16:08:45 +0200 Subject: [PATCH 13/14] Updated indentation --- tensorflow/python/kernel_tests/normalize_op_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/kernel_tests/normalize_op_test.py b/tensorflow/python/kernel_tests/normalize_op_test.py index 37d3844835f..30e5c43c319 100644 --- a/tensorflow/python/kernel_tests/normalize_op_test.py +++ b/tensorflow/python/kernel_tests/normalize_op_test.py @@ -40,15 +40,15 @@ def _Normalize(x, ord, axis): # This prevents axis to be inserted in-between # e.g. when (-2, -1) for d in reversed(axis): - norm = np.expand_dims(norm, d) + norm = np.expand_dims(norm, d) else: for d in axis: - norm = np.expand_dims(norm, d) + norm = np.expand_dims(norm, d) return x / norm elif axis is None: - # Tensorflow handles None differently - norm = np.linalg.norm(x.flatten(), ord, axis) - return x / norm + # Tensorflow handles None differently + norm = np.linalg.norm(x.flatten(), ord, axis) + return x / norm else: norm = np.apply_along_axis(np.linalg.norm, axis, x, ord) return x / np.expand_dims(norm, axis) @@ -93,6 +93,6 @@ if __name__ == "__main__": name = "%s_%s_ord_%s_axis_%s" % ( dtype.__name__, "_".join(map(str, shape)), ord, axis) _AddTest(NormalizeOpTest, "Normalize_" + name, - _GetNormalizeOpTest(dtype, shape, ord, axis)) + _GetNormalizeOpTest(dtype, shape, ord, axis)) test_lib.main() From afefdce8bc7558f0df2eaef8c098740129c3e77c Mon Sep 17 00:00:00 2001 From: sleighsoft Date: Thu, 13 Jun 2019 20:07:24 +0200 Subject: [PATCH 14/14] Add BUILD tags of norm_op_test to normalize_op_test --- tensorflow/python/kernel_tests/BUILD | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 4f9d4a16159..bf8fc38a29e 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -3517,6 +3517,11 @@ cuda_py_test( "//tensorflow/python:nn", ], shard_count = 20, + # TODO(b/117236102): Re-enable in msan build. + tags = [ + "no_windows_gpu", + "nomsan", + ], xla_enable_strict_auto_jit = True, )