[XLA] Add IsFinite op in tf2xla.
PiperOrigin-RevId: 165734702
This commit is contained in:
parent
5f5c3eb0ab
commit
32f4c5b6e8
@ -260,6 +260,13 @@ class UnaryOpsTest(XLATestCase):
|
||||
np.array([[-2, 0, 8]], dtype=dtype),
|
||||
expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.is_finite,
|
||||
np.array(
|
||||
[[42, float("inf"), -123], [float("nan"), 0, -0.0]], dtype=dtype),
|
||||
expected=np.array(
|
||||
[[True, False, True], [False, True, True]], dtype=np.bool))
|
||||
|
||||
def testNumericOps(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._assertOpOutputMatchesExpected(
|
||||
|
@ -31,6 +31,7 @@ tf_kernel_library(
|
||||
"function_ops.cc",
|
||||
"gather_op.cc",
|
||||
"identity_op.cc",
|
||||
"is_finite_op.cc",
|
||||
"l2loss_op.cc",
|
||||
"lrn_ops.cc",
|
||||
"matmul_op.cc",
|
||||
|
43
tensorflow/compiler/tf2xla/kernels/is_finite_op.cc
Normal file
43
tensorflow/compiler/tf2xla/kernels/is_finite_op.cc
Normal file
@ -0,0 +1,43 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/bcast.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class IsFiniteOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit IsFiniteOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationDataHandle input = ctx->Input(0);
|
||||
ctx->SetOutput(0, ctx->builder()->IsFinite(input));
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(IsFiniteOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("IsFinite"), IsFiniteOp);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
Loading…
x
Reference in New Issue
Block a user