[XLA] Allow CustomCall to specify aliasing buffers
PiperOrigin-RevId: 333464460 Change-Id: I666c7a7fade6c9925cd8d2c7dd4673b897a38034
This commit is contained in:
parent
55b546de0e
commit
1b5b9486b0
@ -135,16 +135,12 @@ StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
|
|||||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||||
const Shape& shape, const string& opaque,
|
const Shape& shape, const string& opaque,
|
||||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||||
bool has_side_effect,
|
bool has_side_effect) {
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
||||||
output_operand_aliasing) {
|
|
||||||
if (operand_shapes_with_layout.has_value())
|
if (operand_shapes_with_layout.has_value())
|
||||||
return Unimplemented(
|
return Unimplemented(
|
||||||
"CustomCall doesn't support operands shapes with layout");
|
"CustomCall doesn't support operands shapes with layout");
|
||||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||||
shape, builder_));
|
shape, builder_));
|
||||||
TF_RET_CHECK(output_operand_aliasing.empty())
|
|
||||||
<< "MLIR CustomCallOp does not support output_operand_aliasing yet";
|
|
||||||
auto op = builder_.create<mlir::mhlo::CustomCallOp>(
|
auto op = builder_.create<mlir::mhlo::CustomCallOp>(
|
||||||
loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
|
loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
|
||||||
/*has_side_effect=*/builder_.getBoolAttr(has_side_effect),
|
/*has_side_effect=*/builder_.getBoolAttr(has_side_effect),
|
||||||
|
@ -135,9 +135,7 @@ class MlirHloBuilder : public XlaBuilder {
|
|||||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||||
const Shape& shape, const string& opaque,
|
const Shape& shape, const string& opaque,
|
||||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||||
bool has_side_effect,
|
bool has_side_effect) override;
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
||||||
output_operand_aliasing) override;
|
|
||||||
|
|
||||||
StatusOr<XlaOp> ReduceInternal(
|
StatusOr<XlaOp> ReduceInternal(
|
||||||
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
||||||
|
@ -1707,9 +1707,7 @@ XlaOp XlaBuilder::CustomCall(
|
|||||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||||
const Shape& shape, const string& opaque,
|
const Shape& shape, const string& opaque,
|
||||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||||
bool has_side_effect,
|
bool has_side_effect) {
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
||||||
output_operand_aliasing) {
|
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
if (absl::StartsWith(call_target_name, "$")) {
|
if (absl::StartsWith(call_target_name, "$")) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
@ -1741,8 +1739,7 @@ XlaOp XlaBuilder::CustomCall(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
return CustomCallInternal(call_target_name, operands, shape, opaque,
|
return CustomCallInternal(call_target_name, operands, shape, opaque,
|
||||||
operand_shapes_with_layout, has_side_effect,
|
operand_shapes_with_layout, has_side_effect);
|
||||||
output_operand_aliasing);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1750,9 +1747,7 @@ StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
|
|||||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||||
const Shape& shape, const string& opaque,
|
const Shape& shape, const string& opaque,
|
||||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||||
bool has_side_effect,
|
bool has_side_effect) {
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
||||||
output_operand_aliasing) {
|
|
||||||
HloInstructionProto instr;
|
HloInstructionProto instr;
|
||||||
*instr.mutable_shape() = shape.ToProto();
|
*instr.mutable_shape() = shape.ToProto();
|
||||||
instr.set_custom_call_target(call_target_name);
|
instr.set_custom_call_target(call_target_name);
|
||||||
@ -1764,16 +1759,6 @@ StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
instr.set_custom_call_has_side_effect(has_side_effect);
|
instr.set_custom_call_has_side_effect(has_side_effect);
|
||||||
for (const auto& pair : output_operand_aliasing) {
|
|
||||||
auto aliasing = instr.add_custom_call_output_operand_aliasing();
|
|
||||||
aliasing->set_operand_index(pair.second.first);
|
|
||||||
for (int64 index : pair.second.second) {
|
|
||||||
aliasing->add_operand_shape_index(index);
|
|
||||||
}
|
|
||||||
for (int64 index : pair.first) {
|
|
||||||
aliasing->add_output_shape_index(index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
|
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1781,9 +1766,7 @@ XlaOp XlaBuilder::CustomCall(
|
|||||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||||
const XlaComputation& computation, const Shape& shape, const string& opaque,
|
const XlaComputation& computation, const Shape& shape, const string& opaque,
|
||||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||||
bool has_side_effect,
|
bool has_side_effect) {
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
||||||
output_operand_aliasing) {
|
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
HloInstructionProto instr;
|
HloInstructionProto instr;
|
||||||
if (absl::StartsWith(call_target_name, "$")) {
|
if (absl::StartsWith(call_target_name, "$")) {
|
||||||
@ -1821,16 +1804,6 @@ XlaOp XlaBuilder::CustomCall(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
AddCalledComputation(computation, &instr);
|
AddCalledComputation(computation, &instr);
|
||||||
for (const auto& pair : output_operand_aliasing) {
|
|
||||||
auto aliasing = instr.add_custom_call_output_operand_aliasing();
|
|
||||||
aliasing->set_operand_index(pair.second.first);
|
|
||||||
for (int64 index : pair.second.second) {
|
|
||||||
aliasing->add_operand_shape_index(index);
|
|
||||||
}
|
|
||||||
for (int64 index : pair.first) {
|
|
||||||
aliasing->add_output_shape_index(index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
|
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -3888,39 +3861,31 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
|
|||||||
return builder->Call(computation, operands);
|
return builder->Call(computation, operands);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp CustomCall(
|
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
|
||||||
XlaBuilder* builder, const string& call_target_name,
|
absl::Span<const XlaOp> operands, const Shape& shape,
|
||||||
absl::Span<const XlaOp> operands, const Shape& shape, const string& opaque,
|
const string& opaque, bool has_side_effect) {
|
||||||
bool has_side_effect,
|
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
||||||
output_operand_aliasing) {
|
|
||||||
return builder->CustomCall(call_target_name, operands, shape, opaque,
|
return builder->CustomCall(call_target_name, operands, shape, opaque,
|
||||||
/*operand_shapes_with_layout=*/absl::nullopt,
|
/*operand_shapes_with_layout=*/absl::nullopt,
|
||||||
has_side_effect, output_operand_aliasing);
|
has_side_effect);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp CustomCallWithComputation(
|
XlaOp CustomCallWithComputation(XlaBuilder* builder,
|
||||||
XlaBuilder* builder, const string& call_target_name,
|
const string& call_target_name,
|
||||||
absl::Span<const XlaOp> operands, const XlaComputation& computation,
|
absl::Span<const XlaOp> operands,
|
||||||
const Shape& shape, const string& opaque, bool has_side_effect,
|
const XlaComputation& computation,
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
const Shape& shape, const string& opaque,
|
||||||
output_operand_aliasing) {
|
bool has_side_effect) {
|
||||||
return builder->CustomCall(call_target_name, operands, computation, shape,
|
return builder->CustomCall(
|
||||||
opaque,
|
call_target_name, operands, computation, shape, opaque,
|
||||||
/*operand_shapes_with_layout=*/absl::nullopt,
|
/*operand_shapes_with_layout=*/absl::nullopt, has_side_effect);
|
||||||
has_side_effect, output_operand_aliasing);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp CustomCallWithLayout(
|
XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
|
||||||
XlaBuilder* builder, const string& call_target_name,
|
absl::Span<const XlaOp> operands, const Shape& shape,
|
||||||
absl::Span<const XlaOp> operands, const Shape& shape,
|
absl::Span<const Shape> operand_shapes_with_layout,
|
||||||
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
|
const string& opaque, bool has_side_effect) {
|
||||||
bool has_side_effect,
|
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
||||||
output_operand_aliasing) {
|
|
||||||
return builder->CustomCall(call_target_name, operands, shape, opaque,
|
return builder->CustomCall(call_target_name, operands, shape, opaque,
|
||||||
operand_shapes_with_layout, has_side_effect,
|
operand_shapes_with_layout, has_side_effect);
|
||||||
output_operand_aliasing);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
|
XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
|
||||||
|
@ -593,9 +593,7 @@ class XlaBuilder {
|
|||||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||||
const Shape& shape_with_layout, const string& opaque,
|
const Shape& shape_with_layout, const string& opaque,
|
||||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||||
bool has_side_effect,
|
bool has_side_effect);
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
||||||
output_operand_aliasing);
|
|
||||||
|
|
||||||
// Internal version of CustomCall without computation that doesn't do op
|
// Internal version of CustomCall without computation that doesn't do op
|
||||||
// specific error handling and expects arguments to be legal. CustomCall
|
// specific error handling and expects arguments to be legal. CustomCall
|
||||||
@ -604,18 +602,14 @@ class XlaBuilder {
|
|||||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||||
const Shape& shape_with_layout, const string& opaque,
|
const Shape& shape_with_layout, const string& opaque,
|
||||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||||
bool has_side_effect,
|
bool has_side_effect);
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
||||||
output_operand_aliasing);
|
|
||||||
|
|
||||||
XlaOp CustomCall(
|
XlaOp CustomCall(
|
||||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||||
const XlaComputation& computation, const Shape& shape_with_layout,
|
const XlaComputation& computation, const Shape& shape_with_layout,
|
||||||
const string& opaque,
|
const string& opaque,
|
||||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||||
bool has_side_effect,
|
bool has_side_effect);
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
||||||
output_operand_aliasing);
|
|
||||||
|
|
||||||
XlaOp Reduce(XlaOp operand, XlaOp init_value,
|
XlaOp Reduce(XlaOp operand, XlaOp init_value,
|
||||||
const XlaComputation& computation,
|
const XlaComputation& computation,
|
||||||
@ -1064,25 +1058,18 @@ class XlaBuilder {
|
|||||||
const string& outfeed_config);
|
const string& outfeed_config);
|
||||||
friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
|
friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
|
||||||
absl::Span<const XlaOp> operands);
|
absl::Span<const XlaOp> operands);
|
||||||
friend XlaOp CustomCall(
|
friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
|
||||||
XlaBuilder* builder, const string& call_target_name,
|
absl::Span<const XlaOp> operands, const Shape& shape,
|
||||||
absl::Span<const XlaOp> operands, const Shape& shape,
|
const string& opaque, bool has_side_effect);
|
||||||
const string& opaque, bool has_side_effect,
|
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
||||||
output_operand_aliasing);
|
|
||||||
friend XlaOp CustomCallWithComputation(
|
friend XlaOp CustomCallWithComputation(
|
||||||
XlaBuilder* builder, const string& call_target_name,
|
XlaBuilder* builder, const string& call_target_name,
|
||||||
absl::Span<const XlaOp> operands, const XlaComputation& computation,
|
absl::Span<const XlaOp> operands, const XlaComputation& computation,
|
||||||
const Shape& shape, const string& opaque, bool has_side_effect,
|
const Shape& shape, const string& opaque, bool has_side_effect);
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
||||||
output_operand_aliasing);
|
|
||||||
friend XlaOp CustomCallWithLayout(
|
friend XlaOp CustomCallWithLayout(
|
||||||
XlaBuilder* builder, const string& call_target_name,
|
XlaBuilder* builder, const string& call_target_name,
|
||||||
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
|
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
|
||||||
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
|
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
|
||||||
bool has_side_effect,
|
bool has_side_effect);
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
||||||
output_operand_aliasing);
|
|
||||||
friend XlaOp Complex(XlaOp real, XlaOp imag,
|
friend XlaOp Complex(XlaOp real, XlaOp imag,
|
||||||
absl::Span<const int64> broadcast_dimensions);
|
absl::Span<const int64> broadcast_dimensions);
|
||||||
friend XlaOp Conj(XlaOp operand);
|
friend XlaOp Conj(XlaOp operand);
|
||||||
@ -1818,39 +1805,30 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
|
|||||||
// backend, a call instruction is emitted which targets a symbol with the name
|
// backend, a call instruction is emitted which targets a symbol with the name
|
||||||
// |call_target_name|. |call_target_name| and |opaque| can arbitrary strings,
|
// |call_target_name|. |call_target_name| and |opaque| can arbitrary strings,
|
||||||
// but |call_target_name| should be short as it may be used in labels. |opaque|
|
// but |call_target_name| should be short as it may be used in labels. |opaque|
|
||||||
// can encode arbitrarily large amounts of information. |has_side_effect|
|
// can encode arbitrarily large amounts of information.
|
||||||
// specifies whether the instruction can have side effects.
|
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
|
||||||
// |output_operand_aliasing| specifies a list of output/operand buffer pairs
|
absl::Span<const XlaOp> operands, const Shape& shape,
|
||||||
// that alias each other, where the output buffer is represented as a
|
const string& opaque = "", bool has_side_effect = false);
|
||||||
// ShapeIndex, and the operand buffer is represented as the operand index and
|
|
||||||
// the ShapeIndex.
|
|
||||||
XlaOp CustomCall(
|
|
||||||
XlaBuilder* builder, const string& call_target_name,
|
|
||||||
absl::Span<const XlaOp> operands, const Shape& shape,
|
|
||||||
const string& opaque = "", bool has_side_effect = false,
|
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
||||||
output_operand_aliasing = {});
|
|
||||||
|
|
||||||
// Overload which constructs a custom call that applies an Xla computation.
|
// Overload which constructs a custom call that applies an Xla computation.
|
||||||
XlaOp CustomCallWithComputation(
|
XlaOp CustomCallWithComputation(XlaBuilder* builder,
|
||||||
XlaBuilder* builder, const string& call_target_name,
|
const string& call_target_name,
|
||||||
absl::Span<const XlaOp> operands, const XlaComputation& computation,
|
absl::Span<const XlaOp> operands,
|
||||||
const Shape& shape, const string& opaque = "", bool has_side_effect = false,
|
const XlaComputation& computation,
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
const Shape& shape, const string& opaque = "",
|
||||||
output_operand_aliasing = {});
|
bool has_side_effect = false);
|
||||||
|
|
||||||
// Overload which constructs a custom call with fixed layouts. The operands will
|
// Overload which constructs a custom call with fixed layouts. The operands will
|
||||||
// have the layouts specified by |operand_shapes_with_layout| when provided to
|
// have the layouts specified by |operand_shapes_with_layout| when provided to
|
||||||
// external code, and the external code is expected to produce a result with the
|
// external code, and the external code is expected to produce a result with the
|
||||||
// layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
|
// layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
|
||||||
// and |operand_shapes_with_layout| must have layouts.
|
// and |operand_shapes_with_layout| must have layouts.
|
||||||
XlaOp CustomCallWithLayout(
|
XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
|
||||||
XlaBuilder* builder, const string& call_target_name,
|
absl::Span<const XlaOp> operands,
|
||||||
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
|
const Shape& shape_with_layout,
|
||||||
absl::Span<const Shape> operand_shapes_with_layout,
|
absl::Span<const Shape> operand_shapes_with_layout,
|
||||||
const string& opaque = "", bool has_side_effect = false,
|
const string& opaque = "",
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
bool has_side_effect = false);
|
||||||
output_operand_aliasing = {});
|
|
||||||
|
|
||||||
// The following methods enqueue element-wise binary arithmetic operations
|
// The following methods enqueue element-wise binary arithmetic operations
|
||||||
// onto the computation. The shapes of the operands have to match unless one
|
// onto the computation. The shapes of the operands have to match unless one
|
||||||
|
@ -3492,8 +3492,6 @@ cc_library(
|
|||||||
hdrs = ["memory_space_assignment_utils.h"],
|
hdrs = ["memory_space_assignment_utils.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":heap_simulator",
|
":heap_simulator",
|
||||||
":hlo",
|
|
||||||
":hlo_casting_utils",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
|
|||||||
option cc_enable_arenas = true;
|
option cc_enable_arenas = true;
|
||||||
|
|
||||||
// Serialization of HloInstruction.
|
// Serialization of HloInstruction.
|
||||||
// Next ID: 75
|
// Next ID: 74
|
||||||
message HloInstructionProto {
|
message HloInstructionProto {
|
||||||
reserved 10;
|
reserved 10;
|
||||||
reserved "parameter_name";
|
reserved "parameter_name";
|
||||||
@ -232,11 +232,6 @@ message HloInstructionProto {
|
|||||||
// kCustomCall.
|
// kCustomCall.
|
||||||
bool custom_call_has_side_effect = 65;
|
bool custom_call_has_side_effect = 65;
|
||||||
|
|
||||||
// A list of CustomCallOutputOperandAliasing pairs that specifies aliasing
|
|
||||||
// buffers between output and operands for kCustomCall.
|
|
||||||
repeated xla.CustomCallOutputOperandAliasing
|
|
||||||
custom_call_output_operand_aliasing = 74;
|
|
||||||
|
|
||||||
// The delta value for kRngGetAndUpdateState.
|
// The delta value for kRngGetAndUpdateState.
|
||||||
int64 delta = 66;
|
int64 delta = 66;
|
||||||
|
|
||||||
|
@ -432,23 +432,6 @@ bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
|
|||||||
return changed;
|
return changed;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HloDataflowAnalysis::UpdateCustomCallValueSet(
|
|
||||||
HloInstruction* custom_call) {
|
|
||||||
CHECK_EQ(custom_call->opcode(), HloOpcode::kCustomCall);
|
|
||||||
bool changed = false;
|
|
||||||
for (const auto& aliasing : Cast<HloCustomCallInstruction>(custom_call)
|
|
||||||
->output_to_operand_aliasing()) {
|
|
||||||
const HloValueSet& operand_value_set = GetValueSet(
|
|
||||||
custom_call->operand(aliasing.second.first), aliasing.second.second);
|
|
||||||
HloValueSet& value_set = GetValueSet(custom_call, aliasing.first);
|
|
||||||
if (value_set != operand_value_set) {
|
|
||||||
value_set = operand_value_set;
|
|
||||||
changed = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return changed;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool HloDataflowAnalysis::UpdateCopyStartValueSet(HloInstruction* copy_start) {
|
bool HloDataflowAnalysis::UpdateCopyStartValueSet(HloInstruction* copy_start) {
|
||||||
CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart);
|
CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart);
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
@ -774,8 +757,6 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
|
|||||||
return UpdateAddDependencyValueSet(instruction);
|
return UpdateAddDependencyValueSet(instruction);
|
||||||
case HloOpcode::kBitcast:
|
case HloOpcode::kBitcast:
|
||||||
return UpdateBitcastValueSet(instruction);
|
return UpdateBitcastValueSet(instruction);
|
||||||
case HloOpcode::kCustomCall:
|
|
||||||
return UpdateCustomCallValueSet(instruction);
|
|
||||||
case HloOpcode::kSetDimensionSize:
|
case HloOpcode::kSetDimensionSize:
|
||||||
return UpdateSetDimensionSizeValueSet(instruction);
|
return UpdateSetDimensionSizeValueSet(instruction);
|
||||||
case HloOpcode::kDomain:
|
case HloOpcode::kDomain:
|
||||||
@ -1037,22 +1018,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
|||||||
define_value_at(/*index=*/{1});
|
define_value_at(/*index=*/{1});
|
||||||
define_value_at(/*index=*/{2});
|
define_value_at(/*index=*/{2});
|
||||||
break;
|
break;
|
||||||
case HloOpcode::kCustomCall: {
|
|
||||||
absl::flat_hash_set<ShapeIndex> aliasing_indices;
|
|
||||||
for (const auto& aliasing :
|
|
||||||
Cast<HloCustomCallInstruction>(instruction)
|
|
||||||
->output_to_operand_aliasing()) {
|
|
||||||
aliasing_indices.insert(aliasing.first);
|
|
||||||
}
|
|
||||||
ShapeUtil::ForEachSubshape(
|
|
||||||
instruction->shape(),
|
|
||||||
[&](const Shape& /*subshape*/, const ShapeIndex& index) {
|
|
||||||
if (!aliasing_indices.contains(index)) {
|
|
||||||
define_value_at(index);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
define_all_values();
|
define_all_values();
|
||||||
break;
|
break;
|
||||||
|
@ -216,7 +216,6 @@ class HloDataflowAnalysis {
|
|||||||
bool UpdateCallValueSet(HloInstruction* call);
|
bool UpdateCallValueSet(HloInstruction* call);
|
||||||
bool UpdateConditionalValueSet(HloInstruction* conditional);
|
bool UpdateConditionalValueSet(HloInstruction* conditional);
|
||||||
bool UpdateCopyValueSet(HloInstruction* copy);
|
bool UpdateCopyValueSet(HloInstruction* copy);
|
||||||
bool UpdateCustomCallValueSet(HloInstruction* custom_call);
|
|
||||||
bool UpdateDomainValueSet(HloInstruction* domain);
|
bool UpdateDomainValueSet(HloInstruction* domain);
|
||||||
bool UpdateGetTupleElementValueSet(HloInstruction* gte);
|
bool UpdateGetTupleElementValueSet(HloInstruction* gte);
|
||||||
bool UpdateParameterValueSet(HloInstruction* parameter);
|
bool UpdateParameterValueSet(HloInstruction* parameter);
|
||||||
|
@ -568,19 +568,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
std::max(static_cast<int64>(proto.batch_group_count()), int64{1}));
|
std::max(static_cast<int64>(proto.batch_group_count()), int64{1}));
|
||||||
custom_call_instr->set_custom_call_has_side_effect(
|
custom_call_instr->set_custom_call_has_side_effect(
|
||||||
proto.custom_call_has_side_effect());
|
proto.custom_call_has_side_effect());
|
||||||
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
||||||
output_to_operand_aliasing;
|
|
||||||
for (const auto& aliasing : proto.custom_call_output_operand_aliasing()) {
|
|
||||||
output_to_operand_aliasing.emplace_back(
|
|
||||||
ShapeIndex(aliasing.output_shape_index().begin(),
|
|
||||||
aliasing.output_shape_index().end()),
|
|
||||||
std::pair<int64, ShapeIndex>{
|
|
||||||
aliasing.operand_index(),
|
|
||||||
ShapeIndex(aliasing.operand_shape_index().begin(),
|
|
||||||
aliasing.operand_shape_index().end())});
|
|
||||||
}
|
|
||||||
custom_call_instr->set_output_to_operand_aliasing(
|
|
||||||
std::move(output_to_operand_aliasing));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case HloOpcode::kPad:
|
case HloOpcode::kPad:
|
||||||
|
@ -2395,16 +2395,6 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
proto.set_custom_call_has_side_effect(custom_call_has_side_effect_);
|
proto.set_custom_call_has_side_effect(custom_call_has_side_effect_);
|
||||||
for (const auto& pair : output_to_operand_aliasing_) {
|
|
||||||
auto aliasing = proto.add_custom_call_output_operand_aliasing();
|
|
||||||
aliasing->set_operand_index(pair.second.first);
|
|
||||||
for (int64 index : pair.first) {
|
|
||||||
aliasing->add_output_shape_index(index);
|
|
||||||
}
|
|
||||||
for (int64 index : pair.second.second) {
|
|
||||||
aliasing->add_operand_shape_index(index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return proto;
|
return proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2442,16 +2432,6 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
|
|||||||
if (custom_call_has_side_effect_) {
|
if (custom_call_has_side_effect_) {
|
||||||
extra.push_back("custom_call_has_side_effect=true");
|
extra.push_back("custom_call_has_side_effect=true");
|
||||||
}
|
}
|
||||||
if (!output_to_operand_aliasing_.empty()) {
|
|
||||||
std::vector<string> pair_strings;
|
|
||||||
for (const auto& pair : output_to_operand_aliasing_) {
|
|
||||||
pair_strings.push_back(StrCat(pair.first.ToString(), ": (",
|
|
||||||
pair.second.first, ", ",
|
|
||||||
pair.second.second.ToString(), ")"));
|
|
||||||
}
|
|
||||||
extra.push_back(StrCat("output_to_operand_aliasing={",
|
|
||||||
StrJoin(pair_strings, ", "), "}"));
|
|
||||||
}
|
|
||||||
return extra;
|
return extra;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2495,10 +2475,6 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
|
|||||||
casted_other.custom_call_has_side_effect()) {
|
casted_other.custom_call_has_side_effect()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (output_to_operand_aliasing_ !=
|
|
||||||
casted_other.output_to_operand_aliasing()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// Note: backend_config comparison is done in Identical, which is the
|
// Note: backend_config comparison is done in Identical, which is the
|
||||||
// intended/exposed way to compare computations, and so not repeated here.
|
// intended/exposed way to compare computations, and so not repeated here.
|
||||||
return custom_call_target_ == casted_other.custom_call_target_;
|
return custom_call_target_ == casted_other.custom_call_target_;
|
||||||
@ -2523,7 +2499,6 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
|
|||||||
cloned->set_feature_group_count(feature_group_count_);
|
cloned->set_feature_group_count(feature_group_count_);
|
||||||
cloned->set_batch_group_count(batch_group_count_);
|
cloned->set_batch_group_count(batch_group_count_);
|
||||||
cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
|
cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
|
||||||
cloned->set_output_to_operand_aliasing(output_to_operand_aliasing_);
|
|
||||||
return std::move(cloned);
|
return std::move(cloned);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1430,20 +1430,6 @@ class HloCustomCallInstruction : public HloInstruction {
|
|||||||
CHECK(layout_constrained());
|
CHECK(layout_constrained());
|
||||||
return operand_shapes_with_layout_;
|
return operand_shapes_with_layout_;
|
||||||
}
|
}
|
||||||
// Gets a list of output/operand buffer pairs that alias each other, where the
|
|
||||||
// output buffer is represented as a ShapeIndex, and the operand buffer is
|
|
||||||
// represented as the operand index and the ShapeIndex. By default this list
|
|
||||||
// is empty.
|
|
||||||
const std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>&
|
|
||||||
output_to_operand_aliasing() const {
|
|
||||||
return output_to_operand_aliasing_;
|
|
||||||
}
|
|
||||||
// Sets the list of output/operand buffer pairs that alias each other.
|
|
||||||
void set_output_to_operand_aliasing(
|
|
||||||
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
||||||
aliasing) {
|
|
||||||
output_to_operand_aliasing_ = std::move(aliasing);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<string> ExtraAttributesToStringImpl(
|
std::vector<string> ExtraAttributesToStringImpl(
|
||||||
@ -1472,10 +1458,6 @@ class HloCustomCallInstruction : public HloInstruction {
|
|||||||
std::vector<Shape> operand_shapes_with_layout_;
|
std::vector<Shape> operand_shapes_with_layout_;
|
||||||
// Whether this custom call has a side-effect.
|
// Whether this custom call has a side-effect.
|
||||||
bool custom_call_has_side_effect_;
|
bool custom_call_has_side_effect_;
|
||||||
// A list of output/operand buffer pairs that alias each other. See comment of
|
|
||||||
// output_to_operand_aliasing().
|
|
||||||
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
||||||
output_to_operand_aliasing_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class HloPadInstruction : public HloInstruction {
|
class HloPadInstruction : public HloInstruction {
|
||||||
|
@ -212,7 +212,6 @@ class HloParserImpl : public HloParser {
|
|||||||
kEnum,
|
kEnum,
|
||||||
kRandomAlgorithm,
|
kRandomAlgorithm,
|
||||||
kAliasing,
|
kAliasing,
|
||||||
kInstructionAliasing,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct AttrConfig {
|
struct AttrConfig {
|
||||||
@ -347,12 +346,6 @@ class HloParserImpl : public HloParser {
|
|||||||
// fails.
|
// fails.
|
||||||
bool ParseAliasing(AliasingData* data);
|
bool ParseAliasing(AliasingData* data);
|
||||||
|
|
||||||
// Parses the per-instruction aliasing information from string `s`, returns
|
|
||||||
// `false` if it fails.
|
|
||||||
bool ParseInstructionOutputOperandAliasing(
|
|
||||||
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>*
|
|
||||||
aliasing_output_operand_pairs);
|
|
||||||
|
|
||||||
bool ParseShapeIndex(ShapeIndex* out);
|
bool ParseShapeIndex(ShapeIndex* out);
|
||||||
|
|
||||||
// Returns true if the current token is the beginning of a shape.
|
// Returns true if the current token is the beginning of a shape.
|
||||||
@ -605,58 +598,6 @@ bool HloParserImpl::ParseAliasing(AliasingData* data) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HloParserImpl::ParseInstructionOutputOperandAliasing(
|
|
||||||
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>*
|
|
||||||
aliasing_output_operand_pairs) {
|
|
||||||
if (!ParseToken(
|
|
||||||
TokKind::kLbrace,
|
|
||||||
"Expects '{' at the start of instruction aliasing description")) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
while (lexer_.GetKind() != TokKind::kRbrace) {
|
|
||||||
ShapeIndex out;
|
|
||||||
if (!ParseShapeIndex(&out)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
std::string errmsg =
|
|
||||||
"Expected format: <output_shape_index>: (<operand_index>, "
|
|
||||||
"<operand_shape_index>)";
|
|
||||||
if (!ParseToken(TokKind::kColon, errmsg)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!ParseToken(TokKind::kLparen, errmsg)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
int64 operand_index;
|
|
||||||
ParseInt64(&operand_index);
|
|
||||||
if (!ParseToken(TokKind::kComma, errmsg)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
ShapeIndex operand_shape_index;
|
|
||||||
if (!ParseShapeIndex(&operand_shape_index)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
aliasing_output_operand_pairs->emplace_back(
|
|
||||||
out, std::pair<int64, ShapeIndex>{operand_index, operand_shape_index});
|
|
||||||
if (!ParseToken(TokKind::kRparen, errmsg)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!EatIfPresent(TokKind::kComma)) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!ParseToken(
|
|
||||||
TokKind::kRbrace,
|
|
||||||
"Expects '}' at the end of instruction aliasing description")) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ::= 'HloModule' name computations
|
// ::= 'HloModule' name computations
|
||||||
bool HloParserImpl::ParseHloModule(HloModule* module) {
|
bool HloParserImpl::ParseHloModule(HloModule* module) {
|
||||||
if (lexer_.GetKind() != TokKind::kw_HloModule) {
|
if (lexer_.GetKind() != TokKind::kw_HloModule) {
|
||||||
@ -1836,8 +1777,6 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
optional<std::vector<Shape>> operand_layout_constraints;
|
optional<std::vector<Shape>> operand_layout_constraints;
|
||||||
optional<bool> custom_call_has_side_effect;
|
optional<bool> custom_call_has_side_effect;
|
||||||
optional<HloComputation*> to_apply;
|
optional<HloComputation*> to_apply;
|
||||||
optional<std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>
|
|
||||||
output_to_operand_aliasing;
|
|
||||||
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
|
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
|
||||||
&custom_call_target};
|
&custom_call_target};
|
||||||
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
|
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
|
||||||
@ -1853,9 +1792,6 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
&custom_call_has_side_effect};
|
&custom_call_has_side_effect};
|
||||||
attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
|
attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
|
||||||
&to_apply};
|
&to_apply};
|
||||||
attrs["output_to_operand_aliasing"] = {/*required=*/false,
|
|
||||||
AttrTy::kInstructionAliasing,
|
|
||||||
&output_to_operand_aliasing};
|
|
||||||
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
|
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -1925,10 +1861,6 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
custom_call_instr->set_custom_call_has_side_effect(
|
custom_call_instr->set_custom_call_has_side_effect(
|
||||||
*custom_call_has_side_effect);
|
*custom_call_has_side_effect);
|
||||||
}
|
}
|
||||||
if (output_to_operand_aliasing.has_value()) {
|
|
||||||
custom_call_instr->set_output_to_operand_aliasing(
|
|
||||||
std::move(*output_to_operand_aliasing));
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case HloOpcode::kDot: {
|
case HloOpcode::kDot: {
|
||||||
@ -3291,19 +3223,6 @@ bool HloParserImpl::ParseAttributeHelper(
|
|||||||
->emplace(aliasing_data);
|
->emplace(aliasing_data);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
case AttrTy::kInstructionAliasing: {
|
|
||||||
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
||||||
aliasing_output_operand_pairs;
|
|
||||||
if (!ParseInstructionOutputOperandAliasing(
|
|
||||||
&aliasing_output_operand_pairs)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
static_cast<optional<
|
|
||||||
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>*>(
|
|
||||||
attr_out_ptr)
|
|
||||||
->emplace(std::move(aliasing_output_operand_pairs));
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}();
|
}();
|
||||||
if (!success) {
|
if (!success) {
|
||||||
|
@ -991,19 +991,6 @@ ENTRY %CustomCallWithHasSideEffect (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4])
|
|||||||
ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", custom_call_has_side_effect=true
|
ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", custom_call_has_side_effect=true
|
||||||
}
|
}
|
||||||
|
|
||||||
)"
|
|
||||||
},
|
|
||||||
// CustomCallWithAliasing
|
|
||||||
{
|
|
||||||
"CustomCallWithAliasing",
|
|
||||||
R"(HloModule CustomCallWithAliasing
|
|
||||||
|
|
||||||
ENTRY %CustomCallWithAliasing (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[123,4], f32[2,2], f32[1,2,3]) {
|
|
||||||
%p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
|
|
||||||
%p1 = f32[123,4]{0,1} parameter(1)
|
|
||||||
ROOT %custom-call = (f32[123,4]{0,1}, f32[2,2]{0,1}, f32[1,2,3]{0,1,2}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", output_to_operand_aliasing={{0}: (1, {}), {1}: (0, {0})}
|
|
||||||
}
|
|
||||||
|
|
||||||
)"
|
)"
|
||||||
},
|
},
|
||||||
// Parse c64 literal
|
// Parse c64 literal
|
||||||
|
@ -801,28 +801,6 @@ Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) {
|
|||||||
TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout));
|
TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (const auto& pair : custom_call->output_to_operand_aliasing()) {
|
|
||||||
TF_RET_CHECK(pair.second.first < custom_call->operand_count())
|
|
||||||
<< "Invalid aliasing operand index.";
|
|
||||||
TF_RET_CHECK(ShapeUtil::IndexIsValid(
|
|
||||||
custom_call->operand(pair.second.first)->shape(), pair.second.second))
|
|
||||||
<< "Invalid aliasing operand shape index.";
|
|
||||||
TF_RET_CHECK(ShapeUtil::IndexIsValid(custom_call->shape(), pair.first))
|
|
||||||
<< "Invalid aliasing output shape index.";
|
|
||||||
const Shape& output_subshape =
|
|
||||||
ShapeUtil::GetSubshape(custom_call->shape(), pair.first);
|
|
||||||
const Shape& operand_subshape = ShapeUtil::GetSubshape(
|
|
||||||
custom_call->operand(pair.second.first)->shape(), pair.second.second);
|
|
||||||
if (layout_sensitive_) {
|
|
||||||
TF_RET_CHECK(operand_subshape == output_subshape)
|
|
||||||
<< "Different aliasing shapes: " << operand_subshape.ToString()
|
|
||||||
<< " vs " << output_subshape.ToString();
|
|
||||||
} else {
|
|
||||||
TF_RET_CHECK(ShapeUtil::Compatible(output_subshape, operand_subshape))
|
|
||||||
<< "Different aliasing shapes: " << operand_subshape.ToString()
|
|
||||||
<< " vs " << output_subshape.ToString();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,9 +15,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
|
#include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory(
|
bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory(
|
||||||
@ -90,17 +87,6 @@ bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (auto* custom_call =
|
|
||||||
DynCast<HloCustomCallInstruction>(position.instruction)) {
|
|
||||||
for (const auto& pair : custom_call->output_to_operand_aliasing()) {
|
|
||||||
if (position.index == pair.first) {
|
|
||||||
VLOG(4) << "Keeping value " << value->ToShortString()
|
|
||||||
<< " in default mem because it is a custom-call output that "
|
|
||||||
"aliases an operand buffer.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
@ -691,11 +691,3 @@ message WhileLoopBackendConfig {
|
|||||||
// unknown-trip-count.
|
// unknown-trip-count.
|
||||||
KnownTripCount known_trip_count = 1;
|
KnownTripCount known_trip_count = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Specifies a pair of output/operand buffers for kCustomCall that alias each
|
|
||||||
// other.
|
|
||||||
message CustomCallOutputOperandAliasing {
|
|
||||||
repeated int64 output_shape_index = 1;
|
|
||||||
int64 operand_index = 2;
|
|
||||||
repeated int64 operand_shape_index = 3;
|
|
||||||
}
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user