277 lines
11 KiB
C++
277 lines
11 KiB
C++
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/xla/pjrt/utils.h"
|
|
|
|
#include "absl/container/flat_hash_set.h"
|
|
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
|
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
|
#include "tensorflow/compiler/xla/shape.h"
|
|
#include "tensorflow/compiler/xla/statusor.h"
|
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
|
|
|
namespace xla {
|
|
|
|
namespace {
|
|
StatusOr<Shape> GetShardedShape(const Shape& shape,
|
|
const OpSharding& sharding) {
|
|
if (sharding.type() == OpSharding::TUPLE) {
|
|
if (!shape.IsTuple()) {
|
|
return InvalidArgument(
|
|
"Got tuple OpSharding (%s) for non-tuple shape (%s)",
|
|
sharding.DebugString(), shape.ToString());
|
|
}
|
|
if (sharding.tuple_shardings_size() != shape.tuple_shapes_size()) {
|
|
return InvalidArgument(
|
|
"Got mismatched OpSharding tuple size (%d) and shape tuple size (%d)."
|
|
" (OpSharding: %s, shape: %s)",
|
|
sharding.tuple_shardings_size(), shape.tuple_shapes_size(),
|
|
sharding.DebugString(), shape.ToString());
|
|
}
|
|
std::vector<Shape> sharded_subshapes;
|
|
for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
|
|
TF_ASSIGN_OR_RETURN(
|
|
Shape sharded_subshape,
|
|
GetShardedShape(shape.tuple_shapes(i), sharding.tuple_shardings(i)));
|
|
sharded_subshapes.emplace_back(std::move(sharded_subshape));
|
|
}
|
|
return ShapeUtil::MakeTupleShape(sharded_subshapes);
|
|
}
|
|
TF_ASSIGN_OR_RETURN(HloSharding hlo_sharding,
|
|
HloSharding::FromProto(sharding));
|
|
return hlo_sharding.TileShape(shape);
|
|
}
|
|
|
|
StatusOr<Shape> GetShardedShape(const HloInstructionProto& instr) {
|
|
const Shape unsharded_shape(instr.shape());
|
|
Shape sharded_shape;
|
|
if (instr.has_sharding()) {
|
|
TF_ASSIGN_OR_RETURN(sharded_shape,
|
|
GetShardedShape(unsharded_shape, instr.sharding()));
|
|
} else {
|
|
sharded_shape = unsharded_shape;
|
|
}
|
|
LayoutUtil::ClearLayout(&sharded_shape);
|
|
return sharded_shape;
|
|
}
|
|
|
|
// Returns sharded (argument shapes, result shape) without layouts.
|
|
StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
|
const XlaComputation& computation, const ProgramShape& program_shape) {
|
|
std::vector<Shape> arg_shapes;
|
|
arg_shapes.resize(program_shape.parameters_size());
|
|
Shape result_shape;
|
|
for (const HloComputationProto& comp : computation.proto().computations()) {
|
|
if (comp.id() != computation.proto().entry_computation_id()) {
|
|
continue;
|
|
}
|
|
for (const HloInstructionProto& instr : comp.instructions()) {
|
|
if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
|
|
if (instr.parameter_number() >= program_shape.parameters_size()) {
|
|
return InvalidArgument(
|
|
"Got invalid parameter number %d, expected %d parameters",
|
|
instr.parameter_number(), program_shape.parameters_size());
|
|
}
|
|
TF_ASSIGN_OR_RETURN(arg_shapes[instr.parameter_number()],
|
|
GetShardedShape(instr));
|
|
}
|
|
if (instr.id() == comp.root_id()) {
|
|
if (result_shape.element_type() != PRIMITIVE_TYPE_INVALID) {
|
|
return InvalidArgument("Found multiple root instructions");
|
|
}
|
|
TF_ASSIGN_OR_RETURN(result_shape, GetShardedShape(instr));
|
|
}
|
|
}
|
|
}
|
|
for (int i = 0; i < arg_shapes.size(); ++i) {
|
|
if (arg_shapes[i].element_type() == PRIMITIVE_TYPE_INVALID) {
|
|
return InvalidArgument("Couldn't find parameter %d", i);
|
|
}
|
|
}
|
|
if (result_shape.element_type() == PRIMITIVE_TYPE_INVALID) {
|
|
return InvalidArgument("Couldn't find root instruction");
|
|
}
|
|
return std::make_pair(arg_shapes, result_shape);
|
|
}
|
|
} // namespace
|
|
|
|
Status ParseDeviceAssignmentCompileOptions(
|
|
bool compile_portable_executable, ExecutableBuildOptions* build_options,
|
|
std::function<StatusOr<DeviceAssignment>(int, int)>
|
|
GetDefaultDeviceAssignmentFunction,
|
|
int* num_replicas, int* num_partitions,
|
|
std::shared_ptr<DeviceAssignment>* device_assignment) {
|
|
if (compile_portable_executable) {
|
|
if (build_options->has_device_assignment()) {
|
|
return InvalidArgument(
|
|
"CompileOptions requests portable executable but "
|
|
"ExecutableBuildOptions includes a device assignment");
|
|
}
|
|
*num_replicas = 1;
|
|
*num_partitions = 1;
|
|
} else {
|
|
if (!build_options->has_device_assignment()) {
|
|
VLOG(2) << "Compile using default device_assignment.";
|
|
TF_ASSIGN_OR_RETURN(
|
|
DeviceAssignment device_assignment,
|
|
GetDefaultDeviceAssignmentFunction(build_options->num_replicas(),
|
|
build_options->num_partitions()));
|
|
build_options->set_device_assignment(device_assignment);
|
|
}
|
|
VLOG(2) << "Compile device_assignment:\n"
|
|
<< build_options->device_assignment().ToString();
|
|
*num_replicas = build_options->device_assignment().replica_count();
|
|
*num_partitions = build_options->device_assignment().computation_count();
|
|
*device_assignment =
|
|
std::make_shared<DeviceAssignment>(build_options->device_assignment());
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status DetermineArgumentLayoutsFromCompileOptions(
|
|
const XlaComputation& computation,
|
|
std::function<StatusOr<Shape>(Shape)>
|
|
choose_compact_layout_for_shape_function,
|
|
absl::optional<std::vector<Shape>>& argument_layouts,
|
|
ExecutableBuildOptions* build_options,
|
|
std::vector<const Shape*>* argument_layout_pointers) {
|
|
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
|
|
computation.GetProgramShape());
|
|
if (!argument_layouts) {
|
|
argument_layouts.emplace(program_shape.parameters());
|
|
for (Shape& shape : *argument_layouts) {
|
|
LayoutUtil::ClearLayout(&shape);
|
|
}
|
|
} else if (argument_layouts->size() != program_shape.parameters_size()) {
|
|
return InvalidArgument(
|
|
"CompileOptions specify %d argument layouts, but computation has %d "
|
|
"arguments",
|
|
argument_layouts->size(), program_shape.parameters_size());
|
|
}
|
|
argument_layout_pointers->reserve(argument_layouts->size());
|
|
|
|
// Assign a default layout based on `sharded_shape` to any array subshapes in
|
|
// `dst_shape` that are missing layouts.
|
|
auto assign_layouts = [&choose_compact_layout_for_shape_function](
|
|
const Shape& sharded_shape, Shape* dst_shape) {
|
|
return ShapeUtil::ForEachMutableSubshapeWithStatus(
|
|
dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
|
|
if (subshape->IsArray() && !subshape->has_layout()) {
|
|
CHECK(ShapeUtil::IndexIsValid(sharded_shape, idx));
|
|
const Shape& sharded_subshape =
|
|
ShapeUtil::GetSubshape(sharded_shape, idx);
|
|
LayoutUtil::SetToDefaultLayout(subshape);
|
|
TF_ASSIGN_OR_RETURN(
|
|
Shape layout,
|
|
choose_compact_layout_for_shape_function(sharded_subshape));
|
|
*subshape->mutable_layout() = layout.layout();
|
|
}
|
|
return Status::OK();
|
|
});
|
|
};
|
|
TF_ASSIGN_OR_RETURN(auto sharded_shapes,
|
|
GetShardedProgramShapes(computation, program_shape));
|
|
|
|
CHECK_EQ(sharded_shapes.first.size(), argument_layouts->size());
|
|
for (int i = 0; i < argument_layouts->size(); ++i) {
|
|
Shape* layout = &(*argument_layouts)[i];
|
|
argument_layout_pointers->push_back(layout);
|
|
TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.first[i], layout));
|
|
}
|
|
|
|
Shape result_layout;
|
|
if (build_options->result_layout()) {
|
|
result_layout = *build_options->result_layout();
|
|
} else {
|
|
result_layout = program_shape.result();
|
|
LayoutUtil::ClearLayout(&result_layout);
|
|
}
|
|
TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
|
|
build_options->set_result_layout(result_layout);
|
|
return Status::OK();
|
|
}
|
|
|
|
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
|
const HloModule& module, bool tuple_inputs) {
|
|
HloComputation* computation = module.entry_computation();
|
|
int number_of_parameters = [&]() -> int {
|
|
if (tuple_inputs) {
|
|
CHECK_EQ(computation->num_parameters(), 1);
|
|
const Shape& input_tuple_shape =
|
|
computation->parameter_instruction(0)->shape();
|
|
CHECK(input_tuple_shape.IsTuple());
|
|
return input_tuple_shape.tuple_shapes_size();
|
|
} else {
|
|
return computation->num_parameters();
|
|
}
|
|
}();
|
|
// If any buffer in a parameter is aliased we will donate the entire input
|
|
// parameter.
|
|
absl::flat_hash_set<int> parameters_to_donate;
|
|
const HloInputOutputAliasConfig& config = module.input_output_alias_config();
|
|
TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
|
|
[&](const ShapeIndex& output_index,
|
|
const HloInputOutputAliasConfig::Alias& alias) {
|
|
if (tuple_inputs) {
|
|
if (alias.parameter_number != 0) {
|
|
return InvalidArgument(
|
|
"Unexpected parameter number %d in alias config with tupled "
|
|
"inputs",
|
|
alias.parameter_number);
|
|
}
|
|
const ShapeIndex& index = alias.parameter_index;
|
|
if (!index.empty()) {
|
|
int this_parameter = index.data()[0];
|
|
if (this_parameter >= number_of_parameters) {
|
|
return InvalidArgument(
|
|
"Unexpected parameter index %s in alias config with tupled "
|
|
"inputs and %d parameters",
|
|
index.ToString(), number_of_parameters);
|
|
}
|
|
parameters_to_donate.insert(this_parameter);
|
|
}
|
|
} else {
|
|
int this_parameter = alias.parameter_number;
|
|
if (this_parameter >= number_of_parameters) {
|
|
return InvalidArgument(
|
|
"Unexpected parameter number %d in alias config without tupled "
|
|
"inputs and %d parameters",
|
|
this_parameter, number_of_parameters);
|
|
}
|
|
parameters_to_donate.insert(this_parameter);
|
|
}
|
|
return Status::OK();
|
|
}));
|
|
return parameters_to_donate;
|
|
}
|
|
|
|
int DefaultThreadPoolSize() {
|
|
// Google's CI system exposes an environment variable NPROC that describes
|
|
// a CPU reservation for tests.
|
|
// TODO(phawkins): expose a better thought-out set of knobs to control
|
|
// parallelism.
|
|
const char* nproc_str = std::getenv("NPROC");
|
|
int nproc = 0;
|
|
if (nproc_str && absl::SimpleAtoi(nproc_str, &nproc)) {
|
|
return std::max(0, nproc);
|
|
}
|
|
return tensorflow::port::MaxParallelism();
|
|
}
|
|
|
|
} // namespace xla
|