Add more dtypes to Einsum Op. Second try, as the first one broke global TAP because of long compilation times.

PiperOrigin-RevId: 267012419
This commit is contained in:
Anudhyan Boral 2019-09-03 14:25:17 -07:00 committed by TensorFlower Gardener
parent 45fe8a9dac
commit 6462b0e801
9 changed files with 170 additions and 6 deletions

View File

@ -116,6 +116,10 @@ tensorflow/core/kernels/depthwise_conv_op.cc
tensorflow/core/kernels/dequantize_op.cc
tensorflow/core/kernels/dynamic_partition_op.cc
tensorflow/core/kernels/dynamic_stitch_op.cc
tensorflow/core/kernels/einsum_op_impl_half.cc
tensorflow/core/kernels/einsum_op_impl_bfloat16.cc
tensorflow/core/kernels/einsum_op_impl_int32.cc
tensorflow/core/kernels/einsum_op_impl_int64.cc
tensorflow/core/kernels/einsum_op_impl_float.cc
tensorflow/core/kernels/einsum_op_impl_double.cc
tensorflow/core/kernels/einsum_op_impl_complex64.cc

View File

@ -6323,6 +6323,10 @@ filegroup(
"encode_wav_op.cc",
"eigen_contraction_kernel.cc",
"eigen_contraction_kernel.h",
"einsum_op_impl_half.cc",
"einsum_op_impl_bfloat16.cc",
"einsum_op_impl_int32.cc",
"einsum_op_impl_int64.cc",
"einsum_op_impl_float.cc",
"einsum_op_impl_double.cc",
"einsum_op_impl_complex64.cc",

View File

@ -33,6 +33,7 @@ namespace tensorflow {
DECLARE_GPU_SPECS_NDIM(T, 5); \
DECLARE_GPU_SPECS_NDIM(T, 6);
TF_CALL_half(DECLARE_GPU_SPECS);
TF_CALL_float(DECLARE_GPU_SPECS);
TF_CALL_double(DECLARE_GPU_SPECS);
TF_CALL_complex64(DECLARE_GPU_SPECS);

View File

@ -759,6 +759,7 @@ namespace functor {
DECLARE_GPU_SPEC(T, 5); \
DECLARE_GPU_SPEC(T, 6);
DECLARE_GPU_SPECS(Eigen::half);
DECLARE_GPU_SPECS(double);
DECLARE_GPU_SPECS(float);
// TODO(rocm): Enable once complex types are supported.

View File

@ -0,0 +1,31 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/einsum_op_impl.h"
namespace tensorflow {
#define REGISTER_EINSUM(D, TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("Einsum").Device(DEVICE_##D).TypeConstraint<TYPE>("T"), \
EinsumOp<D##Device, TYPE>);
#define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE)
TF_CALL_bfloat16(REGISTER_CPU);
#undef REGISTER_CPU
#undef REGISTER_EINSUM
} // namespace tensorflow

View File

@ -0,0 +1,37 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/einsum_op_impl.h"
namespace tensorflow {
#define REGISTER_EINSUM(D, TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("Einsum").Device(DEVICE_##D).TypeConstraint<TYPE>("T"), \
EinsumOp<D##Device, TYPE>);
#define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE)
TF_CALL_half(REGISTER_CPU);
#undef REGISTER_CPU
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU(TYPE) REGISTER_EINSUM(GPU, TYPE)
TF_CALL_half(REGISTER_GPU);
#undef REGISTER_GPU
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER_EINSUM
} // namespace tensorflow

View File

@ -0,0 +1,31 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/einsum_op_impl.h"
namespace tensorflow {
#define REGISTER_EINSUM(D, TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("Einsum").Device(DEVICE_##D).TypeConstraint<TYPE>("T"), \
EinsumOp<D##Device, TYPE>);
#define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE)
TF_CALL_int32(REGISTER_CPU);
#undef REGISTER_CPU
#undef REGISTER_EINSUM
} // namespace tensorflow

View File

@ -0,0 +1,31 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/einsum_op_impl.h"
namespace tensorflow {
#define REGISTER_EINSUM(D, TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("Einsum").Device(DEVICE_##D).TypeConstraint<TYPE>("T"), \
EinsumOp<D##Device, TYPE>);
#define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE)
TF_CALL_int64(REGISTER_CPU);
#undef REGISTER_CPU
#undef REGISTER_EINSUM
} // namespace tensorflow

View File

@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@ -122,12 +123,35 @@ class EinsumOpTest(test.TestCase):
self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))
def testDtypes(self):
for dtype in [np.float64, np.float32, np.complex64, np.complex128]:
self._check('ij,jk->ik', (2, 2), (2, 2), dtype=dtype)
self._check('ji,jk->ik', (2, 2), (2, 2), dtype=dtype)
self._check('ji,kj->ik', (2, 2), (2, 2), dtype=dtype)
self._check('ij,jk->ki', (2, 2), (2, 2), dtype=dtype)
self._check('ji,kj->ki', (2, 2), (2, 2), dtype=dtype)
bfloat16 = dtypes.bfloat16.as_numpy_dtype
def check(dtype):
r = np.random.RandomState(0)
equation = 'ij,jk->ik'
input_shapes = [(2, 2), (2, 2)]
inputs = []
for shape in input_shapes:
arr = np.array(r.randn(*shape)).astype(dtype)
if dtype == np.complex64 or dtype == np.complex128:
arr += 1j * np.array(r.randn(*shape)).astype(dtype)
inputs.append(arr)
input_tensors = [constant_op.constant(x) for x in inputs]
if dtype == bfloat16:
# np.einsum doesn't support bfloat16.
a = np.einsum(equation,
*[x.astype(np.float32) for x in inputs]).astype(dtype)
else:
a = np.einsum(equation, *inputs)
b = self.evaluate(gen_linalg_ops.einsum(input_tensors, equation))
tol = 1e-2 if dtype == bfloat16 else 1e-4
self.assertAllClose(a, b, atol=tol, rtol=tol)
for dtype in [
bfloat16, np.float32, np.float64, np.complex64, np.complex128, np.int32,
np.int64
]:
check(dtype)
@test_util.run_in_graph_and_eager_modes
def testInvalid(self):