Move helper functions to PJRT utils.
PiperOrigin-RevId: 347081669 Change-Id: I2fc8669f3d0b7eb8cfe2e1922034132f04801272
This commit is contained in:
parent
4867489c44
commit
bca4ea64e4
@ -143,6 +143,22 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "utils",
|
||||
srcs = ["utils.cc"],
|
||||
hdrs = ["utils.h"],
|
||||
visibility = ["//tensorflow/compiler/xla:friends"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pjrt_stream_executor_client",
|
||||
srcs = ["pjrt_stream_executor_client.cc"],
|
||||
@ -153,6 +169,7 @@ cc_library(
|
||||
":local_device_state",
|
||||
":pjrt_client",
|
||||
":tracked_device_buffer",
|
||||
":utils",
|
||||
"//tensorflow/compiler/xla:cpu_function_runtime",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
|
@ -89,6 +89,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/pjrt/event_pool.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/utils.h"
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||
@ -1582,60 +1583,6 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
||||
const HloModule& module, bool tuple_inputs) {
|
||||
HloComputation* computation = module.entry_computation();
|
||||
int number_of_parameters = [&]() -> int {
|
||||
if (tuple_inputs) {
|
||||
CHECK_EQ(computation->num_parameters(), 1);
|
||||
const Shape& input_tuple_shape =
|
||||
computation->parameter_instruction(0)->shape();
|
||||
CHECK(input_tuple_shape.IsTuple());
|
||||
return input_tuple_shape.tuple_shapes_size();
|
||||
} else {
|
||||
return computation->num_parameters();
|
||||
}
|
||||
}();
|
||||
// If any buffer in a parameter is aliased we will donate the entire input
|
||||
// parameter.
|
||||
absl::flat_hash_set<int> parameters_to_donate;
|
||||
const HloInputOutputAliasConfig& config = module.input_output_alias_config();
|
||||
TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
|
||||
[&](const ShapeIndex& output_index,
|
||||
const HloInputOutputAliasConfig::Alias& alias) {
|
||||
if (tuple_inputs) {
|
||||
if (alias.parameter_number != 0) {
|
||||
return InvalidArgument(
|
||||
"Unexpected parameter number %d in alias config with tupled "
|
||||
"inputs",
|
||||
alias.parameter_number);
|
||||
}
|
||||
const ShapeIndex& index = alias.parameter_index;
|
||||
if (!index.empty()) {
|
||||
int this_parameter = index.data()[0];
|
||||
if (this_parameter >= number_of_parameters) {
|
||||
return InvalidArgument(
|
||||
"Unexpected parameter index %s in alias config with tupled "
|
||||
"inputs and %d parameters",
|
||||
index.ToString(), number_of_parameters);
|
||||
}
|
||||
parameters_to_donate.insert(this_parameter);
|
||||
}
|
||||
} else {
|
||||
int this_parameter = alias.parameter_number;
|
||||
if (this_parameter >= number_of_parameters) {
|
||||
return InvalidArgument(
|
||||
"Unexpected parameter number %d in alias config without tupled "
|
||||
"inputs and %d parameters",
|
||||
this_parameter, number_of_parameters);
|
||||
}
|
||||
parameters_to_donate.insert(this_parameter);
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
return parameters_to_donate;
|
||||
}
|
||||
|
||||
Status PjRtStreamExecutorExecutable::SetUpDonation(bool tuple_inputs) {
|
||||
parameters_that_must_be_donated_.reserve(executables_.size());
|
||||
for (auto& executable : executables_) {
|
||||
@ -2142,93 +2089,6 @@ PjRtStreamExecutorExecutable::GetHloModules() const {
|
||||
return std::move(modules);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
StatusOr<Shape> GetShardedShape(const Shape& shape,
|
||||
const OpSharding& sharding) {
|
||||
if (sharding.type() == OpSharding::TUPLE) {
|
||||
if (!shape.IsTuple()) {
|
||||
return InvalidArgument(
|
||||
"Got tuple OpSharding (%s) for non-tuple shape (%s)",
|
||||
sharding.DebugString(), shape.ToString());
|
||||
}
|
||||
if (sharding.tuple_shardings_size() != shape.tuple_shapes_size()) {
|
||||
return InvalidArgument(
|
||||
"Got mismatched OpSharding tuple size (%d) and shape tuple size (%d)."
|
||||
" (OpSharding: %s, shape: %s)",
|
||||
sharding.tuple_shardings_size(), shape.tuple_shapes_size(),
|
||||
sharding.DebugString(), shape.ToString());
|
||||
}
|
||||
std::vector<Shape> sharded_subshapes;
|
||||
for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape sharded_subshape,
|
||||
GetShardedShape(shape.tuple_shapes(i), sharding.tuple_shardings(i)));
|
||||
sharded_subshapes.emplace_back(std::move(sharded_subshape));
|
||||
}
|
||||
return ShapeUtil::MakeTupleShape(sharded_subshapes);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(HloSharding hlo_sharding,
|
||||
HloSharding::FromProto(sharding));
|
||||
return hlo_sharding.TileShape(shape);
|
||||
}
|
||||
|
||||
StatusOr<Shape> GetShardedShape(const HloInstructionProto& instr) {
|
||||
const Shape unsharded_shape(instr.shape());
|
||||
Shape sharded_shape;
|
||||
if (instr.has_sharding()) {
|
||||
TF_ASSIGN_OR_RETURN(sharded_shape,
|
||||
GetShardedShape(unsharded_shape, instr.sharding()));
|
||||
} else {
|
||||
sharded_shape = unsharded_shape;
|
||||
}
|
||||
LayoutUtil::ClearLayout(&sharded_shape);
|
||||
return sharded_shape;
|
||||
}
|
||||
|
||||
// Returns sharded (argument shapes, result shape) without layouts.
|
||||
StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
||||
const XlaComputation& computation) {
|
||||
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
|
||||
computation.GetProgramShape());
|
||||
std::vector<Shape> arg_shapes;
|
||||
arg_shapes.resize(program_shape.parameters_size());
|
||||
Shape result_shape;
|
||||
for (const HloComputationProto& comp : computation.proto().computations()) {
|
||||
if (comp.id() != computation.proto().entry_computation_id()) {
|
||||
continue;
|
||||
}
|
||||
for (const HloInstructionProto& instr : comp.instructions()) {
|
||||
if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
|
||||
if (instr.parameter_number() >= program_shape.parameters_size()) {
|
||||
return InvalidArgument(
|
||||
"Got invalid parameter number %d, expected %d parameters",
|
||||
instr.parameter_number(), program_shape.parameters_size());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(arg_shapes[instr.parameter_number()],
|
||||
GetShardedShape(instr));
|
||||
}
|
||||
if (instr.id() == comp.root_id()) {
|
||||
if (result_shape.element_type() != PRIMITIVE_TYPE_INVALID) {
|
||||
return InvalidArgument("Found multiple root instructions");
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(result_shape, GetShardedShape(instr));
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < arg_shapes.size(); ++i) {
|
||||
if (arg_shapes[i].element_type() == PRIMITIVE_TYPE_INVALID) {
|
||||
return InvalidArgument("Couldn't find parameter %d", i);
|
||||
}
|
||||
}
|
||||
if (result_shape.element_type() == PRIMITIVE_TYPE_INVALID) {
|
||||
return InvalidArgument("Couldn't find root instruction");
|
||||
}
|
||||
return std::make_pair(arg_shapes, result_shape);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
|
||||
const XlaComputation& computation, CompileOptions options) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
|
||||
|
169
tensorflow/compiler/xla/pjrt/utils.cc
Normal file
169
tensorflow/compiler/xla/pjrt/utils.cc
Normal file
@ -0,0 +1,169 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/pjrt/utils.h"
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
StatusOr<Shape> GetShardedShape(const Shape& shape,
|
||||
const OpSharding& sharding) {
|
||||
if (sharding.type() == OpSharding::TUPLE) {
|
||||
if (!shape.IsTuple()) {
|
||||
return InvalidArgument(
|
||||
"Got tuple OpSharding (%s) for non-tuple shape (%s)",
|
||||
sharding.DebugString(), shape.ToString());
|
||||
}
|
||||
if (sharding.tuple_shardings_size() != shape.tuple_shapes_size()) {
|
||||
return InvalidArgument(
|
||||
"Got mismatched OpSharding tuple size (%d) and shape tuple size (%d)."
|
||||
" (OpSharding: %s, shape: %s)",
|
||||
sharding.tuple_shardings_size(), shape.tuple_shapes_size(),
|
||||
sharding.DebugString(), shape.ToString());
|
||||
}
|
||||
std::vector<Shape> sharded_subshapes;
|
||||
for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape sharded_subshape,
|
||||
GetShardedShape(shape.tuple_shapes(i), sharding.tuple_shardings(i)));
|
||||
sharded_subshapes.emplace_back(std::move(sharded_subshape));
|
||||
}
|
||||
return ShapeUtil::MakeTupleShape(sharded_subshapes);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(HloSharding hlo_sharding,
|
||||
HloSharding::FromProto(sharding));
|
||||
return hlo_sharding.TileShape(shape);
|
||||
}
|
||||
|
||||
StatusOr<Shape> GetShardedShape(const HloInstructionProto& instr) {
|
||||
const Shape unsharded_shape(instr.shape());
|
||||
Shape sharded_shape;
|
||||
if (instr.has_sharding()) {
|
||||
TF_ASSIGN_OR_RETURN(sharded_shape,
|
||||
GetShardedShape(unsharded_shape, instr.sharding()));
|
||||
} else {
|
||||
sharded_shape = unsharded_shape;
|
||||
}
|
||||
LayoutUtil::ClearLayout(&sharded_shape);
|
||||
return sharded_shape;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Returns sharded (argument shapes, result shape) without layouts.
|
||||
StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
||||
const XlaComputation& computation) {
|
||||
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
|
||||
computation.GetProgramShape());
|
||||
std::vector<Shape> arg_shapes;
|
||||
arg_shapes.resize(program_shape.parameters_size());
|
||||
Shape result_shape;
|
||||
for (const HloComputationProto& comp : computation.proto().computations()) {
|
||||
if (comp.id() != computation.proto().entry_computation_id()) {
|
||||
continue;
|
||||
}
|
||||
for (const HloInstructionProto& instr : comp.instructions()) {
|
||||
if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
|
||||
if (instr.parameter_number() >= program_shape.parameters_size()) {
|
||||
return InvalidArgument(
|
||||
"Got invalid parameter number %d, expected %d parameters",
|
||||
instr.parameter_number(), program_shape.parameters_size());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(arg_shapes[instr.parameter_number()],
|
||||
GetShardedShape(instr));
|
||||
}
|
||||
if (instr.id() == comp.root_id()) {
|
||||
if (result_shape.element_type() != PRIMITIVE_TYPE_INVALID) {
|
||||
return InvalidArgument("Found multiple root instructions");
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(result_shape, GetShardedShape(instr));
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < arg_shapes.size(); ++i) {
|
||||
if (arg_shapes[i].element_type() == PRIMITIVE_TYPE_INVALID) {
|
||||
return InvalidArgument("Couldn't find parameter %d", i);
|
||||
}
|
||||
}
|
||||
if (result_shape.element_type() == PRIMITIVE_TYPE_INVALID) {
|
||||
return InvalidArgument("Couldn't find root instruction");
|
||||
}
|
||||
return std::make_pair(arg_shapes, result_shape);
|
||||
}
|
||||
|
||||
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
||||
const HloModule& module, bool tuple_inputs) {
|
||||
HloComputation* computation = module.entry_computation();
|
||||
int number_of_parameters = [&]() -> int {
|
||||
if (tuple_inputs) {
|
||||
CHECK_EQ(computation->num_parameters(), 1);
|
||||
const Shape& input_tuple_shape =
|
||||
computation->parameter_instruction(0)->shape();
|
||||
CHECK(input_tuple_shape.IsTuple());
|
||||
return input_tuple_shape.tuple_shapes_size();
|
||||
} else {
|
||||
return computation->num_parameters();
|
||||
}
|
||||
}();
|
||||
// If any buffer in a parameter is aliased we will donate the entire input
|
||||
// parameter.
|
||||
absl::flat_hash_set<int> parameters_to_donate;
|
||||
const HloInputOutputAliasConfig& config = module.input_output_alias_config();
|
||||
TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
|
||||
[&](const ShapeIndex& output_index,
|
||||
const HloInputOutputAliasConfig::Alias& alias) {
|
||||
if (tuple_inputs) {
|
||||
if (alias.parameter_number != 0) {
|
||||
return InvalidArgument(
|
||||
"Unexpected parameter number %d in alias config with tupled "
|
||||
"inputs",
|
||||
alias.parameter_number);
|
||||
}
|
||||
const ShapeIndex& index = alias.parameter_index;
|
||||
if (!index.empty()) {
|
||||
int this_parameter = index.data()[0];
|
||||
if (this_parameter >= number_of_parameters) {
|
||||
return InvalidArgument(
|
||||
"Unexpected parameter index %s in alias config with tupled "
|
||||
"inputs and %d parameters",
|
||||
index.ToString(), number_of_parameters);
|
||||
}
|
||||
parameters_to_donate.insert(this_parameter);
|
||||
}
|
||||
} else {
|
||||
int this_parameter = alias.parameter_number;
|
||||
if (this_parameter >= number_of_parameters) {
|
||||
return InvalidArgument(
|
||||
"Unexpected parameter number %d in alias config without tupled "
|
||||
"inputs and %d parameters",
|
||||
this_parameter, number_of_parameters);
|
||||
}
|
||||
parameters_to_donate.insert(this_parameter);
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
return parameters_to_donate;
|
||||
}
|
||||
|
||||
} // namespace xla
|
41
tensorflow/compiler/xla/pjrt/utils.h
Normal file
41
tensorflow/compiler/xla/pjrt/utils.h
Normal file
@ -0,0 +1,41 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PJRT_UTILS_H_
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Extract from XlaComputation the sharded program shapes (argument shapes,
|
||||
// result shape) without layouts.
|
||||
StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
||||
const XlaComputation& computation);
|
||||
|
||||
// Executables can donate buffers so that buffers can be aliased from inputs
|
||||
// to outputs. This function returns the list of parameters that must be
|
||||
// donated when executable is run. tuple_inputs reflects the option that
|
||||
// executable was compiled with.
|
||||
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
||||
const HloModule& module, bool tuple_inputs);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_UTILS_H_
|
Loading…
x
Reference in New Issue
Block a user