Extract helpers for PjRtClient::Compile.

PiperOrigin-RevId: 347503703
Change-Id: I18f8a4be58740074919be6fccba13cc24e080bae
This commit is contained in:
Qiao Zhang 2020-12-14 17:25:47 -08:00 committed by TensorFlower Gardener
parent 3fbedf4eba
commit c802bf66ef
4 changed files with 136 additions and 89 deletions

View File

@ -152,7 +152,9 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_proto_cc",
"@com_google_absl//absl/container:flat_hash_set",

View File

@ -2101,88 +2101,23 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
int num_replicas;
int num_partitions;
std::shared_ptr<DeviceAssignment> device_assignment;
if (options.compile_portable_executable) {
if (build_options.has_device_assignment()) {
return InvalidArgument(
"CompileOptions requests portable executable but "
"ExecutableBuildOptions includes a device assignment");
}
num_replicas = 1;
num_partitions = 1;
} else {
if (!build_options.has_device_assignment()) {
VLOG(2) << "PjRtStreamExecutorClient::Compile using default "
"device_assignment.";
TF_ASSIGN_OR_RETURN(
DeviceAssignment device_assignment,
GetDefaultDeviceAssignment(build_options.num_replicas(),
build_options.num_partitions()));
build_options.set_device_assignment(device_assignment);
}
VLOG(2) << "PjRtStreamExecutorClient::Compile device_assignment:\n"
<< build_options.device_assignment().ToString();
num_replicas = build_options.device_assignment().replica_count();
num_partitions = build_options.device_assignment().computation_count();
device_assignment =
std::make_shared<DeviceAssignment>(build_options.device_assignment());
}
TF_RETURN_IF_ERROR(ParseDeviceAssignmentCompileOptions(
options.compile_portable_executable, &options.executable_build_options,
[this](int num_replicas, int num_partitions) {
return this->GetDefaultDeviceAssignment(num_replicas, num_partitions);
},
&num_replicas, &num_partitions, &device_assignment));
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
computation.GetProgramShape());
if (!options.argument_layouts) {
options.argument_layouts = program_shape.parameters();
for (Shape& shape : *options.argument_layouts) {
LayoutUtil::ClearLayout(&shape);
}
} else if (options.argument_layouts->size() !=
program_shape.parameters_size()) {
return InvalidArgument(
"CompileOptions specify %d argument layouts, but computation has %d "
"arguments",
options.argument_layouts->size(), program_shape.parameters_size());
}
std::vector<const Shape*> argument_layout_pointers;
argument_layout_pointers.reserve(options.argument_layouts->size());
// Assign a default layout based on `sharded_shape` to any array subshapes in
// `dst_shape` that are missing layouts.
auto assign_layouts = [local_client = client()](const Shape& sharded_shape,
Shape* dst_shape) {
return ShapeUtil::ForEachMutableSubshapeWithStatus(
dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
if (subshape->IsArray() && !subshape->has_layout()) {
CHECK(ShapeUtil::IndexIsValid(sharded_shape, idx));
const Shape& sharded_subshape =
ShapeUtil::GetSubshape(sharded_shape, idx);
LayoutUtil::SetToDefaultLayout(subshape);
TF_ASSIGN_OR_RETURN(Shape layout, local_client->backend()
.transfer_manager()
->ChooseCompactLayoutForShape(
sharded_subshape));
*subshape->mutable_layout() = layout.layout();
}
return Status::OK();
});
};
TF_ASSIGN_OR_RETURN(auto sharded_shapes,
GetShardedProgramShapes(computation));
CHECK_EQ(sharded_shapes.first.size(), options.argument_layouts->size());
for (int i = 0; i < options.argument_layouts->size(); ++i) {
Shape* layout = &(*options.argument_layouts)[i];
argument_layout_pointers.push_back(layout);
TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.first[i], layout));
}
Shape result_layout;
if (build_options.result_layout()) {
result_layout = *build_options.result_layout();
} else {
result_layout = program_shape.result();
LayoutUtil::ClearLayout(&result_layout);
}
TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
build_options.set_result_layout(result_layout);
TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions(
computation,
[local_client = client()](Shape shape) {
return local_client->backend()
.transfer_manager()
->ChooseCompactLayoutForShape(shape);
},
options.argument_layouts, &options.executable_build_options,
&argument_layout_pointers));
// Find devices that are addressable by this client/task.
std::vector<PjRtExecutable::LogicalDeviceIds> addressable_device_logical_ids;

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/pjrt/utils.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@ -69,13 +70,9 @@ StatusOr<Shape> GetShardedShape(const HloInstructionProto& instr) {
return sharded_shape;
}
} // namespace
// Returns sharded (argument shapes, result shape) without layouts.
StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
const XlaComputation& computation) {
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
computation.GetProgramShape());
const XlaComputation& computation, const ProgramShape& program_shape) {
std::vector<Shape> arg_shapes;
arg_shapes.resize(program_shape.parameters_size());
Shape result_shape;
@ -111,6 +108,103 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
}
return std::make_pair(arg_shapes, result_shape);
}
} // namespace
Status ParseDeviceAssignmentCompileOptions(
bool compile_portable_executable, ExecutableBuildOptions* build_options,
std::function<StatusOr<DeviceAssignment>(int, int)>
GetDefaultDeviceAssignmentFunction,
int* num_replicas, int* num_partitions,
std::shared_ptr<DeviceAssignment>* device_assignment) {
if (compile_portable_executable) {
if (build_options->has_device_assignment()) {
return InvalidArgument(
"CompileOptions requests portable executable but "
"ExecutableBuildOptions includes a device assignment");
}
*num_replicas = 1;
*num_partitions = 1;
} else {
if (!build_options->has_device_assignment()) {
VLOG(2) << "Compile using default device_assignment.";
TF_ASSIGN_OR_RETURN(
DeviceAssignment device_assignment,
GetDefaultDeviceAssignmentFunction(build_options->num_replicas(),
build_options->num_partitions()));
build_options->set_device_assignment(device_assignment);
}
VLOG(2) << "Compile device_assignment:\n"
<< build_options->device_assignment().ToString();
*num_replicas = build_options->device_assignment().replica_count();
*num_partitions = build_options->device_assignment().computation_count();
*device_assignment =
std::make_shared<DeviceAssignment>(build_options->device_assignment());
}
return Status::OK();
}
Status DetermineArgumentLayoutsFromCompileOptions(
const XlaComputation& computation,
std::function<StatusOr<Shape>(Shape)>
choose_compact_layout_for_shape_function,
absl::optional<std::vector<Shape>>& argument_layouts,
ExecutableBuildOptions* build_options,
std::vector<const Shape*>* argument_layout_pointers) {
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
computation.GetProgramShape());
if (!argument_layouts) {
argument_layouts.emplace(program_shape.parameters());
for (Shape& shape : *argument_layouts) {
LayoutUtil::ClearLayout(&shape);
}
} else if (argument_layouts->size() != program_shape.parameters_size()) {
return InvalidArgument(
"CompileOptions specify %d argument layouts, but computation has %d "
"arguments",
argument_layouts->size(), program_shape.parameters_size());
}
argument_layout_pointers->reserve(argument_layouts->size());
// Assign a default layout based on `sharded_shape` to any array subshapes in
// `dst_shape` that are missing layouts.
auto assign_layouts = [&choose_compact_layout_for_shape_function](
const Shape& sharded_shape, Shape* dst_shape) {
return ShapeUtil::ForEachMutableSubshapeWithStatus(
dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
if (subshape->IsArray() && !subshape->has_layout()) {
CHECK(ShapeUtil::IndexIsValid(sharded_shape, idx));
const Shape& sharded_subshape =
ShapeUtil::GetSubshape(sharded_shape, idx);
LayoutUtil::SetToDefaultLayout(subshape);
TF_ASSIGN_OR_RETURN(
Shape layout,
choose_compact_layout_for_shape_function(sharded_subshape));
*subshape->mutable_layout() = layout.layout();
}
return Status::OK();
});
};
TF_ASSIGN_OR_RETURN(auto sharded_shapes,
GetShardedProgramShapes(computation, program_shape));
CHECK_EQ(sharded_shapes.first.size(), argument_layouts->size());
for (int i = 0; i < argument_layouts->size(); ++i) {
Shape* layout = &(*argument_layouts)[i];
argument_layout_pointers->push_back(layout);
TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.first[i], layout));
}
Shape result_layout;
if (build_options->result_layout()) {
result_layout = *build_options->result_layout();
} else {
result_layout = program_shape.result();
LayoutUtil::ClearLayout(&result_layout);
}
TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
build_options->set_result_layout(result_layout);
return Status::OK();
}
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
const HloModule& module, bool tuple_inputs) {

View File

@ -17,17 +17,33 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_PJRT_UTILS_H_
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
// Extract from XlaComputation the sharded program shapes (argument shapes,
// result shape) without layouts.
StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
const XlaComputation& computation);
// Returns the num_replicas, num_partitions and device assignment given a
// ExecutableBuildOptions and whether we want a portable executable.
Status ParseDeviceAssignmentCompileOptions(
bool compile_portable_executable, ExecutableBuildOptions* build_options,
std::function<StatusOr<DeviceAssignment>(int, int)>
GetDefaultDeviceAssignmentFunction,
int* num_replicas, int* num_partitions,
std::shared_ptr<DeviceAssignment>* device_assignment);
// Returns pointers to the argument layouts given an XlaComputation and
// ExecutableBuildOptions.
Status DetermineArgumentLayoutsFromCompileOptions(
const XlaComputation& computation,
std::function<StatusOr<Shape>(Shape)>
choose_compact_layout_for_shape_function,
absl::optional<std::vector<Shape>>& argument_layouts,
ExecutableBuildOptions* build_options,
std::vector<const Shape*>* argument_layout_pointers);
// Executables can donate buffers so that buffers can be aliased from inputs
// to outputs. This function returns the list of parameters that must be