132 lines
5.1 KiB
C++
132 lines
5.1 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
// This file defines helper routines for XLA compilation.
|
|
|
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
|
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
|
|
|
#include "absl/types/span.h"
|
|
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
|
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
|
#include "tensorflow/compiler/xla/types.h"
|
|
#include "tensorflow/core/framework/tensor.h"
|
|
#include "tensorflow/core/lib/core/status.h"
|
|
|
|
namespace tensorflow {
|
|
|
|
xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) {
|
|
xla::PrimitiveType type;
|
|
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
|
|
return xla::ConstantLiteral(b, xla::LiteralUtil::Zero(type));
|
|
}
|
|
|
|
xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) {
|
|
xla::PrimitiveType type;
|
|
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
|
|
return xla::ConstantLiteral(b, xla::LiteralUtil::One(type));
|
|
}
|
|
|
|
xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type,
|
|
int64 value) {
|
|
xla::PrimitiveType type;
|
|
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
|
|
return ::tensorflow::IntegerLiteral(b, type, value);
|
|
}
|
|
|
|
xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type,
|
|
double value) {
|
|
xla::PrimitiveType type;
|
|
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
|
|
return ::tensorflow::FloatLiteral(b, type, value);
|
|
}
|
|
|
|
/* static */ Status XlaHelpers::ReshapeLiteral(
|
|
const xla::Literal& input, absl::Span<const int64> dimensions,
|
|
xla::Literal* output) {
|
|
if (input.shape().IsTuple()) {
|
|
return errors::InvalidArgument("ReshapeLiteral does not support tuples.");
|
|
}
|
|
xla::Shape shape =
|
|
xla::ShapeUtil::MakeShape(input.shape().element_type(), dimensions);
|
|
int64 elements_before = xla::ShapeUtil::ElementsIn(input.shape());
|
|
int64 elements_after = xla::ShapeUtil::ElementsIn(shape);
|
|
if (elements_before != elements_after) {
|
|
return errors::InvalidArgument(
|
|
"Shapes before and after ReshapeLiteral have different numbers of "
|
|
"elements.");
|
|
}
|
|
|
|
*output = input.Clone();
|
|
output->mutable_shape_do_not_use()->Swap(&shape);
|
|
return Status::OK();
|
|
}
|
|
|
|
Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
|
|
DataType index_type, const TensorShape& indices_shape,
|
|
const xla::XlaOp& indices, const xla::XlaOp& on_value,
|
|
const xla::XlaOp& off_value, xla::XlaOp* one_hot) {
|
|
// Broadcast the linspace constant across the indices along the new axis,
|
|
// and test equality at each position.
|
|
std::vector<int64> broadcast_dims(indices_shape.dims());
|
|
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
|
|
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
|
|
|
|
TensorShape output_shape = indices_shape;
|
|
output_shape.InsertDim(axis, depth);
|
|
xla::Shape iota_shape;
|
|
TF_RETURN_IF_ERROR(
|
|
TensorShapeToXLAShape(index_type, output_shape, &iota_shape));
|
|
|
|
// Selects the user-provided off_value and on_value values.
|
|
*one_hot = xla::Select(
|
|
xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims),
|
|
xla::Broadcast(on_value, output_shape.dim_sizes()),
|
|
xla::Broadcast(off_value, output_shape.dim_sizes()));
|
|
return Status::OK();
|
|
}
|
|
|
|
DataType XlaHelpers::SumAccumulationType(const DataType& dtype) {
|
|
// Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
|
|
// repeated floating point additions.
|
|
if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
|
|
return DT_FLOAT;
|
|
}
|
|
// Upcast small integer types to 32 bit to avoid overflow.
|
|
if (dtype == DT_INT8 || dtype == DT_INT16) {
|
|
return DT_INT32;
|
|
}
|
|
if (dtype == DT_UINT8 || dtype == DT_UINT16) {
|
|
return DT_UINT32;
|
|
}
|
|
return dtype;
|
|
}
|
|
|
|
xla::XlaOp XlaHelpers::ConvertElementType(const xla::XlaOp& operand,
|
|
const DataType new_element_type) {
|
|
xla::PrimitiveType convert_to;
|
|
TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to));
|
|
return xla::ConvertElementType(operand, convert_to);
|
|
}
|
|
|
|
} // end namespace tensorflow
|