Add a Euclidean norm reduction kernel. This implements a fused sqrt(reduce_sum(x * conj(x))) kernels for CPU (using Eigen) and GPU (using CUB), which is more efficient than the composite implementation at the TF level. It will also be easier to avoid the issue of producing NaNs in the gradient at the origin.
Adds tf.math.reduce_euclidian_norm() Python interface to call the fused reduction kernel directly. Gradients will be added in a followup change. PiperOrigin-RevId: 234188431
This commit is contained in:
parent
5113410e10
commit
398fce0307
39
tensorflow/core/api_def/base_api/api_def_EuclideanNorm.pbtxt
Normal file
39
tensorflow/core/api_def/base_api/api_def_EuclideanNorm.pbtxt
Normal file
@ -0,0 +1,39 @@
|
||||
op {
|
||||
graph_op_name: "EuclideanNorm"
|
||||
endpoint {
|
||||
name: "EuclideanNorm"
|
||||
}
|
||||
in_arg {
|
||||
name: "input"
|
||||
description: <<END
|
||||
The tensor to reduce.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "reduction_indices"
|
||||
rename_to: "axis"
|
||||
description: <<END
|
||||
The dimensions to reduce. Must be in the range
|
||||
`[-rank(input), rank(input))`.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
The reduced tensor.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "keep_dims"
|
||||
description: <<END
|
||||
If true, retain reduced dimensions with length 1.
|
||||
END
|
||||
}
|
||||
summary: "Computes the euclidean norm of elements across dimensions of a tensor."
|
||||
description: <<END
|
||||
Reduces `input` along the dimensions given in `reduction_indices`. Unless
|
||||
`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
|
||||
`reduction_indices`. If `keep_dims` is true, the reduced dimensions are
|
||||
retained with length 1.
|
||||
END
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "EuclideanNorm"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -40,6 +40,20 @@ namespace functor {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <typename T>
|
||||
struct Square {
|
||||
__host__ __device__ T operator()(const T& a) const {
|
||||
return a * Eigen::numext::conj(a);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Sqrt {
|
||||
__host__ __device__ T operator()(const T& a) const {
|
||||
return Eigen::numext::sqrt(a);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Sum {
|
||||
__host__ __device__ T operator()(const T& a, const T& b) const {
|
||||
@ -884,6 +898,31 @@ struct ReduceFunctor<GPUDevice, Eigen::internal::SumReducer<T>> {
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(rmlarsen): Specialize for float16.
|
||||
template <typename T>
|
||||
struct ReduceFunctor<GPUDevice, functor::EuclideanNormReducer<T>> {
|
||||
template <typename OUT_T, typename IN_T, typename ReductionAxes>
|
||||
static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
|
||||
const ReductionAxes& reduction_axes,
|
||||
const functor::EuclideanNormReducer<T>& reducer) {
|
||||
typedef cub::TransformInputIterator<T, Square<T>, T*> inputIterType;
|
||||
inputIterType input_itr((T*)in.data(), Square<T>());
|
||||
typedef TransformOutputIterator<T, T, Sqrt<T>> outputIterType;
|
||||
outputIterType output_itr((T*)out.data(), Sqrt<T>());
|
||||
ReduceImpl<T, Sum<T>, outputIterType, inputIterType, ReductionAxes>(
|
||||
ctx, output_itr, input_itr, in.rank(), in.dimension(0),
|
||||
in.rank() >= 2 ? in.dimension(1) : 1,
|
||||
in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
|
||||
Sum<T>());
|
||||
}
|
||||
|
||||
template <typename OUT_T>
|
||||
static void FillIdentity(const GPUDevice& d, OUT_T out,
|
||||
const functor::EuclideanNormReducer<T>& reducer) {
|
||||
FillIdentityEigenImpl(d, To32Bit(out), reducer);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceFunctor<GPUDevice, functor::MeanReducer<T>> {
|
||||
template <typename OUT_T, typename IN_T, typename ReductionAxes>
|
||||
|
@ -33,6 +33,12 @@ struct MeanReducer {
|
||||
Scalar initialize() const { return Scalar(0); }
|
||||
};
|
||||
|
||||
// Dummy class used for template specialization for l2-norm reduction.
|
||||
template <typename Scalar>
|
||||
struct EuclideanNormReducer {
|
||||
Scalar initialize() const { return Scalar(0); }
|
||||
};
|
||||
|
||||
template <typename Device, typename OUT_T, typename IN_T,
|
||||
typename ReductionAxes, typename Reducer>
|
||||
struct ReduceEigenImpl {
|
||||
@ -56,6 +62,39 @@ struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(rmlarsen): Refactor this such that taking the sqrt can be optional
|
||||
// controlled by an attribute.
|
||||
template <typename Device, typename OUT_T, typename IN_T,
|
||||
typename ReductionAxes, typename Scalar>
|
||||
struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,
|
||||
functor::EuclideanNormReducer<Scalar>> {
|
||||
void operator()(const Device& d, OUT_T out, IN_T in,
|
||||
const ReductionAxes& reduction_axes,
|
||||
const functor::EuclideanNormReducer<Scalar>& reducer) {
|
||||
static_assert(std::is_same<Scalar, typename OUT_T::Scalar>::value, "");
|
||||
Eigen::internal::SumReducer<Scalar> sum_reducer;
|
||||
out.device(d) =
|
||||
(in * in.conjugate()).reduce(reduction_axes, sum_reducer).sqrt();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename OUT_T, typename IN_T,
|
||||
typename ReductionAxes>
|
||||
struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,
|
||||
functor::EuclideanNormReducer<bfloat16>> {
|
||||
void operator()(const Device& d, OUT_T out, IN_T in,
|
||||
const ReductionAxes& reduction_axes,
|
||||
const functor::EuclideanNormReducer<bfloat16>& reducer) {
|
||||
static_assert(std::is_same<bfloat16, typename OUT_T::Scalar>::value, "");
|
||||
Eigen::internal::SumReducer<float> sum_reducer;
|
||||
auto in_as_float = in.template cast<float>();
|
||||
out.device(d) = (in_as_float * in_as_float.conjugate())
|
||||
.reduce(reduction_axes, sum_reducer)
|
||||
.sqrt()
|
||||
.template cast<bfloat16>();
|
||||
}
|
||||
};
|
||||
|
||||
// For most reducers, the identity is Reducer::initialize()
|
||||
template <typename Reducer>
|
||||
struct Identity {
|
||||
|
81
tensorflow/core/kernels/reduction_ops_euclidean.cc
Normal file
81
tensorflow/core/kernels/reduction_ops_euclidean.cc
Normal file
@ -0,0 +1,81 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/reduction_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#define REGISTER_CPU_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("EuclideanNorm") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx"), \
|
||||
ReductionOp<CPUDevice, type, int32, \
|
||||
functor::EuclideanNormReducer<type>>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("EuclideanNorm") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int64>("Tidx"), \
|
||||
ReductionOp<CPUDevice, type, int64, \
|
||||
functor::EuclideanNormReducer<type>>);
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
|
||||
#undef REGISTER_CPU_KERNELS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#define REGISTER_GPU_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("EuclideanNorm") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx") \
|
||||
.HostMemory("reduction_indices"), \
|
||||
ReductionOp<GPUDevice, type, int32, \
|
||||
functor::EuclideanNormReducer<type>>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("EuclideanNorm") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int64>("Tidx") \
|
||||
.HostMemory("reduction_indices"), \
|
||||
ReductionOp<GPUDevice, type, int64, \
|
||||
functor::EuclideanNormReducer<type>>);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_complex64(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_GPU_KERNELS);
|
||||
#undef REGISTER_GPU_KERNELS
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
#define REGISTER_SYCL_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("EuclideanNorm") \
|
||||
.Device(DEVICE_SYCL) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx") \
|
||||
.HostMemory("reduction_indices"), \
|
||||
ReductionOp<SYCLDevice, type, int32, \
|
||||
functor::EuclideanNormReducer<type>>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("EuclideanNorm") \
|
||||
.Device(DEVICE_SYCL) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int64>("Tidx") \
|
||||
.HostMemory("reduction_indices"), \
|
||||
ReductionOp<SYCLDevice, type, int64, \
|
||||
functor::EuclideanNormReducer<type>>);
|
||||
REGISTER_SYCL_KERNELS(float);
|
||||
REGISTER_SYCL_KERNELS(double);
|
||||
#undef REGISTER_SYCL_KERNELS
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
} // namespace tensorflow
|
@ -53,6 +53,7 @@ typedef TTypes<float>::Tensor::Index Index;
|
||||
|
||||
DEFINE_FOR_TYPE_AND_R(complex128, Eigen::internal::SumReducer<complex128>);
|
||||
DEFINE_FOR_TYPE_AND_R(complex128, functor::MeanReducer<complex128>);
|
||||
DEFINE_FOR_TYPE_AND_R(complex128, functor::EuclideanNormReducer<complex128>);
|
||||
DEFINE_FOR_TYPE_AND_R(complex128, Eigen::internal::ProdReducer<complex128>);
|
||||
#undef DEFINE_FOR_TYPE_AND_R
|
||||
#undef DEFINE
|
||||
|
@ -53,6 +53,7 @@ typedef TTypes<float>::Tensor::Index Index;
|
||||
|
||||
DEFINE_FOR_TYPE_AND_R(complex64, Eigen::internal::SumReducer<complex64>);
|
||||
DEFINE_FOR_TYPE_AND_R(complex64, functor::MeanReducer<complex64>);
|
||||
DEFINE_FOR_TYPE_AND_R(complex64, functor::EuclideanNormReducer<complex64>);
|
||||
DEFINE_FOR_TYPE_AND_R(complex64, Eigen::internal::ProdReducer<complex64>);
|
||||
#undef DEFINE_FOR_TYPE_AND_R
|
||||
#undef DEFINE
|
||||
|
@ -51,11 +51,12 @@ typedef TTypes<float>::Tensor::Index Index;
|
||||
DEFINE(T, R, 3, 2); \
|
||||
DEFINE_IDENTITY(T, R)
|
||||
|
||||
#define DEFINE_FOR_ALL_REDUCERS(T) \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer<T>); \
|
||||
#define DEFINE_FOR_ALL_REDUCERS(T) \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, functor::EuclideanNormReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::ProdReducer<T>)
|
||||
|
||||
DEFINE_FOR_ALL_REDUCERS(double);
|
||||
|
@ -51,11 +51,12 @@ typedef TTypes<float>::Tensor::Index Index;
|
||||
DEFINE(T, R, 3, 2); \
|
||||
DEFINE_IDENTITY(T, R)
|
||||
|
||||
#define DEFINE_FOR_ALL_REDUCERS(T) \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer<T>); \
|
||||
#define DEFINE_FOR_ALL_REDUCERS(T) \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, functor::EuclideanNormReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::ProdReducer<T>)
|
||||
|
||||
DEFINE_FOR_ALL_REDUCERS(float);
|
||||
|
@ -51,11 +51,12 @@ typedef TTypes<float>::Tensor::Index Index;
|
||||
DEFINE(T, R, 3, 2); \
|
||||
DEFINE_IDENTITY(T, R)
|
||||
|
||||
#define DEFINE_FOR_ALL_REDUCERS(T) \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer<T>); \
|
||||
#define DEFINE_FOR_ALL_REDUCERS(T) \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, functor::EuclideanNormReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::ProdReducer<T>)
|
||||
|
||||
DEFINE_FOR_ALL_REDUCERS(int32);
|
||||
|
@ -51,8 +51,9 @@ typedef TTypes<float>::Tensor::Index Index;
|
||||
DEFINE(T, R, 3, 2); \
|
||||
DEFINE_IDENTITY(T, R)
|
||||
|
||||
#define DEFINE_FOR_ALL_REDUCERS(T) \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer<T>); \
|
||||
#define DEFINE_FOR_ALL_REDUCERS(T) \
|
||||
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, functor::EuclideanNormReducer<T>); \
|
||||
DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer<T>);
|
||||
|
||||
DEFINE_FOR_ALL_REDUCERS(Eigen::half);
|
||||
|
@ -164,6 +164,11 @@ static void BM_Mean2DToScalarGPU(int iters, int num_x, int num_y) {
|
||||
}
|
||||
BENCHMARK(BM_Mean2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
|
||||
|
||||
static void BM_EuclideanNorm2DToScalarGPU(int iters, int num_x, int num_y) {
|
||||
ReduceToScalar<float>(iters, "gpu", "EuclideanNorm", num_x, num_y);
|
||||
}
|
||||
BENCHMARK(BM_EuclideanNorm2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
|
||||
|
||||
static void BM_Max2DToScalarGPU(int iters, int num_x, int num_y) {
|
||||
ReduceToScalar<float>(iters, "gpu", "Max", num_x, num_y);
|
||||
}
|
||||
|
@ -850,6 +850,15 @@ REGISTER_OP("Sum")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(shape_inference::ReductionShape);
|
||||
|
||||
REGISTER_OP("EuclideanNorm")
|
||||
.Input("input: T")
|
||||
.Input("reduction_indices: Tidx")
|
||||
.Output("output: T")
|
||||
.Attr("keep_dims: bool = false")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(shape_inference::ReductionShape);
|
||||
|
||||
REGISTER_OP("Mean")
|
||||
.Input("input: T")
|
||||
.Input("reduction_indices: Tidx")
|
||||
|
@ -2199,7 +2199,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
],
|
||||
shard_count = 4,
|
||||
shard_count = 6,
|
||||
tags = ["no_windows_gpu"],
|
||||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
@ -104,7 +104,8 @@ class ReductionUnknownShape(test.TestCase):
|
||||
for dtype, reductions in [(dtypes.float32,
|
||||
(math_ops.reduce_sum, math_ops.reduce_mean,
|
||||
math_ops.reduce_prod, math_ops.reduce_max,
|
||||
math_ops.reduce_min)),
|
||||
math_ops.reduce_min,
|
||||
math_ops.reduce_euclidean_norm)),
|
||||
(dtypes.bool, (math_ops.reduce_all,
|
||||
math_ops.reduce_any))]:
|
||||
for reduction in reductions:
|
||||
@ -487,6 +488,79 @@ class MeanReductionTest(BaseReductionTest):
|
||||
self.assertTrue(np.all(np.isnan(y)))
|
||||
|
||||
|
||||
class EuclideanNormReductionTest(BaseReductionTest):
|
||||
|
||||
def _tf_reduce(self, x, reduction_axes, keepdims):
|
||||
return math_ops.reduce_euclidean_norm(x, reduction_axes, keepdims)
|
||||
|
||||
def _np_reduce(self, x, reduction_axes, keepdims):
|
||||
if isinstance(reduction_axes, list) or isinstance(reduction_axes,
|
||||
np.ndarray):
|
||||
reduction_axes = tuple(reduction_axes)
|
||||
if reduction_axes is None or reduction_axes != tuple():
|
||||
np_fro = np.sqrt(
|
||||
np.sum(x * np.conj(x), axis=reduction_axes, keepdims=keepdims))
|
||||
else:
|
||||
np_fro = x
|
||||
if np.issubdtype(x.dtype, np.integer):
|
||||
np_fro = np.floor(np_fro)
|
||||
return np_fro
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAxesType(self):
|
||||
for dtype in [dtypes.int64, dtypes.int32]:
|
||||
with self.cached_session(use_gpu=True):
|
||||
v = math_ops.reduce_mean([0, 0], constant_op.constant(0, dtype=dtype))
|
||||
tf_v = self.evaluate(v)
|
||||
self.assertAllEqual(tf_v, 0)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testInfinity(self):
|
||||
for dtype in [np.float32, np.float64]:
|
||||
for special_value_x in [-np.inf, np.inf]:
|
||||
for special_value_y in [-np.inf, np.inf]:
|
||||
np_arr = np.array([special_value_x, special_value_y]).astype(dtype)
|
||||
self._compareAll(np_arr, None)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testInt32(self):
|
||||
for rank in range(1, _MAX_RANK + 1):
|
||||
np_arr = self._makeIncremental((2,) * rank, dtypes.int32)
|
||||
self._compareAllAxes(np_arr)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFloat32(self):
|
||||
for rank in range(1, _MAX_RANK + 1):
|
||||
np_arr = self._makeIncremental((2,) * rank, dtypes.float32)
|
||||
self._compareAllAxes(np_arr)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFloat64(self):
|
||||
for rank in range(1, _MAX_RANK + 1):
|
||||
np_arr = self._makeIncremental((2,) * rank, dtypes.float64)
|
||||
self._compareAllAxes(np_arr)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testComplex64(self):
|
||||
for rank in range(1, _MAX_RANK + 1):
|
||||
np_arr = self._makeIncremental((2,) * rank, dtypes.complex64)
|
||||
self._compareAllAxes(np_arr)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testComplex128(self):
|
||||
for rank in range(1, _MAX_RANK + 1):
|
||||
np_arr = self._makeIncremental((2,) * rank, dtypes.complex128)
|
||||
self._compareAllAxes(np_arr)
|
||||
|
||||
with self.session(use_gpu=True):
|
||||
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
|
||||
# A large number is needed to get Eigen to die
|
||||
x = array_ops.zeros((0, 9938), dtype=dtype)
|
||||
y = math_ops.reduce_euclidean_norm(x, [0]).eval()
|
||||
self.assertEqual(y.shape, (9938,))
|
||||
self.assertAllEqual(y, np.zeros(9938))
|
||||
|
||||
|
||||
class ProdReductionTest(BaseReductionTest):
|
||||
|
||||
def _tf_reduce(self, x, reduction_axes, keepdims):
|
||||
|
@ -619,6 +619,8 @@ def norm(tensor,
|
||||
result = math_ops.sqrt(
|
||||
math_ops.reduce_sum(
|
||||
tensor * math_ops.conj(tensor), axis, keepdims=True))
|
||||
# TODO(rmlarsen): Replace with the following, once gradients are defined
|
||||
# result = math_ops.reduce_euclidean_norm(tensor, axis, keepdims=True)
|
||||
else:
|
||||
result = math_ops.abs(tensor)
|
||||
if ord == 1:
|
||||
|
@ -47,6 +47,10 @@ def _ArgMinGrad(op, grad):
|
||||
return [None, None]
|
||||
|
||||
|
||||
# TODO(rmlarsen): Implement gradient.
|
||||
ops.NotDifferentiable("EuclideanNorm")
|
||||
|
||||
|
||||
@ops.RegisterGradient("Sum")
|
||||
def _SumGrad(op, grad):
|
||||
"""Gradient for Sum."""
|
||||
|
@ -1386,6 +1386,47 @@ def reduce_sum(input_tensor, axis=None, keepdims=False, name=None):
|
||||
name=name))
|
||||
|
||||
|
||||
@tf_export("math.reduce_euclidean_norm")
|
||||
def reduce_euclidean_norm(input_tensor, axis=None, keepdims=False, name=None):
|
||||
"""Computes the Euclidean norm of elements across dimensions of a tensor.
|
||||
|
||||
Reduces `input_tensor` along the dimensions given in `axis`.
|
||||
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
|
||||
entry in `axis`. If `keepdims` is true, the reduced dimensions
|
||||
are retained with length 1.
|
||||
|
||||
If `axis` is None, all dimensions are reduced, and a
|
||||
tensor with a single element is returned.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
x = tf.constant([[1, 2, 3], [1, 1, 1]])
|
||||
tf.reduce_euclidean_norm(x) # sqrt(17)
|
||||
tf.reduce_euclidean_norm(x, 0) # [sqrt(2), sqrt(5), sqrt(10)]
|
||||
tf.reduce_euclidean_norm(x, 1) # [sqrt(14), sqrt(3)]
|
||||
tf.reduce_euclidean_norm(x, 1, keepdims=True) # [[sqrt(14)], [sqrt(3)]]
|
||||
tf.reduce_euclidean_norm(x, [0, 1]) # sqrt(17)
|
||||
```
|
||||
|
||||
Args:
|
||||
input_tensor: The tensor to reduce. Should have numeric type.
|
||||
axis: The dimensions to reduce. If `None` (the default), reduces all
|
||||
dimensions. Must be in the range `[-rank(input_tensor),
|
||||
rank(input_tensor))`.
|
||||
keepdims: If true, retains reduced dimensions with length 1.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
The reduced tensor, of the same dtype as the input_tensor.
|
||||
"""
|
||||
return _may_reduce_to_scalar(
|
||||
keepdims, axis,
|
||||
gen_math_ops.euclidean_norm(
|
||||
input_tensor, _ReductionDims(input_tensor, axis), keepdims,
|
||||
name=name))
|
||||
|
||||
|
||||
@tf_export(v1=["math.count_nonzero", "count_nonzero"])
|
||||
@deprecation.deprecated_args(
|
||||
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
|
||||
|
@ -300,6 +300,10 @@ tf_module {
|
||||
name: "reduce_any"
|
||||
argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_euclidean_norm"
|
||||
argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_logsumexp"
|
||||
argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
|
@ -980,6 +980,10 @@ tf_module {
|
||||
name: "Erfc"
|
||||
argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "EuclideanNorm"
|
||||
argspec: "args=[\'input\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "Exit"
|
||||
argspec: "args=[\'data\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -2659,7 +2663,7 @@ tf_module {
|
||||
member_method {
|
||||
name: "Requantize"
|
||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'requested_output_min\', \'requested_output_max\', \'out_type\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
member_method {
|
||||
name: "RequantizePerChannel"
|
||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'requested_output_min\', \'requested_output_max\', \'out_type\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -300,6 +300,10 @@ tf_module {
|
||||
name: "reduce_any"
|
||||
argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_euclidean_norm"
|
||||
argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_logsumexp"
|
||||
argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
|
||||
|
@ -980,6 +980,10 @@ tf_module {
|
||||
name: "Erfc"
|
||||
argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "EuclideanNorm"
|
||||
argspec: "args=[\'input\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "Exit"
|
||||
argspec: "args=[\'data\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -2659,7 +2663,7 @@ tf_module {
|
||||
member_method {
|
||||
name: "Requantize"
|
||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'requested_output_min\', \'requested_output_max\', \'out_type\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
member_method {
|
||||
name: "RequantizePerChannel"
|
||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'requested_output_min\', \'requested_output_max\', \'out_type\'], varargs=None, keywords=None, defaults=None"
|
||||
|
Loading…
Reference in New Issue
Block a user