Allow specifying fast memory space for device shape.

PiperOrigin-RevId: 263570415
This commit is contained in:
Berkin Ilbeyi 2019-08-15 08:41:42 -07:00 committed by TensorFlower Gardener
parent f4e6ee7c09
commit 0e271c3e39
5 changed files with 38 additions and 23 deletions

View File

@ -90,8 +90,9 @@ XlaDeviceContext::XlaDeviceContext(
CHECK(host_to_device_stream_ != nullptr);
CHECK(stream_ != nullptr);
if (!shape_representation_fn_) {
shape_representation_fn_ = [](const TensorShape& shape,
DataType dtype) -> xla::StatusOr<xla::Shape> {
shape_representation_fn_ =
[](const TensorShape& shape, DataType dtype,
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
return xla_shape;
@ -130,9 +131,10 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
CHECK(xla_tensor);
Status status = [&]() -> Status {
TF_ASSIGN_OR_RETURN(xla::Shape shape,
shape_representation_fn_(device_tensor->shape(),
device_tensor->dtype()));
TF_ASSIGN_OR_RETURN(
xla::Shape shape,
shape_representation_fn_(device_tensor->shape(), device_tensor->dtype(),
/*use_fast_memory=*/false));
// The device tensor should always be fresh.
TF_RET_CHECK(!xla_tensor->has_shaped_buffer());

View File

@ -220,7 +220,8 @@ Status BuildComputation(
// If there is a shape representation function, reshape the output
// tensor to the shape given by the representation shape function.
TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn(
output.shape, output.type));
output.shape, output.type,
/*use_fast_memory=*/false));
value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions()));
retval_index_and_layout.emplace_back(elems.size(), shape.layout());
} else if (it != retval_shardings.end()) {
@ -301,7 +302,8 @@ Status BuildComputation(
if (shape_representation_fn) {
TF_ASSIGN_OR_RETURN(
xla::Shape xla_shape,
shape_representation_fn(resource->shape(), resource->type()));
shape_representation_fn(resource->shape(), resource->type(),
/*use_fast_memory=*/false));
representation_shape = xla_shape;
}
if (resource->representation_shape().has_value()) {
@ -477,8 +479,8 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
// The default shape representation function is the identity.
if (!options_.shape_representation_fn) {
options_.shape_representation_fn =
[](const TensorShape& shape,
DataType dtype) -> xla::StatusOr<xla::Shape> {
[](const TensorShape& shape, DataType dtype,
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
return xla_shape;
@ -711,8 +713,9 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
TF_RETURN_IF_ERROR(
XLAShapeToTensorShape(absl::get<xla::Shape>(arg.shape), &shape));
}
TF_ASSIGN_OR_RETURN(*xla_shape,
options_.shape_representation_fn(shape, arg.type));
TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn(
shape, arg.type,
/*use_fast_memory=*/false));
} else {
if (absl::holds_alternative<xla::Shape>(arg.shape)) {
*xla_shape = absl::get<xla::Shape>(arg.shape);
@ -736,7 +739,8 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
TF_ASSIGN_OR_RETURN(*xla_shape,
options_.shape_representation_fn(
absl::get<TensorShape>(arg.shape), arg.type));
absl::get<TensorShape>(arg.shape), arg.type,
/*use_fast_memory=*/false));
return Status::OK();
}

View File

@ -286,7 +286,8 @@ class XlaCompiler {
std::shared_ptr<xla::XlaComputation> computation;
};
typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType)>
typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType,
bool)>
ShapeRepresentationFn;
struct Options {
// Name of the compilation device to use. It must be set by the caller.

View File

@ -304,7 +304,8 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForUnwrittenResource) {
auto options = DefaultOptions();
options.shape_representation_fn =
[](const TensorShape& shape, DataType dt) -> xla::StatusOr<xla::Shape> {
[](const TensorShape& shape, DataType dt,
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
*xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
@ -357,7 +358,8 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) {
auto options = DefaultOptions();
options.shape_representation_fn =
[](const TensorShape& shape, DataType dt) -> xla::StatusOr<xla::Shape> {
[](const TensorShape& shape, DataType dt,
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
*xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
@ -1080,7 +1082,8 @@ TEST_F(XlaCompilerTest, ResultLayoutSingle) {
auto options = DefaultOptions();
// Sets the representation function to return a non-default layout.
options.shape_representation_fn =
[](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
[](const TensorShape& shape, DataType type,
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
*xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
@ -1118,7 +1121,8 @@ TEST_F(XlaCompilerTest, ResultLayoutMultiple) {
auto options = DefaultOptions();
// Sets the representation function to return a non-default layout.
options.shape_representation_fn =
[](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
[](const TensorShape& shape, DataType type,
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
*xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
@ -1252,7 +1256,8 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
// Compiles the graph.
XlaCompiler::Options options = DefaultOptions();
options.shape_representation_fn =
[](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
[](const TensorShape& shape, DataType type,
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
xla::PrimitiveType ptype;
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype));
return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()});
@ -1322,7 +1327,8 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
// Compiles the graph.
XlaCompiler::Options options = DefaultOptions();
options.shape_representation_fn =
[](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
[](const TensorShape& shape, DataType type,
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
xla::PrimitiveType ptype;
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype));
return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()});

View File

@ -415,7 +415,8 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
ctx->compiler()->options().shape_representation_fn(
variable->shape(), variable->type()));
variable->shape(), variable->type(),
/*use_fast_memory=*/false));
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(
TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape));
@ -550,9 +551,10 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type,
TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
TF_ASSIGN_OR_RETURN(
xla::Shape representation_shape,
ctx->compiler()->options().shape_representation_fn(shape, type));
TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
ctx->compiler()->options().shape_representation_fn(
shape, type,
/*use_fast_memory=*/false));
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {