[TF:XLA] Broadcast NextAfter arguments when needed.
PiperOrigin-RevId: 251613307
This commit is contained in:
		
							parent
							
								
									c47ad082f4
								
							
						
					
					
						commit
						58c796df1d
					
				| @ -56,6 +56,7 @@ tf_kernel_library( | |||||||
|         "matrix_set_diag_op.cc", |         "matrix_set_diag_op.cc", | ||||||
|         "matrix_triangular_solve_op.cc", |         "matrix_triangular_solve_op.cc", | ||||||
|         "mirror_pad_op.cc", |         "mirror_pad_op.cc", | ||||||
|  |         "next_after_op.cc", | ||||||
|         "no_op.cc", |         "no_op.cc", | ||||||
|         "one_hot_op.cc", |         "one_hot_op.cc", | ||||||
|         "pack_op.cc", |         "pack_op.cc", | ||||||
|  | |||||||
| @ -16,6 +16,8 @@ limitations under the License. | |||||||
| // Native XLA implementations of simple binary Ops
 | // Native XLA implementations of simple binary Ops
 | ||||||
| 
 | 
 | ||||||
| #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" | #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" | ||||||
|  | #include "tensorflow/compiler/tf2xla/lib/broadcast.h" | ||||||
|  | #include "tensorflow/compiler/tf2xla/shape_util.h" | ||||||
| #include "tensorflow/compiler/tf2xla/xla_helpers.h" | #include "tensorflow/compiler/tf2xla/xla_helpers.h" | ||||||
| #include "tensorflow/compiler/tf2xla/xla_op_registry.h" | #include "tensorflow/compiler/tf2xla/xla_op_registry.h" | ||||||
| #include "tensorflow/compiler/xla/client/client_library.h" | #include "tensorflow/compiler/xla/client/client_library.h" | ||||||
| @ -236,8 +238,6 @@ XLA_MAKE_BINARY(TanhGrad, | |||||||
| 
 | 
 | ||||||
| XLA_MAKE_BINARY(Pow, xla::Pow(lhs, rhs, extend_dimensions)); | XLA_MAKE_BINARY(Pow, xla::Pow(lhs, rhs, extend_dimensions)); | ||||||
| 
 | 
 | ||||||
| XLA_MAKE_BINARY(NextAfter, xla::NextAfter(lhs, rhs)); |  | ||||||
| 
 |  | ||||||
| #undef XLA_MAKE_BINARY | #undef XLA_MAKE_BINARY | ||||||
| 
 | 
 | ||||||
| class ApproximateEqualOp : public XlaOpKernel { | class ApproximateEqualOp : public XlaOpKernel { | ||||||
|  | |||||||
							
								
								
									
										43
									
								
								tensorflow/compiler/tf2xla/kernels/next_after_op.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								tensorflow/compiler/tf2xla/kernels/next_after_op.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,43 @@ | |||||||
|  | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | #include "tensorflow/compiler/tf2xla/lib/broadcast.h" | ||||||
|  | #include "tensorflow/compiler/tf2xla/shape_util.h" | ||||||
|  | #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" | ||||||
|  | #include "tensorflow/compiler/tf2xla/xla_op_registry.h" | ||||||
|  | #include "tensorflow/compiler/xla/client/lib/math.h" | ||||||
|  | #include "tensorflow/compiler/xla/client/xla_builder.h" | ||||||
|  | #include "tensorflow/core/framework/op_kernel.h" | ||||||
|  | #include "tensorflow/core/lib/core/errors.h" | ||||||
|  | #include "tensorflow/core/util/bcast.h" | ||||||
|  | 
 | ||||||
|  | namespace tensorflow { | ||||||
|  | namespace { | ||||||
|  | 
 | ||||||
|  | class NextAfterOp : public XlaOpKernel { | ||||||
|  |  public: | ||||||
|  |   explicit NextAfterOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} | ||||||
|  | 
 | ||||||
|  |   void Compile(XlaOpKernelContext* ctx) override { | ||||||
|  |     auto lhs = ctx->Input(0); | ||||||
|  |     auto rhs = ctx->Input(1); | ||||||
|  |     OP_REQUIRES_OK(ctx, BroadcastOpsToSame(&lhs, &rhs)); | ||||||
|  |     ctx->SetOutput(0, xla::NextAfter(lhs, rhs)); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | REGISTER_XLA_OP(Name("NextAfter"), NextAfterOp); | ||||||
|  | 
 | ||||||
|  | }  // namespace
 | ||||||
|  | }  // namespace tensorflow
 | ||||||
| @ -668,7 +668,6 @@ class NextAfterTest(test_util.TensorFlowTestCase): | |||||||
|       self.assertAllEqual(math_ops.nextafter(one, one), one) |       self.assertAllEqual(math_ops.nextafter(one, one), one) | ||||||
| 
 | 
 | ||||||
|   @test_util.run_in_graph_and_eager_modes |   @test_util.run_in_graph_and_eager_modes | ||||||
|   @test_util.disable_xla("Broadcasting not supported for XLA") |  | ||||||
|   def testBroadcasting(self): |   def testBroadcasting(self): | ||||||
| 
 | 
 | ||||||
|     for dtype in [dtypes.float32, dtypes.float64]: |     for dtype in [dtypes.float32, dtypes.float64]: | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user