From 58c796df1dfd750782a1d285ec0972969bd196c6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Jun 2019 03:47:33 -0700 Subject: [PATCH] [TF:XLA] Broadcast NextAfter arguments when needed. PiperOrigin-RevId: 251613307 --- tensorflow/compiler/tf2xla/kernels/BUILD | 1 + .../compiler/tf2xla/kernels/binary_ops.cc | 4 +- .../compiler/tf2xla/kernels/next_after_op.cc | 43 +++++++++++++++++++ tensorflow/python/ops/math_ops_test.py | 1 - 4 files changed, 46 insertions(+), 3 deletions(-) create mode 100644 tensorflow/compiler/tf2xla/kernels/next_after_op.cc diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 06376f7174e..621c49fbce7 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -56,6 +56,7 @@ tf_kernel_library( "matrix_set_diag_op.cc", "matrix_triangular_solve_op.cc", "mirror_pad_op.cc", + "next_after_op.cc", "no_op.cc", "one_hot_op.cc", "pack_op.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 1fe05fba158..b923ff6c96c 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -16,6 +16,8 @@ limitations under the License. // Native XLA implementations of simple binary Ops #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -236,8 +238,6 @@ XLA_MAKE_BINARY(TanhGrad, XLA_MAKE_BINARY(Pow, xla::Pow(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(NextAfter, xla::NextAfter(lhs, rhs)); - #undef XLA_MAKE_BINARY class ApproximateEqualOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/next_after_op.cc b/tensorflow/compiler/tf2xla/kernels/next_after_op.cc new file mode 100644 index 00000000000..0801c52500f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/next_after_op.cc @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/bcast.h" + +namespace tensorflow { +namespace { + +class NextAfterOp : public XlaOpKernel { + public: + explicit NextAfterOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + auto lhs = ctx->Input(0); + auto rhs = ctx->Input(1); + OP_REQUIRES_OK(ctx, BroadcastOpsToSame(&lhs, &rhs)); + ctx->SetOutput(0, xla::NextAfter(lhs, rhs)); + } +}; +REGISTER_XLA_OP(Name("NextAfter"), NextAfterOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 6fb3badb588..68740b67374 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -668,7 +668,6 @@ class NextAfterTest(test_util.TensorFlowTestCase): self.assertAllEqual(math_ops.nextafter(one, one), one) @test_util.run_in_graph_and_eager_modes - @test_util.disable_xla("Broadcasting not supported for XLA") def testBroadcasting(self): for dtype in [dtypes.float32, dtypes.float64]: