Extract helpers for PjRtClient::Compile.
PiperOrigin-RevId: 347503703 Change-Id: I18f8a4be58740074919be6fccba13cc24e080bae
This commit is contained in:
parent
3fbedf4eba
commit
c802bf66ef
@ -152,7 +152,9 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
|
@ -2101,88 +2101,23 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
|
||||
int num_replicas;
|
||||
int num_partitions;
|
||||
std::shared_ptr<DeviceAssignment> device_assignment;
|
||||
if (options.compile_portable_executable) {
|
||||
if (build_options.has_device_assignment()) {
|
||||
return InvalidArgument(
|
||||
"CompileOptions requests portable executable but "
|
||||
"ExecutableBuildOptions includes a device assignment");
|
||||
}
|
||||
num_replicas = 1;
|
||||
num_partitions = 1;
|
||||
} else {
|
||||
if (!build_options.has_device_assignment()) {
|
||||
VLOG(2) << "PjRtStreamExecutorClient::Compile using default "
|
||||
"device_assignment.";
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
DeviceAssignment device_assignment,
|
||||
GetDefaultDeviceAssignment(build_options.num_replicas(),
|
||||
build_options.num_partitions()));
|
||||
build_options.set_device_assignment(device_assignment);
|
||||
}
|
||||
VLOG(2) << "PjRtStreamExecutorClient::Compile device_assignment:\n"
|
||||
<< build_options.device_assignment().ToString();
|
||||
num_replicas = build_options.device_assignment().replica_count();
|
||||
num_partitions = build_options.device_assignment().computation_count();
|
||||
device_assignment =
|
||||
std::make_shared<DeviceAssignment>(build_options.device_assignment());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ParseDeviceAssignmentCompileOptions(
|
||||
options.compile_portable_executable, &options.executable_build_options,
|
||||
[this](int num_replicas, int num_partitions) {
|
||||
return this->GetDefaultDeviceAssignment(num_replicas, num_partitions);
|
||||
},
|
||||
&num_replicas, &num_partitions, &device_assignment));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
|
||||
computation.GetProgramShape());
|
||||
if (!options.argument_layouts) {
|
||||
options.argument_layouts = program_shape.parameters();
|
||||
for (Shape& shape : *options.argument_layouts) {
|
||||
LayoutUtil::ClearLayout(&shape);
|
||||
}
|
||||
} else if (options.argument_layouts->size() !=
|
||||
program_shape.parameters_size()) {
|
||||
return InvalidArgument(
|
||||
"CompileOptions specify %d argument layouts, but computation has %d "
|
||||
"arguments",
|
||||
options.argument_layouts->size(), program_shape.parameters_size());
|
||||
}
|
||||
std::vector<const Shape*> argument_layout_pointers;
|
||||
argument_layout_pointers.reserve(options.argument_layouts->size());
|
||||
|
||||
// Assign a default layout based on `sharded_shape` to any array subshapes in
|
||||
// `dst_shape` that are missing layouts.
|
||||
auto assign_layouts = [local_client = client()](const Shape& sharded_shape,
|
||||
Shape* dst_shape) {
|
||||
return ShapeUtil::ForEachMutableSubshapeWithStatus(
|
||||
dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
|
||||
if (subshape->IsArray() && !subshape->has_layout()) {
|
||||
CHECK(ShapeUtil::IndexIsValid(sharded_shape, idx));
|
||||
const Shape& sharded_subshape =
|
||||
ShapeUtil::GetSubshape(sharded_shape, idx);
|
||||
LayoutUtil::SetToDefaultLayout(subshape);
|
||||
TF_ASSIGN_OR_RETURN(Shape layout, local_client->backend()
|
||||
.transfer_manager()
|
||||
->ChooseCompactLayoutForShape(
|
||||
sharded_subshape));
|
||||
*subshape->mutable_layout() = layout.layout();
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
};
|
||||
TF_ASSIGN_OR_RETURN(auto sharded_shapes,
|
||||
GetShardedProgramShapes(computation));
|
||||
|
||||
CHECK_EQ(sharded_shapes.first.size(), options.argument_layouts->size());
|
||||
for (int i = 0; i < options.argument_layouts->size(); ++i) {
|
||||
Shape* layout = &(*options.argument_layouts)[i];
|
||||
argument_layout_pointers.push_back(layout);
|
||||
TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.first[i], layout));
|
||||
}
|
||||
|
||||
Shape result_layout;
|
||||
if (build_options.result_layout()) {
|
||||
result_layout = *build_options.result_layout();
|
||||
} else {
|
||||
result_layout = program_shape.result();
|
||||
LayoutUtil::ClearLayout(&result_layout);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
|
||||
build_options.set_result_layout(result_layout);
|
||||
TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions(
|
||||
computation,
|
||||
[local_client = client()](Shape shape) {
|
||||
return local_client->backend()
|
||||
.transfer_manager()
|
||||
->ChooseCompactLayoutForShape(shape);
|
||||
},
|
||||
options.argument_layouts, &options.executable_build_options,
|
||||
&argument_layout_pointers));
|
||||
|
||||
// Find devices that are addressable by this client/task.
|
||||
std::vector<PjRtExecutable::LogicalDeviceIds> addressable_device_logical_ids;
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/pjrt/utils.h"
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
@ -69,13 +70,9 @@ StatusOr<Shape> GetShardedShape(const HloInstructionProto& instr) {
|
||||
return sharded_shape;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Returns sharded (argument shapes, result shape) without layouts.
|
||||
StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
||||
const XlaComputation& computation) {
|
||||
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
|
||||
computation.GetProgramShape());
|
||||
const XlaComputation& computation, const ProgramShape& program_shape) {
|
||||
std::vector<Shape> arg_shapes;
|
||||
arg_shapes.resize(program_shape.parameters_size());
|
||||
Shape result_shape;
|
||||
@ -111,6 +108,103 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
||||
}
|
||||
return std::make_pair(arg_shapes, result_shape);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status ParseDeviceAssignmentCompileOptions(
|
||||
bool compile_portable_executable, ExecutableBuildOptions* build_options,
|
||||
std::function<StatusOr<DeviceAssignment>(int, int)>
|
||||
GetDefaultDeviceAssignmentFunction,
|
||||
int* num_replicas, int* num_partitions,
|
||||
std::shared_ptr<DeviceAssignment>* device_assignment) {
|
||||
if (compile_portable_executable) {
|
||||
if (build_options->has_device_assignment()) {
|
||||
return InvalidArgument(
|
||||
"CompileOptions requests portable executable but "
|
||||
"ExecutableBuildOptions includes a device assignment");
|
||||
}
|
||||
*num_replicas = 1;
|
||||
*num_partitions = 1;
|
||||
} else {
|
||||
if (!build_options->has_device_assignment()) {
|
||||
VLOG(2) << "Compile using default device_assignment.";
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
DeviceAssignment device_assignment,
|
||||
GetDefaultDeviceAssignmentFunction(build_options->num_replicas(),
|
||||
build_options->num_partitions()));
|
||||
build_options->set_device_assignment(device_assignment);
|
||||
}
|
||||
VLOG(2) << "Compile device_assignment:\n"
|
||||
<< build_options->device_assignment().ToString();
|
||||
*num_replicas = build_options->device_assignment().replica_count();
|
||||
*num_partitions = build_options->device_assignment().computation_count();
|
||||
*device_assignment =
|
||||
std::make_shared<DeviceAssignment>(build_options->device_assignment());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DetermineArgumentLayoutsFromCompileOptions(
|
||||
const XlaComputation& computation,
|
||||
std::function<StatusOr<Shape>(Shape)>
|
||||
choose_compact_layout_for_shape_function,
|
||||
absl::optional<std::vector<Shape>>& argument_layouts,
|
||||
ExecutableBuildOptions* build_options,
|
||||
std::vector<const Shape*>* argument_layout_pointers) {
|
||||
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
|
||||
computation.GetProgramShape());
|
||||
if (!argument_layouts) {
|
||||
argument_layouts.emplace(program_shape.parameters());
|
||||
for (Shape& shape : *argument_layouts) {
|
||||
LayoutUtil::ClearLayout(&shape);
|
||||
}
|
||||
} else if (argument_layouts->size() != program_shape.parameters_size()) {
|
||||
return InvalidArgument(
|
||||
"CompileOptions specify %d argument layouts, but computation has %d "
|
||||
"arguments",
|
||||
argument_layouts->size(), program_shape.parameters_size());
|
||||
}
|
||||
argument_layout_pointers->reserve(argument_layouts->size());
|
||||
|
||||
// Assign a default layout based on `sharded_shape` to any array subshapes in
|
||||
// `dst_shape` that are missing layouts.
|
||||
auto assign_layouts = [&choose_compact_layout_for_shape_function](
|
||||
const Shape& sharded_shape, Shape* dst_shape) {
|
||||
return ShapeUtil::ForEachMutableSubshapeWithStatus(
|
||||
dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
|
||||
if (subshape->IsArray() && !subshape->has_layout()) {
|
||||
CHECK(ShapeUtil::IndexIsValid(sharded_shape, idx));
|
||||
const Shape& sharded_subshape =
|
||||
ShapeUtil::GetSubshape(sharded_shape, idx);
|
||||
LayoutUtil::SetToDefaultLayout(subshape);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape layout,
|
||||
choose_compact_layout_for_shape_function(sharded_subshape));
|
||||
*subshape->mutable_layout() = layout.layout();
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
};
|
||||
TF_ASSIGN_OR_RETURN(auto sharded_shapes,
|
||||
GetShardedProgramShapes(computation, program_shape));
|
||||
|
||||
CHECK_EQ(sharded_shapes.first.size(), argument_layouts->size());
|
||||
for (int i = 0; i < argument_layouts->size(); ++i) {
|
||||
Shape* layout = &(*argument_layouts)[i];
|
||||
argument_layout_pointers->push_back(layout);
|
||||
TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.first[i], layout));
|
||||
}
|
||||
|
||||
Shape result_layout;
|
||||
if (build_options->result_layout()) {
|
||||
result_layout = *build_options->result_layout();
|
||||
} else {
|
||||
result_layout = program_shape.result();
|
||||
LayoutUtil::ClearLayout(&result_layout);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
|
||||
build_options->set_result_layout(result_layout);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
||||
const HloModule& module, bool tuple_inputs) {
|
||||
|
@ -17,17 +17,33 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_PJRT_UTILS_H_
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Extract from XlaComputation the sharded program shapes (argument shapes,
|
||||
// result shape) without layouts.
|
||||
StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
||||
const XlaComputation& computation);
|
||||
// Returns the num_replicas, num_partitions and device assignment given a
|
||||
// ExecutableBuildOptions and whether we want a portable executable.
|
||||
Status ParseDeviceAssignmentCompileOptions(
|
||||
bool compile_portable_executable, ExecutableBuildOptions* build_options,
|
||||
std::function<StatusOr<DeviceAssignment>(int, int)>
|
||||
GetDefaultDeviceAssignmentFunction,
|
||||
int* num_replicas, int* num_partitions,
|
||||
std::shared_ptr<DeviceAssignment>* device_assignment);
|
||||
|
||||
// Returns pointers to the argument layouts given an XlaComputation and
|
||||
// ExecutableBuildOptions.
|
||||
Status DetermineArgumentLayoutsFromCompileOptions(
|
||||
const XlaComputation& computation,
|
||||
std::function<StatusOr<Shape>(Shape)>
|
||||
choose_compact_layout_for_shape_function,
|
||||
absl::optional<std::vector<Shape>>& argument_layouts,
|
||||
ExecutableBuildOptions* build_options,
|
||||
std::vector<const Shape*>* argument_layout_pointers);
|
||||
|
||||
// Executables can donate buffers so that buffers can be aliased from inputs
|
||||
// to outputs. This function returns the list of parameters that must be
|
||||
|
Loading…
x
Reference in New Issue
Block a user