[TF2XLA] Remove the last unncessary host-to-device memcpy, and remove the

HostTensorToLiteral function completely to prevent potential future misuse of
unnecessary memcpy.

PiperOrigin-RevId: 200750664
This commit is contained in:
Kay Zhu 2018-06-15 11:54:29 -07:00 committed by TensorFlower Gardener
parent 916c0aab83
commit f9b832d91f
12 changed files with 51 additions and 48 deletions

View File

@ -27,7 +27,7 @@ class MirrorPadOp : public XlaOpKernel {
xla::StatusOr<xla::XlaOp> DoMirrorPad(const xla::XlaOp& t,
const xla::Shape& original_shape,
const xla::Literal& pad_literal,
const xla::LiteralSlice& pad_literal,
xla::XlaBuilder* b) {
xla::XlaOp accum = t;
for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0;

View File

@ -63,8 +63,8 @@ class PadOp : public XlaOpKernel {
int before = pad_literal.Get<int32>({i, 0});
int after = pad_literal.Get<int32>({i, 1});
OP_REQUIRES(ctx, before >= 0 && after >= 0,
errors::InvalidArgument("Paddings must be non-negative: ",
before, " ", after));
errors::InvalidArgument(
"Paddings must be non-negative: ", before, " ", after));
dim->set_edge_padding_low(before);
dim->set_edge_padding_high(after);
}

View File

@ -56,9 +56,9 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
// Evaluate the constant, reshaping to a 1-vector if it is a scalar.
xla::Literal axes_literal;
OP_REQUIRES_OK(ctx,
ctx->ConstantInputReshaped(
1, {axes_tensor_shape.num_elements()}, &axes_literal));
OP_REQUIRES_OK(
ctx, ctx->ConstantInputReshaped(1, {axes_tensor_shape.num_elements()},
&axes_literal));
VLOG(1) << "data shape: " << data_shape.DebugString();
VLOG(1) << "axes : " << axes_literal.ToString();

View File

@ -55,9 +55,10 @@ Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) {
// The type-specific part of the implementation of Range.
template <typename T>
Status CreateRangeTensor(const xla::Literal& start_literal,
const xla::Literal& limit_literal,
const xla::Literal& delta_literal, Tensor* output) {
Status CreateRangeTensor(const xla::LiteralSlice& start_literal,
const xla::LiteralSlice& limit_literal,
const xla::LiteralSlice& delta_literal,
Tensor* output) {
T start = start_literal.Get<T>({});
T limit = limit_literal.Get<T>({});
T delta = delta_literal.Get<T>({});
@ -67,13 +68,13 @@ Status CreateRangeTensor(const xla::Literal& start_literal,
}
if (delta > 0) {
if (start > limit) {
return errors::InvalidArgument("Requires start <= limit when delta > 0: ",
start, "/", limit);
return errors::InvalidArgument(
"Requires start <= limit when delta > 0: ", start, "/", limit);
}
} else {
if (start < limit) {
return errors::InvalidArgument("Requires start >= limit when delta < 0: ",
start, "/", limit);
return errors::InvalidArgument(
"Requires start >= limit when delta < 0: ", start, "/", limit);
}
}
int64 size =

View File

@ -134,7 +134,7 @@ class SplitVOp : public XlaOpKernel {
errors::InvalidArgument(
"Number of ways to split should be > 0, but got ", num_split));
// check that sizes are correct
// Check that sizes are correct.
int total_split_size = 0;
int neg_one_dim = -1;
std::vector<int64> split_sizes_vec(num_split, -1);
@ -148,7 +148,7 @@ class SplitVOp : public XlaOpKernel {
" number of elements as the output. Got ",
split_size_shape.dims(), "-D and ",
split_size_shape.num_elements(), " elements"));
// get the dimension of this split
// Get the dimension of this split.
xla::Literal split_size_literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &split_size_literal));

View File

