Allow specifying fast memory space for device shape.
PiperOrigin-RevId: 263570415
This commit is contained in:
parent
f4e6ee7c09
commit
0e271c3e39
tensorflow/compiler
@ -90,8 +90,9 @@ XlaDeviceContext::XlaDeviceContext(
|
||||
CHECK(host_to_device_stream_ != nullptr);
|
||||
CHECK(stream_ != nullptr);
|
||||
if (!shape_representation_fn_) {
|
||||
shape_representation_fn_ = [](const TensorShape& shape,
|
||||
DataType dtype) -> xla::StatusOr<xla::Shape> {
|
||||
shape_representation_fn_ =
|
||||
[](const TensorShape& shape, DataType dtype,
|
||||
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
|
||||
return xla_shape;
|
||||
@ -130,9 +131,10 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
CHECK(xla_tensor);
|
||||
|
||||
Status status = [&]() -> Status {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape shape,
|
||||
shape_representation_fn_(device_tensor->shape(),
|
||||
device_tensor->dtype()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
xla::Shape shape,
|
||||
shape_representation_fn_(device_tensor->shape(), device_tensor->dtype(),
|
||||
/*use_fast_memory=*/false));
|
||||
|
||||
// The device tensor should always be fresh.
|
||||
TF_RET_CHECK(!xla_tensor->has_shaped_buffer());
|
||||
|
@ -220,7 +220,8 @@ Status BuildComputation(
|
||||
// If there is a shape representation function, reshape the output
|
||||
// tensor to the shape given by the representation shape function.
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn(
|
||||
output.shape, output.type));
|
||||
output.shape, output.type,
|
||||
/*use_fast_memory=*/false));
|
||||
value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions()));
|
||||
retval_index_and_layout.emplace_back(elems.size(), shape.layout());
|
||||
} else if (it != retval_shardings.end()) {
|
||||
@ -301,7 +302,8 @@ Status BuildComputation(
|
||||
if (shape_representation_fn) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
xla::Shape xla_shape,
|
||||
shape_representation_fn(resource->shape(), resource->type()));
|
||||
shape_representation_fn(resource->shape(), resource->type(),
|
||||
/*use_fast_memory=*/false));
|
||||
representation_shape = xla_shape;
|
||||
}
|
||||
if (resource->representation_shape().has_value()) {
|
||||
@ -477,8 +479,8 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
|
||||
// The default shape representation function is the identity.
|
||||
if (!options_.shape_representation_fn) {
|
||||
options_.shape_representation_fn =
|
||||
[](const TensorShape& shape,
|
||||
DataType dtype) -> xla::StatusOr<xla::Shape> {
|
||||
[](const TensorShape& shape, DataType dtype,
|
||||
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
|
||||
return xla_shape;
|
||||
@ -711,8 +713,9 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
|
||||
TF_RETURN_IF_ERROR(
|
||||
XLAShapeToTensorShape(absl::get<xla::Shape>(arg.shape), &shape));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(*xla_shape,
|
||||
options_.shape_representation_fn(shape, arg.type));
|
||||
TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn(
|
||||
shape, arg.type,
|
||||
/*use_fast_memory=*/false));
|
||||
} else {
|
||||
if (absl::holds_alternative<xla::Shape>(arg.shape)) {
|
||||
*xla_shape = absl::get<xla::Shape>(arg.shape);
|
||||
@ -736,7 +739,8 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
|
||||
TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
|
||||
TF_ASSIGN_OR_RETURN(*xla_shape,
|
||||
options_.shape_representation_fn(
|
||||
absl::get<TensorShape>(arg.shape), arg.type));
|
||||
absl::get<TensorShape>(arg.shape), arg.type,
|
||||
/*use_fast_memory=*/false));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -286,7 +286,8 @@ class XlaCompiler {
|
||||
std::shared_ptr<xla::XlaComputation> computation;
|
||||
};
|
||||
|
||||
typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType)>
|
||||
typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType,
|
||||
bool)>
|
||||
ShapeRepresentationFn;
|
||||
struct Options {
|
||||
// Name of the compilation device to use. It must be set by the caller.
|
||||
|
@ -304,7 +304,8 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForUnwrittenResource) {
|
||||
|
||||
auto options = DefaultOptions();
|
||||
options.shape_representation_fn =
|
||||
[](const TensorShape& shape, DataType dt) -> xla::StatusOr<xla::Shape> {
|
||||
[](const TensorShape& shape, DataType dt,
|
||||
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
|
||||
*xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
|
||||
@ -357,7 +358,8 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) {
|
||||
|
||||
auto options = DefaultOptions();
|
||||
options.shape_representation_fn =
|
||||
[](const TensorShape& shape, DataType dt) -> xla::StatusOr<xla::Shape> {
|
||||
[](const TensorShape& shape, DataType dt,
|
||||
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
|
||||
*xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
|
||||
@ -1080,7 +1082,8 @@ TEST_F(XlaCompilerTest, ResultLayoutSingle) {
|
||||
auto options = DefaultOptions();
|
||||
// Sets the representation function to return a non-default layout.
|
||||
options.shape_representation_fn =
|
||||
[](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
|
||||
[](const TensorShape& shape, DataType type,
|
||||
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
|
||||
*xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
|
||||
@ -1118,7 +1121,8 @@ TEST_F(XlaCompilerTest, ResultLayoutMultiple) {
|
||||
auto options = DefaultOptions();
|
||||
// Sets the representation function to return a non-default layout.
|
||||
options.shape_representation_fn =
|
||||
[](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
|
||||
[](const TensorShape& shape, DataType type,
|
||||
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
|
||||
*xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
|
||||
@ -1252,7 +1256,8 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
|
||||
// Compiles the graph.
|
||||
XlaCompiler::Options options = DefaultOptions();
|
||||
options.shape_representation_fn =
|
||||
[](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
|
||||
[](const TensorShape& shape, DataType type,
|
||||
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
|
||||
xla::PrimitiveType ptype;
|
||||
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype));
|
||||
return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()});
|
||||
@ -1322,7 +1327,8 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
|
||||
// Compiles the graph.
|
||||
XlaCompiler::Options options = DefaultOptions();
|
||||
options.shape_representation_fn =
|
||||
[](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
|
||||
[](const TensorShape& shape, DataType type,
|
||||
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
|
||||
xla::PrimitiveType ptype;
|
||||
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype));
|
||||
return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()});
|
||||
|
@ -415,7 +415,8 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
|
||||
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
|
||||
ctx->compiler()->options().shape_representation_fn(
|
||||
variable->shape(), variable->type()));
|
||||
variable->shape(), variable->type(),
|
||||
/*use_fast_memory=*/false));
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape));
|
||||
@ -550,9 +551,10 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type,
|
||||
|
||||
TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
xla::Shape representation_shape,
|
||||
ctx->compiler()->options().shape_representation_fn(shape, type));
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
|
||||
ctx->compiler()->options().shape_representation_fn(
|
||||
shape, type,
|
||||
/*use_fast_memory=*/false));
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
|
||||
if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
|
||||
|
Loading…
Reference in New Issue
Block a user