[XLA] Add IsFinite op in tf2xla.
PiperOrigin-RevId: 165734702
This commit is contained in:
parent
5f5c3eb0ab
commit
32f4c5b6e8
@ -260,6 +260,13 @@ class UnaryOpsTest(XLATestCase):
|
|||||||
np.array([[-2, 0, 8]], dtype=dtype),
|
np.array([[-2, 0, 8]], dtype=dtype),
|
||||||
expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype))
|
expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype))
|
||||||
|
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.is_finite,
|
||||||
|
np.array(
|
||||||
|
[[42, float("inf"), -123], [float("nan"), 0, -0.0]], dtype=dtype),
|
||||||
|
expected=np.array(
|
||||||
|
[[True, False, True], [False, True, True]], dtype=np.bool))
|
||||||
|
|
||||||
def testNumericOps(self):
|
def testNumericOps(self):
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
|
@ -31,6 +31,7 @@ tf_kernel_library(
|
|||||||
"function_ops.cc",
|
"function_ops.cc",
|
||||||
"gather_op.cc",
|
"gather_op.cc",
|
||||||
"identity_op.cc",
|
"identity_op.cc",
|
||||||
|
"is_finite_op.cc",
|
||||||
"l2loss_op.cc",
|
"l2loss_op.cc",
|
||||||
"lrn_ops.cc",
|
"lrn_ops.cc",
|
||||||
"matmul_op.cc",
|
"matmul_op.cc",
|
||||||
|
43
tensorflow/compiler/tf2xla/kernels/is_finite_op.cc
Normal file
43
tensorflow/compiler/tf2xla/kernels/is_finite_op.cc
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/bcast.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class IsFiniteOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit IsFiniteOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
xla::ComputationDataHandle input = ctx->Input(0);
|
||||||
|
ctx->SetOutput(0, ctx->builder()->IsFinite(input));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(IsFiniteOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("IsFinite"), IsFiniteOp);
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
} // namespace tensorflow
|
Loading…
x
Reference in New Issue
Block a user