StringPiece is an alias for absl::string_view, InlinedVector is aliased to absl::InlinedVector. StrCat is compatible, so swapping it out is safe. PiperOrigin-RevId: 211691840
247 lines
8.4 KiB
C++
247 lines
8.4 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
// XLA-specific Shape Ops.
|
|
|
|
#include "tensorflow/compiler/tf2xla/kernels/shape_util.h"
|
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
|
#include "tensorflow/core/kernels/bounds_check.h"
|
|
|
|
namespace tensorflow {
|
|
namespace {
|
|
|
|
class ShapeOp : public XlaOpKernel {
|
|
public:
|
|
explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
|
|
}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
const TensorShape input_shape = ctx->InputShape(0);
|
|
Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()}));
|
|
OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant));
|
|
ctx->SetConstantOutput(0, shape_constant);
|
|
}
|
|
|
|
private:
|
|
DataType out_dtype_;
|
|
};
|
|
|
|
REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp);
|
|
|
|
class ShapeNOp : public XlaOpKernel {
|
|
public:
|
|
explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
|
|
}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
|
const TensorShape input_shape = ctx->InputShape(i);
|
|
Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()}));
|
|
OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant));
|
|
ctx->SetConstantOutput(i, shape_constant);
|
|
}
|
|
}
|
|
|
|
bool IsExpensive() override { return false; }
|
|
|
|
private:
|
|
DataType out_dtype_;
|
|
};
|
|
REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp);
|
|
|
|
class RankOp : public XlaOpKernel {
|
|
public:
|
|
explicit RankOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
const TensorShape input_shape = ctx->InputShape(0);
|
|
const int rank = input_shape.dims();
|
|
Tensor rank_constant(DT_INT32, TensorShape({}));
|
|
rank_constant.scalar<int32>()() = rank;
|
|
|
|
ctx->SetConstantOutput(0, rank_constant);
|
|
}
|
|
};
|
|
|
|
REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp);
|
|
|
|
class SizeOp : public XlaOpKernel {
|
|
public:
|
|
explicit SizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
const TensorShape input_shape = ctx->InputShape(0);
|
|
const int64 size = input_shape.num_elements();
|
|
OP_REQUIRES(ctx, FastBoundsCheck(size, std::numeric_limits<int32>::max()),
|
|
errors::InvalidArgument("Size does not work for tensors > "
|
|
"int32 max."));
|
|
Tensor size_constant(DT_INT32, TensorShape({}));
|
|
size_constant.scalar<int32>()() = static_cast<int32>(size);
|
|
|
|
ctx->SetConstantOutput(0, size_constant);
|
|
}
|
|
};
|
|
|
|
REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp);
|
|
|
|
class ExpandDimsOp : public XlaOpKernel {
|
|
public:
|
|
explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
const TensorShape input_shape = ctx->InputShape(0);
|
|
const TensorShape dim_shape = ctx->InputShape(1);
|
|
|
|
// TODO(phawkins): the standard implementation of ExpandDimsOp seems to
|
|
// accept legacy scalars, even when they should be forbidden by the graphdef
|
|
// version.
|
|
OP_REQUIRES(ctx, dim_shape.num_elements() == 1,
|
|
errors::InvalidArgument(absl::StrCat(
|
|
"dim input to ExpandDims must be a scalar; got ",
|
|
dim_shape.DebugString())));
|
|
|
|
xla::Literal literal;
|
|
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {1}, &literal));
|
|
|
|
int dim = literal.data<int32>()[0];
|
|
|
|
OP_REQUIRES(ctx,
|
|
(dim >= -1 - input_shape.dims() && dim <= input_shape.dims()),
|
|
errors::InvalidArgument("Tried to expand dim index ", dim,
|
|
" for tensor with ", input_shape.dims(),
|
|
" dimensions."));
|
|
|
|
auto existing_dims = input_shape.dim_sizes();
|
|
// Safe - # elements in tensor dims bounded.
|
|
const int existing_dims_size = static_cast<int>(existing_dims.size());
|
|
std::vector<int64> new_shape(existing_dims_size);
|
|
for (size_t i = 0; i < new_shape.size(); ++i) {
|
|
new_shape[i] = existing_dims[i];
|
|
}
|
|
|
|
// We emulate numpy's interpretation of the dim axis when
|
|
// -input.dims() >= dim <= input.dims().
|
|
if (dim < 0) {
|
|
dim += existing_dims.size() + 1;
|
|
}
|
|
|
|
// Clamp to the end if needed.
|
|
dim = std::min<int32>(dim, existing_dims_size);
|
|
new_shape.emplace(new_shape.begin() + dim, 1);
|
|
|
|
ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape));
|
|
}
|
|
};
|
|
REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstInput("dim"), ExpandDimsOp);
|
|
|
|
class SqueezeOp : public XlaOpKernel {
|
|
public:
|
|
explicit SqueezeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
|
std::vector<int32> squeeze_dims;
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims));
|
|
squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end());
|
|
}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
const TensorShape input_shape = ctx->InputShape(0);
|
|
auto existing_dims = input_shape.dim_sizes();
|
|
int existing_dims_size = input_shape.dims();
|
|
std::vector<int64> new_shape;
|
|
|
|
std::unordered_set<int32> wrapped_squeeze_dims;
|
|
wrapped_squeeze_dims.reserve(squeeze_dims_.size());
|
|
// Validate squeeze dims against the input.
|
|
for (int32 dim : squeeze_dims_) {
|
|
OP_REQUIRES(ctx, (dim >= -input_shape.dims() && dim < input_shape.dims()),
|
|
errors::InvalidArgument("Tried to squeeze dim index ", dim,
|
|
" for tensor with ",
|
|
input_shape.dims(), " dimensions."));
|
|
// If dim is < 0, we wrap around (-1 means the last element).
|
|
if (dim < 0) {
|
|
dim = existing_dims_size + dim;
|
|
}
|
|
|
|
wrapped_squeeze_dims.insert(dim);
|
|
}
|
|
|
|
for (int i = 0; i < existing_dims_size; ++i) {
|
|
auto existing_dim = existing_dims[i];
|
|
|
|
// If squeeze_set is non-empty, only squeeze those dimensions.
|
|
if (!wrapped_squeeze_dims.empty()) {
|
|
if (wrapped_squeeze_dims.count(i) > 0) {
|
|
OP_REQUIRES(ctx, existing_dim == 1,
|
|
errors::InvalidArgument(
|
|
"Tried to explicitly squeeze dimension ", i,
|
|
" but dimension was not 1: ", existing_dim));
|
|
} else {
|
|
// This dimension is not being squeezed.
|
|
new_shape.push_back(existing_dim);
|
|
}
|
|
} else {
|
|
// Copy over all non-1-length dimensions.
|
|
if (existing_dim != 1) {
|
|
new_shape.push_back(existing_dim);
|
|
}
|
|
}
|
|
}
|
|
|
|
ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape));
|
|
}
|
|
|
|
private:
|
|
std::unordered_set<int32> squeeze_dims_;
|
|
};
|
|
|
|
REGISTER_XLA_OP(Name("Squeeze"), SqueezeOp);
|
|
|
|
class ZerosLikeOp : public XlaOpKernel {
|
|
public:
|
|
explicit ZerosLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
const TensorShape input_shape = ctx->InputShape(0);
|
|
|
|
auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
|
|
ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes()));
|
|
}
|
|
};
|
|
|
|
REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp);
|
|
|
|
class OnesLikeOp : public XlaOpKernel {
|
|
public:
|
|
explicit OnesLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
const TensorShape input_shape = ctx->InputShape(0);
|
|
|
|
auto one = XlaHelpers::One(ctx->builder(), input_type(0));
|
|
ctx->SetOutput(0, xla::Broadcast(one, input_shape.dim_sizes()));
|
|
}
|
|
};
|
|
|
|
REGISTER_XLA_OP(Name("OnesLike"), OnesLikeOp);
|
|
|
|
} // namespace
|
|
} // namespace tensorflow
|