diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 28d3fe9e39d..299db8bc871 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -116,6 +116,10 @@ tensorflow/core/kernels/depthwise_conv_op.cc tensorflow/core/kernels/dequantize_op.cc tensorflow/core/kernels/dynamic_partition_op.cc tensorflow/core/kernels/dynamic_stitch_op.cc +tensorflow/core/kernels/einsum_op_impl_half.cc +tensorflow/core/kernels/einsum_op_impl_bfloat16.cc +tensorflow/core/kernels/einsum_op_impl_int32.cc +tensorflow/core/kernels/einsum_op_impl_int64.cc tensorflow/core/kernels/einsum_op_impl_float.cc tensorflow/core/kernels/einsum_op_impl_double.cc tensorflow/core/kernels/einsum_op_impl_complex64.cc diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 02a630413cb..6b233eabe46 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -6323,6 +6323,10 @@ filegroup( "encode_wav_op.cc", "eigen_contraction_kernel.cc", "eigen_contraction_kernel.h", + "einsum_op_impl_half.cc", + "einsum_op_impl_bfloat16.cc", + "einsum_op_impl_int32.cc", + "einsum_op_impl_int64.cc", "einsum_op_impl_float.cc", "einsum_op_impl_double.cc", "einsum_op_impl_complex64.cc", diff --git a/tensorflow/core/kernels/einsum_op_gpu.cu.cc b/tensorflow/core/kernels/einsum_op_gpu.cu.cc index fa1c8cbb4a5..36a97691297 100644 --- a/tensorflow/core/kernels/einsum_op_gpu.cu.cc +++ b/tensorflow/core/kernels/einsum_op_gpu.cu.cc @@ -33,6 +33,7 @@ namespace tensorflow { DECLARE_GPU_SPECS_NDIM(T, 5); \ DECLARE_GPU_SPECS_NDIM(T, 6); +TF_CALL_half(DECLARE_GPU_SPECS); TF_CALL_float(DECLARE_GPU_SPECS); TF_CALL_double(DECLARE_GPU_SPECS); TF_CALL_complex64(DECLARE_GPU_SPECS); diff --git a/tensorflow/core/kernels/einsum_op_impl.h b/tensorflow/core/kernels/einsum_op_impl.h index 4b35cf3f20f..0139ec735da 100644 --- a/tensorflow/core/kernels/einsum_op_impl.h +++ b/tensorflow/core/kernels/einsum_op_impl.h @@ -759,6 +759,7 @@ namespace functor { DECLARE_GPU_SPEC(T, 5); \ DECLARE_GPU_SPEC(T, 6); +DECLARE_GPU_SPECS(Eigen::half); DECLARE_GPU_SPECS(double); DECLARE_GPU_SPECS(float); // TODO(rocm): Enable once complex types are supported. diff --git a/tensorflow/core/kernels/einsum_op_impl_bfloat16.cc b/tensorflow/core/kernels/einsum_op_impl_bfloat16.cc new file mode 100644 index 00000000000..44508f86a5e --- /dev/null +++ b/tensorflow/core/kernels/einsum_op_impl_bfloat16.cc @@ -0,0 +1,31 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/einsum_op_impl.h" + +namespace tensorflow { + +#define REGISTER_EINSUM(D, TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Einsum").Device(DEVICE_##D).TypeConstraint<TYPE>("T"), \ + EinsumOp<D##Device, TYPE>); + +#define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE) +TF_CALL_bfloat16(REGISTER_CPU); +#undef REGISTER_CPU + +#undef REGISTER_EINSUM + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/einsum_op_impl_half.cc b/tensorflow/core/kernels/einsum_op_impl_half.cc new file mode 100644 index 00000000000..0486b133e62 --- /dev/null +++ b/tensorflow/core/kernels/einsum_op_impl_half.cc @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/einsum_op_impl.h" + +namespace tensorflow { + +#define REGISTER_EINSUM(D, TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Einsum").Device(DEVICE_##D).TypeConstraint<TYPE>("T"), \ + EinsumOp<D##Device, TYPE>); + +#define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE) +TF_CALL_half(REGISTER_CPU); +#undef REGISTER_CPU + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define REGISTER_GPU(TYPE) REGISTER_EINSUM(GPU, TYPE) +TF_CALL_half(REGISTER_GPU); +#undef REGISTER_GPU +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#undef REGISTER_EINSUM + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/einsum_op_impl_int32.cc b/tensorflow/core/kernels/einsum_op_impl_int32.cc new file mode 100644 index 00000000000..db5169498d9 --- /dev/null +++ b/tensorflow/core/kernels/einsum_op_impl_int32.cc @@ -0,0 +1,31 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/einsum_op_impl.h" + +namespace tensorflow { + +#define REGISTER_EINSUM(D, TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Einsum").Device(DEVICE_##D).TypeConstraint<TYPE>("T"), \ + EinsumOp<D##Device, TYPE>); + +#define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE) +TF_CALL_int32(REGISTER_CPU); +#undef REGISTER_CPU + +#undef REGISTER_EINSUM + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/einsum_op_impl_int64.cc b/tensorflow/core/kernels/einsum_op_impl_int64.cc new file mode 100644 index 00000000000..7f1a1eac411 --- /dev/null +++ b/tensorflow/core/kernels/einsum_op_impl_int64.cc @@ -0,0 +1,31 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/einsum_op_impl.h" + +namespace tensorflow { + +#define REGISTER_EINSUM(D, TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Einsum").Device(DEVICE_##D).TypeConstraint<TYPE>("T"), \ + EinsumOp<D##Device, TYPE>); + +#define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE) +TF_CALL_int64(REGISTER_CPU); +#undef REGISTER_CPU + +#undef REGISTER_EINSUM + +} // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/einsum_op_test.py b/tensorflow/python/kernel_tests/einsum_op_test.py index b51b91ddbf4..c4fffa1b5a5 100644 --- a/tensorflow/python/kernel_tests/einsum_op_test.py +++ b/tensorflow/python/kernel_tests/einsum_op_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.client import session from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -122,12 +123,35 @@ class EinsumOpTest(test.TestCase): self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,)) def testDtypes(self): - for dtype in [np.float64, np.float32, np.complex64, np.complex128]: - self._check('ij,jk->ik', (2, 2), (2, 2), dtype=dtype) - self._check('ji,jk->ik', (2, 2), (2, 2), dtype=dtype) - self._check('ji,kj->ik', (2, 2), (2, 2), dtype=dtype) - self._check('ij,jk->ki', (2, 2), (2, 2), dtype=dtype) - self._check('ji,kj->ki', (2, 2), (2, 2), dtype=dtype) + bfloat16 = dtypes.bfloat16.as_numpy_dtype + + def check(dtype): + r = np.random.RandomState(0) + equation = 'ij,jk->ik' + input_shapes = [(2, 2), (2, 2)] + inputs = [] + for shape in input_shapes: + arr = np.array(r.randn(*shape)).astype(dtype) + if dtype == np.complex64 or dtype == np.complex128: + arr += 1j * np.array(r.randn(*shape)).astype(dtype) + inputs.append(arr) + input_tensors = [constant_op.constant(x) for x in inputs] + if dtype == bfloat16: + # np.einsum doesn't support bfloat16. + a = np.einsum(equation, + *[x.astype(np.float32) for x in inputs]).astype(dtype) + else: + a = np.einsum(equation, *inputs) + + b = self.evaluate(gen_linalg_ops.einsum(input_tensors, equation)) + tol = 1e-2 if dtype == bfloat16 else 1e-4 + self.assertAllClose(a, b, atol=tol, rtol=tol) + + for dtype in [ + bfloat16, np.float32, np.float64, np.complex64, np.complex128, np.int32, + np.int64 + ]: + check(dtype) @test_util.run_in_graph_and_eager_modes def testInvalid(self):