[TF2XLA] Remove the last unncessary host-to-device memcpy, and remove the
HostTensorToLiteral function completely to prevent potential future misuse of unnecessary memcpy. PiperOrigin-RevId: 200750664
This commit is contained in:
parent
916c0aab83
commit
f9b832d91f
@ -27,7 +27,7 @@ class MirrorPadOp : public XlaOpKernel {
|
||||
|
||||
xla::StatusOr<xla::XlaOp> DoMirrorPad(const xla::XlaOp& t,
|
||||
const xla::Shape& original_shape,
|
||||
const xla::Literal& pad_literal,
|
||||
const xla::LiteralSlice& pad_literal,
|
||||
xla::XlaBuilder* b) {
|
||||
xla::XlaOp accum = t;
|
||||
for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0;
|
||||
|
@ -63,8 +63,8 @@ class PadOp : public XlaOpKernel {
|
||||
int before = pad_literal.Get<int32>({i, 0});
|
||||
int after = pad_literal.Get<int32>({i, 1});
|
||||
OP_REQUIRES(ctx, before >= 0 && after >= 0,
|
||||
errors::InvalidArgument("Paddings must be non-negative: ",
|
||||
before, " ", after));
|
||||
errors::InvalidArgument(
|
||||
"Paddings must be non-negative: ", before, " ", after));
|
||||
dim->set_edge_padding_low(before);
|
||||
dim->set_edge_padding_high(after);
|
||||
}
|
||||
|
@ -56,9 +56,9 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
|
||||
|
||||
// Evaluate the constant, reshaping to a 1-vector if it is a scalar.
|
||||
xla::Literal axes_literal;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->ConstantInputReshaped(
|
||||
1, {axes_tensor_shape.num_elements()}, &axes_literal));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->ConstantInputReshaped(1, {axes_tensor_shape.num_elements()},
|
||||
&axes_literal));
|
||||
|
||||
VLOG(1) << "data shape: " << data_shape.DebugString();
|
||||
VLOG(1) << "axes : " << axes_literal.ToString();
|
||||
|
@ -55,9 +55,10 @@ Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) {
|
||||
|
||||
// The type-specific part of the implementation of Range.
|
||||
template <typename T>
|
||||
Status CreateRangeTensor(const xla::Literal& start_literal,
|
||||
const xla::Literal& limit_literal,
|
||||
const xla::Literal& delta_literal, Tensor* output) {
|
||||
Status CreateRangeTensor(const xla::LiteralSlice& start_literal,
|
||||
const xla::LiteralSlice& limit_literal,
|
||||
const xla::LiteralSlice& delta_literal,
|
||||
Tensor* output) {
|
||||
T start = start_literal.Get<T>({});
|
||||
T limit = limit_literal.Get<T>({});
|
||||
T delta = delta_literal.Get<T>({});
|
||||
@ -67,13 +68,13 @@ Status CreateRangeTensor(const xla::Literal& start_literal,
|
||||
}
|
||||
if (delta > 0) {
|
||||
if (start > limit) {
|
||||
return errors::InvalidArgument("Requires start <= limit when delta > 0: ",
|
||||
start, "/", limit);
|
||||
return errors::InvalidArgument(
|
||||
"Requires start <= limit when delta > 0: ", start, "/", limit);
|
||||
}
|
||||
} else {
|
||||
if (start < limit) {
|
||||
return errors::InvalidArgument("Requires start >= limit when delta < 0: ",
|
||||
start, "/", limit);
|
||||
return errors::InvalidArgument(
|
||||
"Requires start >= limit when delta < 0: ", start, "/", limit);
|
||||
}
|
||||
}
|
||||
int64 size =
|
||||
|
@ -134,7 +134,7 @@ class SplitVOp : public XlaOpKernel {
|
||||
errors::InvalidArgument(
|
||||
"Number of ways to split should be > 0, but got ", num_split));
|
||||
|
||||
// check that sizes are correct
|
||||
// Check that sizes are correct.
|
||||
int total_split_size = 0;
|
||||
int neg_one_dim = -1;
|
||||
std::vector<int64> split_sizes_vec(num_split, -1);
|
||||
@ -148,7 +148,7 @@ class SplitVOp : public XlaOpKernel {
|
||||
" number of elements as the output. Got ",
|
||||
split_size_shape.dims(), "-D and ",
|
||||
split_size_shape.num_elements(), " elements"));
|
||||
// get the dimension of this split
|
||||
// Get the dimension of this split.
|
||||
xla::Literal split_size_literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &split_size_literal));
|
||||
|
||||
|
@ -22,24 +22,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) {
|
||||
xla::Shape literal_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(
|
||||
host_tensor.dtype(), host_tensor.shape(), &literal_shape));
|
||||
|
||||
*literal = xla::Literal(literal_shape);
|
||||
|
||||
// memcpy over the payload ...
|
||||
// TODO(phawkins): handle string types.
|
||||
size_t total_bytes = host_tensor.TotalBytes();
|
||||
if (total_bytes > 0) {
|
||||
void* dst_ptr = literal->untyped_data();
|
||||
const void* src_ptr = DMAHelper::base(&host_tensor);
|
||||
memcpy(dst_ptr, src_ptr, total_bytes);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
|
||||
xla::BorrowingLiteral* literal) {
|
||||
xla::Shape xla_shape;
|
||||
|
@ -26,10 +26,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Copies 'host_tensor' to an XLA Literal. Fails if host_tensor is of an
|
||||
// unsupported type.
|
||||
Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal);
|
||||
|
||||
// Returns a BorrowingLiteral that utilizes the same underlying buffer owned by
|
||||
// 'host_tensor'.
|
||||
Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
|
||||
|
@ -92,7 +92,7 @@ void XlaContext::AddRetval(int retval_index, DataType type,
|
||||
}
|
||||
|
||||
Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
|
||||
const xla::Literal& literal) {
|
||||
const xla::LiteralSlice& literal) {
|
||||
VLOG(1) << "Adding retval index " << retval_index
|
||||
<< " with non-data-dependent tensor to XLA computation";
|
||||
if (retvals_.size() <= retval_index) {
|
||||
|
@ -83,7 +83,7 @@ class XlaContext : public ResourceBase {
|
||||
|
||||
// As for Retval, but for return values that are compile-time constants.
|
||||
Status AddConstRetval(int retval_index, DataType dtype,
|
||||
const xla::Literal& literal);
|
||||
const xla::LiteralSlice& literal);
|
||||
|
||||
// Creates a resource with resource `kind` and initial value `handle`. `name`
|
||||
// is a descriptive name for use in error messages. See the `XlaResource`
|
||||
|
@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
@ -248,6 +247,7 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
|
||||
return errors::InvalidArgument("Invalid argument type ",
|
||||
DataTypeString(index_type));
|
||||
}
|
||||
|
||||
xla::BorrowingLiteral linspace_literal;
|
||||
TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal));
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -87,6 +88,25 @@ Status XlaOpKernelContext::ConstantInputReshaped(
|
||||
}
|
||||
const XlaExpression* expression = CastExpressionFromTensor(tensor);
|
||||
|
||||
auto copy_tensor_to_literal = [](const Tensor& tensor,
|
||||
xla::Literal* literal) {
|
||||
xla::Shape literal_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), &literal_shape));
|
||||
|
||||
*literal = xla::Literal(literal_shape);
|
||||
|
||||
// memcpy over the payload ...
|
||||
// TODO(phawkins): handle string types.
|
||||
size_t total_bytes = tensor.TotalBytes();
|
||||
if (total_bytes > 0) {
|
||||
void* dst_ptr = literal->untyped_data();
|
||||
const void* src_ptr = DMAHelper::base(&tensor);
|
||||
memcpy(dst_ptr, src_ptr, total_bytes);
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
// If the tensor has a known constant value, there is no need to invoke XLA.
|
||||
if (expression->has_constant_value()) {
|
||||
Tensor temp(tensor.dtype());
|
||||
@ -95,13 +115,15 @@ Status XlaOpKernelContext::ConstantInputReshaped(
|
||||
// with the enclosing Tensor.
|
||||
return errors::Internal("Incompatible shapes in ConstantInputReshaped.");
|
||||
}
|
||||
return HostTensorToLiteral(temp, constant_literal);
|
||||
|
||||
return copy_tensor_to_literal(temp, constant_literal);
|
||||
}
|
||||
|
||||
// Make sure we treat zero-element tensors as constant.
|
||||
if (new_shape.num_elements() == 0) {
|
||||
Tensor temp(tensor.dtype(), new_shape);
|
||||
return HostTensorToLiteral(temp, constant_literal);
|
||||
|
||||
return copy_tensor_to_literal(temp, constant_literal);
|
||||
}
|
||||
|
||||
xla::XlaOp handle = expression->handle();
|
||||
@ -162,7 +184,8 @@ Status XlaOpKernelContext::ConstantInputReshaped(
|
||||
}
|
||||
|
||||
// Converts an int32 or int64 scalar literal to an int64.
|
||||
static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) {
|
||||
static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal,
|
||||
int64* out) {
|
||||
if (xla::ShapeUtil::Rank(literal.shape()) != 0) {
|
||||
return errors::InvalidArgument("value is not a scalar");
|
||||
}
|
||||
@ -177,7 +200,8 @@ static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) {
|
||||
}
|
||||
|
||||
// Converts an float32 or float64 scalar literal to a float64.
|
||||
static Status LiteralToFloat64Scalar(const xla::Literal& literal, double* out) {
|
||||
static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal,
|
||||
double* out) {
|
||||
if (xla::ShapeUtil::Rank(literal.shape()) != 0) {
|
||||
return errors::InvalidArgument("value is not a scalar");
|
||||
}
|
||||
@ -204,7 +228,7 @@ Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) {
|
||||
}
|
||||
|
||||
// Converts an int32 or int64 1D literal to an int64 vector.
|
||||
static Status LiteralToInt64Vector(const xla::Literal& literal,
|
||||
static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
|
||||
std::vector<int64>* out) {
|
||||
if (xla::ShapeUtil::Rank(literal.shape()) != 1) {
|
||||
return errors::InvalidArgument("value is not 1D");
|
||||
@ -368,8 +392,9 @@ void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
|
||||
void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
|
||||
const TensorShape& shape = constant.shape();
|
||||
|
||||
xla::Literal literal;
|
||||
OP_REQUIRES_OK(context_, HostTensorToLiteral(constant, &literal));
|
||||
xla::BorrowingLiteral literal;
|
||||
OP_REQUIRES_OK(context_, HostTensorToBorrowingLiteral(constant, &literal));
|
||||
|
||||
xla::XlaOp handle = builder()->ConstantLiteral(literal);
|
||||
CHECK_NE(handle.builder(), nullptr);
|
||||
|
||||
|
@ -2355,7 +2355,6 @@ LiteralSlice::LiteralSlice(const LiteralBase& literal,
|
||||
BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
|
||||
: LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
|
||||
CHECK(ShapeUtil::IsArray(*shape_));
|
||||
CHECK_NE(src_buf_ptr, nullptr);
|
||||
CHECK(LayoutUtil::HasLayout(*shape_));
|
||||
|
||||
root_piece_ = Piece();
|
||||
|
Loading…
Reference in New Issue
Block a user