Add more dtypes to Einsum Op. Second try, as the first one broke global TAP because of long compilation times.
PiperOrigin-RevId: 267012419
This commit is contained in:
parent
45fe8a9dac
commit
6462b0e801
tensorflow
contrib/makefile
core/kernels
BUILDeinsum_op_gpu.cu.cceinsum_op_impl.heinsum_op_impl_bfloat16.cceinsum_op_impl_half.cceinsum_op_impl_int32.cceinsum_op_impl_int64.cc
python/kernel_tests
@ -116,6 +116,10 @@ tensorflow/core/kernels/depthwise_conv_op.cc
|
||||
tensorflow/core/kernels/dequantize_op.cc
|
||||
tensorflow/core/kernels/dynamic_partition_op.cc
|
||||
tensorflow/core/kernels/dynamic_stitch_op.cc
|
||||
tensorflow/core/kernels/einsum_op_impl_half.cc
|
||||
tensorflow/core/kernels/einsum_op_impl_bfloat16.cc
|
||||
tensorflow/core/kernels/einsum_op_impl_int32.cc
|
||||
tensorflow/core/kernels/einsum_op_impl_int64.cc
|
||||
tensorflow/core/kernels/einsum_op_impl_float.cc
|
||||
tensorflow/core/kernels/einsum_op_impl_double.cc
|
||||
tensorflow/core/kernels/einsum_op_impl_complex64.cc
|
||||
|
@ -6323,6 +6323,10 @@ filegroup(
|
||||
"encode_wav_op.cc",
|
||||
"eigen_contraction_kernel.cc",
|
||||
"eigen_contraction_kernel.h",
|
||||
"einsum_op_impl_half.cc",
|
||||
"einsum_op_impl_bfloat16.cc",
|
||||
"einsum_op_impl_int32.cc",
|
||||
"einsum_op_impl_int64.cc",
|
||||
"einsum_op_impl_float.cc",
|
||||
"einsum_op_impl_double.cc",
|
||||
"einsum_op_impl_complex64.cc",
|
||||
|
@ -33,6 +33,7 @@ namespace tensorflow {
|
||||
DECLARE_GPU_SPECS_NDIM(T, 5); \
|
||||
DECLARE_GPU_SPECS_NDIM(T, 6);
|
||||
|
||||
TF_CALL_half(DECLARE_GPU_SPECS);
|
||||
TF_CALL_float(DECLARE_GPU_SPECS);
|
||||
TF_CALL_double(DECLARE_GPU_SPECS);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPECS);
|
||||
|
@ -759,6 +759,7 @@ namespace functor {
|
||||
DECLARE_GPU_SPEC(T, 5); \
|
||||
DECLARE_GPU_SPEC(T, 6);
|
||||
|
||||
DECLARE_GPU_SPECS(Eigen::half);
|
||||
DECLARE_GPU_SPECS(double);
|
||||
DECLARE_GPU_SPECS(float);
|
||||
// TODO(rocm): Enable once complex types are supported.
|
||||
|
31
tensorflow/core/kernels/einsum_op_impl_bfloat16.cc
Normal file
31
tensorflow/core/kernels/einsum_op_impl_bfloat16.cc
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/einsum_op_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#define REGISTER_EINSUM(D, TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Einsum").Device(DEVICE_##D).TypeConstraint<TYPE>("T"), \
|
||||
EinsumOp<D##Device, TYPE>);
|
||||
|
||||
#define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE)
|
||||
TF_CALL_bfloat16(REGISTER_CPU);
|
||||
#undef REGISTER_CPU
|
||||
|
||||
#undef REGISTER_EINSUM
|
||||
|
||||
} // namespace tensorflow
|
37
tensorflow/core/kernels/einsum_op_impl_half.cc
Normal file
37
tensorflow/core/kernels/einsum_op_impl_half.cc
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/einsum_op_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#define REGISTER_EINSUM(D, TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Einsum").Device(DEVICE_##D).TypeConstraint<TYPE>("T"), \
|
||||
EinsumOp<D##Device, TYPE>);
|
||||
|
||||
#define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE)
|
||||
TF_CALL_half(REGISTER_CPU);
|
||||
#undef REGISTER_CPU
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define REGISTER_GPU(TYPE) REGISTER_EINSUM(GPU, TYPE)
|
||||
TF_CALL_half(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#undef REGISTER_EINSUM
|
||||
|
||||
} // namespace tensorflow
|
31
tensorflow/core/kernels/einsum_op_impl_int32.cc
Normal file
31
tensorflow/core/kernels/einsum_op_impl_int32.cc
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/einsum_op_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#define REGISTER_EINSUM(D, TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Einsum").Device(DEVICE_##D).TypeConstraint<TYPE>("T"), \
|
||||
EinsumOp<D##Device, TYPE>);
|
||||
|
||||
#define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE)
|
||||
TF_CALL_int32(REGISTER_CPU);
|
||||
#undef REGISTER_CPU
|
||||
|
||||
#undef REGISTER_EINSUM
|
||||
|
||||
} // namespace tensorflow
|
31
tensorflow/core/kernels/einsum_op_impl_int64.cc
Normal file
31
tensorflow/core/kernels/einsum_op_impl_int64.cc
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/einsum_op_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#define REGISTER_EINSUM(D, TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Einsum").Device(DEVICE_##D).TypeConstraint<TYPE>("T"), \
|
||||
EinsumOp<D##Device, TYPE>);
|
||||
|
||||
#define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE)
|
||||
TF_CALL_int64(REGISTER_CPU);
|
||||
#undef REGISTER_CPU
|
||||
|
||||
#undef REGISTER_EINSUM
|
||||
|
||||
} // namespace tensorflow
|
@ -22,6 +22,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -122,12 +123,35 @@ class EinsumOpTest(test.TestCase):
|
||||
self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))
|
||||
|
||||
def testDtypes(self):
|
||||
for dtype in [np.float64, np.float32, np.complex64, np.complex128]:
|
||||
self._check('ij,jk->ik', (2, 2), (2, 2), dtype=dtype)
|
||||
self._check('ji,jk->ik', (2, 2), (2, 2), dtype=dtype)
|
||||
self._check('ji,kj->ik', (2, 2), (2, 2), dtype=dtype)
|
||||
self._check('ij,jk->ki', (2, 2), (2, 2), dtype=dtype)
|
||||
self._check('ji,kj->ki', (2, 2), (2, 2), dtype=dtype)
|
||||
bfloat16 = dtypes.bfloat16.as_numpy_dtype
|
||||
|
||||
def check(dtype):
|
||||
r = np.random.RandomState(0)
|
||||
equation = 'ij,jk->ik'
|
||||
input_shapes = [(2, 2), (2, 2)]
|
||||
inputs = []
|
||||
for shape in input_shapes:
|
||||
arr = np.array(r.randn(*shape)).astype(dtype)
|
||||
if dtype == np.complex64 or dtype == np.complex128:
|
||||
arr += 1j * np.array(r.randn(*shape)).astype(dtype)
|
||||
inputs.append(arr)
|
||||
input_tensors = [constant_op.constant(x) for x in inputs]
|
||||
if dtype == bfloat16:
|
||||
# np.einsum doesn't support bfloat16.
|
||||
a = np.einsum(equation,
|
||||
*[x.astype(np.float32) for x in inputs]).astype(dtype)
|
||||
else:
|
||||
a = np.einsum(equation, *inputs)
|
||||
|
||||
b = self.evaluate(gen_linalg_ops.einsum(input_tensors, equation))
|
||||
tol = 1e-2 if dtype == bfloat16 else 1e-4
|
||||
self.assertAllClose(a, b, atol=tol, rtol=tol)
|
||||
|
||||
for dtype in [
|
||||
bfloat16, np.float32, np.float64, np.complex64, np.complex128, np.int32,
|
||||
np.int64
|
||||
]:
|
||||
check(dtype)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testInvalid(self):
|
||||
|
Loading…
Reference in New Issue
Block a user