129 lines
4.4 KiB
C++
129 lines
4.4 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
// Native XLA implementations of XLA Relu Ops
|
|
|
|
#include "tensorflow/compiler/tf2xla/kernels/relu_op.h"
|
|
|
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
|
#include "tensorflow/compiler/xla/literal.h"
|
|
|
|
namespace xla {
|
|
XlaOp Relu(XlaOp x) { return Max(ScalarLike(x, 0), x); }
|
|
|
|
XlaOp Relu6(XlaOp x) {
|
|
auto zero = ScalarLike(x, 0);
|
|
auto six = ScalarLike(x, 6);
|
|
return Clamp(zero, x, six);
|
|
}
|
|
} // namespace xla
|
|
|
|
namespace tensorflow {
|
|
namespace {
|
|
|
|
class ReluOp : public XlaOpKernel {
|
|
public:
|
|
explicit ReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
// Computes the max of the scalar input x and 0.
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
ctx->SetOutput(0, xla::Relu(ctx->Input(0)));
|
|
}
|
|
};
|
|
REGISTER_XLA_OP(Name("Relu"), ReluOp);
|
|
|
|
class Relu6Op : public XlaOpKernel {
|
|
public:
|
|
explicit Relu6Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
// Clamp the scalar input between 0 and 6.
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
ctx->SetOutput(0, xla::Relu6(ctx->Input(0)));
|
|
}
|
|
};
|
|
REGISTER_XLA_OP(Name("Relu6"), Relu6Op);
|
|
|
|
class LeakyReluOp : public XlaOpKernel {
|
|
public:
|
|
explicit LeakyReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_));
|
|
}
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
auto features = ctx->Input("features");
|
|
auto prod_with_alpha = features * xla::ScalarLike(features, alpha_);
|
|
auto gt_zero = xla::Gt(features, xla::ScalarLike(features, 0));
|
|
auto output = xla::Select(gt_zero, features, prod_with_alpha);
|
|
ctx->SetOutput(0, output);
|
|
}
|
|
float alpha_;
|
|
};
|
|
REGISTER_XLA_OP(Name("LeakyRelu"), LeakyReluOp);
|
|
|
|
class ReluGradOp : public XlaOpKernel {
|
|
public:
|
|
explicit ReluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
// Return the lhs (incoming gradient) if the rhs (input feature) > 0,
|
|
// otherwise return 0.
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
xla::XlaBuilder* b = ctx->builder();
|
|
const TensorShape shape = ctx->InputShape(0);
|
|
const auto zero =
|
|
xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes());
|
|
const auto pred = xla::Gt(ctx->Input(1), zero);
|
|
ctx->SetOutput(0, xla::Select(pred, ctx->Input(0), zero));
|
|
}
|
|
};
|
|
REGISTER_XLA_OP(Name("ReluGrad"), ReluGradOp);
|
|
|
|
class Relu6GradOp : public XlaOpKernel {
|
|
public:
|
|
explicit Relu6GradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
// Return the lhs (incoming gradient) if the rhs (input feature) > 0,
|
|
// otherwise return 0.
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
xla::XlaBuilder* b = ctx->builder();
|
|
const TensorShape shape = ctx->InputShape(0);
|
|
const auto zero =
|
|
xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes());
|
|
const auto six = xla::Broadcast(
|
|
XlaHelpers::IntegerLiteral(b, input_type(0), 6), shape.dim_sizes());
|
|
auto out = xla::Select(
|
|
xla::And(xla::Lt(ctx->Input(1), six), xla::Gt(ctx->Input(1), zero)),
|
|
ctx->Input(0), zero);
|
|
ctx->SetOutput(0, out);
|
|
}
|
|
};
|
|
REGISTER_XLA_OP(Name("Relu6Grad"), Relu6GradOp);
|
|
|
|
class LeakyReluGradOp : public XlaOpKernel {
|
|
public:
|
|
explicit LeakyReluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_));
|
|
}
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
auto gradients = ctx->Input("gradients");
|
|
auto features = ctx->Input("features");
|
|
auto output =
|
|
xla::Select(xla::Gt(features, xla::ScalarLike(features, 0)), gradients,
|
|
gradients * xla::ScalarLike(gradients, alpha_));
|
|
ctx->SetOutput(0, output);
|
|
}
|
|
float alpha_;
|
|
};
|
|
REGISTER_XLA_OP(Name("LeakyReluGrad"), LeakyReluGradOp);
|
|
|
|
} // namespace
|
|
} // namespace tensorflow
|