@ -22,24 +22,6 @@ limitations under the License.
namespace tensorflow {
Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) {
xla::Shape literal_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(
host_tensor.dtype(), host_tensor.shape(), &literal_shape));
*literal = xla::Literal(literal_shape);
// memcpy over the payload ...
// TODO(phawkins): handle string types.
size_t total_bytes = host_tensor.TotalBytes();
if (total_bytes > 0) {
void* dst_ptr = literal->untyped_data();
const void* src_ptr = DMAHelper::base(&host_tensor);
memcpy(dst_ptr, src_ptr, total_bytes);
}
return Status::OK();
}
Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
xla::BorrowingLiteral* literal) {
xla::Shape xla_shape;

View File

@ -26,10 +26,6 @@ limitations under the License.
namespace tensorflow {
// Copies 'host_tensor' to an XLA Literal. Fails if host_tensor is of an
// unsupported type.
Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal);
// Returns a BorrowingLiteral that utilizes the same underlying buffer owned by
// 'host_tensor'.
Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,

View File

@ -92,7 +92,7 @@ void XlaContext::AddRetval(int retval_index, DataType type,
}
Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
const xla::Literal& literal) {
const xla::LiteralSlice& literal) {
VLOG(1) << "Adding retval index " << retval_index
<< " with non-data-dependent tensor to XLA computation";
if (retvals_.size() <= retval_index) {

View File

@ -83,7 +83,7 @@ class XlaContext : public ResourceBase {
// As for Retval, but for return values that are compile-time constants.
Status AddConstRetval(int retval_index, DataType dtype,
const xla::Literal& literal);
const xla::LiteralSlice& literal);
// Creates a resource with resource `kind` and initial value `handle`. `name`
// is a descriptive name for use in error messages. See the `XlaResource`

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@ -248,6 +247,7 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
return errors::InvalidArgument("Invalid argument type ",
DataTypeString(index_type));
}
xla::BorrowingLiteral linspace_literal;
TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal));

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
namespace tensorflow {
@ -87,6 +88,25 @@ Status XlaOpKernelContext::ConstantInputReshaped(
}
const XlaExpression* expression = CastExpressionFromTensor(tensor);
auto copy_tensor_to_literal = [](const Tensor& tensor,
xla::Literal* literal) {
xla::Shape literal_shape;
TF_RETURN_IF_ERROR(
TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), &literal_shape));
*literal = xla::Literal(literal_shape);
// memcpy over the payload ...
// TODO(phawkins): handle string types.
size_t total_bytes = tensor.TotalBytes();
if (total_bytes > 0) {
void* dst_ptr = literal->untyped_data();
const void* src_ptr = DMAHelper::base(&tensor);
memcpy(dst_ptr, src_ptr, total_bytes);
}
return Status::OK();
};
// If the tensor has a known constant value, there is no need to invoke XLA.
if (expression->has_constant_value()) {
Tensor temp(tensor.dtype());
@ -95,13 +115,15 @@ Status XlaOpKernelContext::ConstantInputReshaped(
// with the enclosing Tensor.
return errors::Internal("Incompatible shapes in ConstantInputReshaped.");
}
return HostTensorToLiteral(temp, constant_literal);
return copy_tensor_to_literal(temp, constant_literal);
}
// Make sure we treat zero-element tensors as constant.
if (new_shape.num_elements() == 0) {
Tensor temp(tensor.dtype(), new_shape);
return HostTensorToLiteral(temp, constant_literal);
return copy_tensor_to_literal(temp, constant_literal);
}
xla::XlaOp handle = expression->handle();
@ -162,7 +184,8 @@ Status XlaOpKernelContext::ConstantInputReshaped(
}
// Converts an int32 or int64 scalar literal to an int64.
static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) {
static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal,
int64* out) {
if (xla::ShapeUtil::Rank(literal.shape()) != 0) {
return errors::InvalidArgument("value is not a scalar");
}
@ -177,7 +200,8 @@ static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) {
}
// Converts an float32 or float64 scalar literal to a float64.
static Status LiteralToFloat64Scalar(const xla::Literal& literal, double* out) {
static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal,
double* out) {
if (xla::ShapeUtil::Rank(literal.shape()) != 0) {
return errors::InvalidArgument("value is not a scalar");
}
@ -204,7 +228,7 @@ Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) {
}
// Converts an int32 or int64 1D literal to an int64 vector.
static Status LiteralToInt64Vector(const xla::Literal& literal,
static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
std::vector<int64>* out) {
if (xla::ShapeUtil::Rank(literal.shape()) != 1) {
return errors::InvalidArgument("value is not 1D");
@ -368,8 +392,9 @@ void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
const TensorShape& shape = constant.shape();
xla::Literal literal;
OP_REQUIRES_OK(context_, HostTensorToLiteral(constant, &literal));
xla::BorrowingLiteral literal;
OP_REQUIRES_OK(context_, HostTensorToBorrowingLiteral(constant, &literal));
xla::XlaOp handle = builder()->ConstantLiteral(literal);
CHECK_NE(handle.builder(), nullptr);

View File

@ -2355,7 +2355,6 @@ LiteralSlice::LiteralSlice(const LiteralBase& literal,
BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
: LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
CHECK(ShapeUtil::IsArray(*shape_));
CHECK_NE(src_buf_ptr, nullptr);
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = Piece();