diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 9eaede7f406..83cfd2ea75b 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -765,6 +765,24 @@ class BinaryOpsTest(XLATestCase): np.array([1, 0], dtype=np.int32), expected=np.array([[1, 3], [2, 4]], dtype=dtype)) + def testCross(self): + for dtype in self.float_types: + self._testBinary( + gen_math_ops.cross, + np.zeros((4, 3), dtype=dtype), + np.zeros((4, 3), dtype=dtype), + expected=np.zeros((4, 3), dtype=dtype)) + self._testBinary( + gen_math_ops.cross, + np.array([1, 2, 3], dtype=dtype), + np.array([4, 5, 6], dtype=dtype), + expected=np.array([-3, 6, -3], dtype=dtype)) + self._testBinary( + gen_math_ops.cross, + np.array([[1, 2, 3], [10, 11, 12]], dtype=dtype), + np.array([[4, 5, 6], [40, 50, 60]], dtype=dtype), + expected=np.array([[-3, 6, -3], [60, -120, 60]], dtype=dtype)) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 546e9be8647..b114b7e6f83 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -24,6 +24,7 @@ tf_kernel_library( "concat_op.cc", "const_op.cc", "conv_ops.cc", + "cross_op.cc", "cwise_ops.cc", "depthwise_conv_ops.cc", "diag_op.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc new file mode 100644 index 00000000000..3df8c00f1b8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +class CrossOp : public XlaOpKernel { + public: + explicit CrossOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape in0_shape = ctx->InputShape(0); + TensorShape in1_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, in0_shape == in1_shape, + errors::InvalidArgument("Both inputs must be of same shape: ", + in0_shape.DebugString(), " vs. ", + in1_shape.DebugString())); + OP_REQUIRES(ctx, in0_shape.dims() >= 1, + errors::InvalidArgument("Input must be at least 1D", + in0_shape.DebugString())); + + auto inner_dim = in0_shape.dim_size(in0_shape.dims() - 1); + OP_REQUIRES(ctx, inner_dim == 3, + errors::FailedPrecondition( + "Cross-products are only defined for 3-element vectors.")); + + // in0 is a [...,X,Y,Z,3] + // in1 is the same shape as in0 + // So slice 0 is: in0[...,:,:,:,0:1] + // So slice 1 is: in0[...,:,:,:,1:2] + // So slice 2 is: in0[...,:,:,:,2:3] + + std::vector starts(in0_shape.dims(), 0); + std::vector limits; + for (auto dim_size : in0_shape.dim_sizes()) { + limits.push_back(dim_size); + } + std::vector strides(in0_shape.dims(), 1); + + xla::ComputationBuilder* b = ctx->builder(); + auto in0 = ctx->Input(0); + auto in1 = ctx->Input(1); + starts.back() = 0; + limits.back() = 1; + auto u1 = b->Slice(in0, starts, limits, strides); + auto v1 = b->Slice(in1, starts, limits, strides); + starts.back() = 1; + limits.back() = 2; + auto u2 = b->Slice(in0, starts, limits, strides); + auto v2 = b->Slice(in1, starts, limits, strides); + starts.back() = 2; + limits.back() = 3; + auto u3 = b->Slice(in0, starts, limits, strides); + auto v3 = b->Slice(in1, starts, limits, strides); + + auto s1 = b->Sub(b->Mul(u2, v3), b->Mul(u3, v2)); + auto s2 = b->Sub(b->Mul(u3, v1), b->Mul(u1, v3)); + auto s3 = b->Sub(b->Mul(u1, v2), b->Mul(u2, v1)); + auto output = b->ConcatInDim({s1, s2, s3}, in0_shape.dims() - 1); + + ctx->SetOutput(0, output); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(CrossOp); +}; + +REGISTER_XLA_OP(Name("Cross"), CrossOp); + +} // namespace +} // namespace tensorflow