Add a literal field to custom call instruction.
- Add a literal field to custom call instruction. - Change SetBound custom call to use literal as side data instead of a hand-serialized number. PiperOrigin-RevId: 358006370 Change-Id: I67727bbe3ce12b082fbcfc7290e57194bd96c29a
This commit is contained in:
parent
7ec1faf4d1
commit
09db4abc5b
@ -134,7 +134,8 @@ StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
|
|||||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||||
bool has_side_effect,
|
bool has_side_effect,
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||||
output_operand_aliasing) {
|
output_operand_aliasing,
|
||||||
|
const Literal* literal) {
|
||||||
if (operand_shapes_with_layout.has_value())
|
if (operand_shapes_with_layout.has_value())
|
||||||
return Unimplemented(
|
return Unimplemented(
|
||||||
"CustomCall doesn't support operands shapes with layout");
|
"CustomCall doesn't support operands shapes with layout");
|
||||||
@ -142,6 +143,8 @@ StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
|
|||||||
shape, builder_));
|
shape, builder_));
|
||||||
TF_RET_CHECK(output_operand_aliasing.empty())
|
TF_RET_CHECK(output_operand_aliasing.empty())
|
||||||
<< "MLIR CustomCallOp does not support output_operand_aliasing yet";
|
<< "MLIR CustomCallOp does not support output_operand_aliasing yet";
|
||||||
|
TF_RET_CHECK(literal == nullptr)
|
||||||
|
<< "MLIR CustomCallOp does not support literal yet";
|
||||||
auto op = builder_.create<mlir::mhlo::CustomCallOp>(
|
auto op = builder_.create<mlir::mhlo::CustomCallOp>(
|
||||||
loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
|
loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
|
||||||
/*has_side_effect=*/builder_.getBoolAttr(has_side_effect),
|
/*has_side_effect=*/builder_.getBoolAttr(has_side_effect),
|
||||||
|
@ -137,7 +137,8 @@ class MlirHloBuilder : public XlaBuilder {
|
|||||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||||
bool has_side_effect,
|
bool has_side_effect,
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||||
output_operand_aliasing) override;
|
output_operand_aliasing,
|
||||||
|
const Literal* literal) override;
|
||||||
|
|
||||||
StatusOr<XlaOp> ReduceInternal(
|
StatusOr<XlaOp> ReduceInternal(
|
||||||
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/core/framework/bounds_check.h"
|
#include "tensorflow/core/framework/bounds_check.h"
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
@ -97,10 +98,11 @@ class XlaSetBoundOp : public XlaOpKernel {
|
|||||||
bound_shape.DebugString()));
|
bound_shape.DebugString()));
|
||||||
int64 bound;
|
int64 bound;
|
||||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("bound", &bound));
|
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("bound", &bound));
|
||||||
|
xla::Literal bound_literal = xla::LiteralUtil::CreateR0<int32>(bound);
|
||||||
xla::XlaOp result = xla::CustomCall(
|
xla::XlaOp result =
|
||||||
ctx->builder(), "SetBound", {ctx->Input("input")},
|
xla::CustomCall(ctx->builder(), "SetBound", {ctx->Input("input")},
|
||||||
ctx->InputXlaShape("input").ValueOrDie(), absl::StrFormat("%d", bound));
|
ctx->InputXlaShape("input").ValueOrDie(), "", false, {},
|
||||||
|
&bound_literal);
|
||||||
ctx->SetOutput(0, result);
|
ctx->SetOutput(0, result);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1882,7 +1882,8 @@ XlaOp XlaBuilder::CustomCall(
|
|||||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||||
bool has_side_effect,
|
bool has_side_effect,
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||||
output_operand_aliasing) {
|
output_operand_aliasing,
|
||||||
|
const Literal* literal) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
if (absl::StartsWith(call_target_name, "$")) {
|
if (absl::StartsWith(call_target_name, "$")) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
@ -1915,7 +1916,7 @@ XlaOp XlaBuilder::CustomCall(
|
|||||||
}
|
}
|
||||||
return CustomCallInternal(call_target_name, operands, shape, opaque,
|
return CustomCallInternal(call_target_name, operands, shape, opaque,
|
||||||
operand_shapes_with_layout, has_side_effect,
|
operand_shapes_with_layout, has_side_effect,
|
||||||
output_operand_aliasing);
|
output_operand_aliasing, literal);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1925,7 +1926,8 @@ StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
|
|||||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||||
bool has_side_effect,
|
bool has_side_effect,
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||||
output_operand_aliasing) {
|
output_operand_aliasing,
|
||||||
|
const Literal* literal) {
|
||||||
HloInstructionProto instr;
|
HloInstructionProto instr;
|
||||||
*instr.mutable_shape() = shape.ToProto();
|
*instr.mutable_shape() = shape.ToProto();
|
||||||
instr.set_custom_call_target(call_target_name);
|
instr.set_custom_call_target(call_target_name);
|
||||||
@ -1936,6 +1938,9 @@ StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
|
|||||||
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
|
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (literal != nullptr) {
|
||||||
|
*instr.mutable_literal() = literal->ToProto();
|
||||||
|
}
|
||||||
instr.set_custom_call_has_side_effect(has_side_effect);
|
instr.set_custom_call_has_side_effect(has_side_effect);
|
||||||
for (const auto& pair : output_operand_aliasing) {
|
for (const auto& pair : output_operand_aliasing) {
|
||||||
auto aliasing = instr.add_custom_call_output_operand_aliasing();
|
auto aliasing = instr.add_custom_call_output_operand_aliasing();
|
||||||
@ -1956,7 +1961,8 @@ XlaOp XlaBuilder::CustomCall(
|
|||||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||||
bool has_side_effect,
|
bool has_side_effect,
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||||
output_operand_aliasing) {
|
output_operand_aliasing,
|
||||||
|
const Literal* literal) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
HloInstructionProto instr;
|
HloInstructionProto instr;
|
||||||
if (absl::StartsWith(call_target_name, "$")) {
|
if (absl::StartsWith(call_target_name, "$")) {
|
||||||
@ -1968,6 +1974,9 @@ XlaOp XlaBuilder::CustomCall(
|
|||||||
*instr.mutable_shape() = shape.ToProto();
|
*instr.mutable_shape() = shape.ToProto();
|
||||||
instr.set_custom_call_target(call_target_name);
|
instr.set_custom_call_target(call_target_name);
|
||||||
instr.set_backend_config(opaque);
|
instr.set_backend_config(opaque);
|
||||||
|
if (literal != nullptr) {
|
||||||
|
*instr.mutable_literal() = literal->ToProto();
|
||||||
|
}
|
||||||
if (operand_shapes_with_layout.has_value()) {
|
if (operand_shapes_with_layout.has_value()) {
|
||||||
if (!LayoutUtil::HasLayout(shape)) {
|
if (!LayoutUtil::HasLayout(shape)) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
@ -3786,6 +3795,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
|||||||
HloOpcodeString(HloOpcode::kGetDimensionSize) ||
|
HloOpcodeString(HloOpcode::kGetDimensionSize) ||
|
||||||
InstrIsSetBound(instr_proto)) {
|
InstrIsSetBound(instr_proto)) {
|
||||||
int32 constant_value = -1;
|
int32 constant_value = -1;
|
||||||
|
HloInstructionProto const_instr;
|
||||||
|
|
||||||
if (instr_proto->opcode() ==
|
if (instr_proto->opcode() ==
|
||||||
HloOpcodeString(HloOpcode::kGetDimensionSize)) {
|
HloOpcodeString(HloOpcode::kGetDimensionSize)) {
|
||||||
// At this point, BuildConstantSubGraph should never encounter a
|
// At this point, BuildConstantSubGraph should never encounter a
|
||||||
@ -3804,18 +3815,14 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
|||||||
constant_value =
|
constant_value =
|
||||||
static_cast<int32>(operand_proto->shape().dimensions(dimension));
|
static_cast<int32>(operand_proto->shape().dimensions(dimension));
|
||||||
}
|
}
|
||||||
|
Literal literal = LiteralUtil::CreateR0(constant_value);
|
||||||
|
*const_instr.mutable_literal() = literal.ToProto();
|
||||||
|
*const_instr.mutable_shape() = literal.shape().ToProto();
|
||||||
} else {
|
} else {
|
||||||
TF_RET_CHECK(
|
*const_instr.mutable_literal() = instr_proto->literal();
|
||||||
absl::SimpleAtoi(instr_proto->backend_config(), &constant_value));
|
*const_instr.mutable_shape() = instr_proto->shape();
|
||||||
}
|
}
|
||||||
|
|
||||||
Literal literal = LiteralUtil::CreateR0(constant_value);
|
|
||||||
|
|
||||||
HloInstructionProto const_instr;
|
|
||||||
*const_instr.mutable_shape() = literal.shape().ToProto();
|
|
||||||
*const_instr.mutable_literal() = literal.ToProto();
|
|
||||||
*const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
|
*const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
|
||||||
|
|
||||||
const_instr.set_id(handle);
|
const_instr.set_id(handle);
|
||||||
*const_instr.mutable_name() =
|
*const_instr.mutable_name() =
|
||||||
GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id());
|
GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id());
|
||||||
@ -3866,7 +3873,6 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
*module->add_computations() = std::move(entry);
|
*module->add_computations() = std::move(entry);
|
||||||
|
|
||||||
return std::move(computation);
|
return std::move(computation);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4459,10 +4465,11 @@ XlaOp CustomCall(
|
|||||||
absl::Span<const XlaOp> operands, const Shape& shape, const string& opaque,
|
absl::Span<const XlaOp> operands, const Shape& shape, const string& opaque,
|
||||||
bool has_side_effect,
|
bool has_side_effect,
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||||
output_operand_aliasing) {
|
output_operand_aliasing,
|
||||||
|
const Literal* literal) {
|
||||||
return builder->CustomCall(call_target_name, operands, shape, opaque,
|
return builder->CustomCall(call_target_name, operands, shape, opaque,
|
||||||
/*operand_shapes_with_layout=*/absl::nullopt,
|
/*operand_shapes_with_layout=*/absl::nullopt,
|
||||||
has_side_effect, output_operand_aliasing);
|
has_side_effect, output_operand_aliasing, literal);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp CustomCallWithComputation(
|
XlaOp CustomCallWithComputation(
|
||||||
@ -4470,11 +4477,12 @@ XlaOp CustomCallWithComputation(
|
|||||||
absl::Span<const XlaOp> operands, const XlaComputation& computation,
|
absl::Span<const XlaOp> operands, const XlaComputation& computation,
|
||||||
const Shape& shape, const string& opaque, bool has_side_effect,
|
const Shape& shape, const string& opaque, bool has_side_effect,
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||||
output_operand_aliasing) {
|
output_operand_aliasing,
|
||||||
|
const Literal* literal) {
|
||||||
return builder->CustomCall(call_target_name, operands, computation, shape,
|
return builder->CustomCall(call_target_name, operands, computation, shape,
|
||||||
opaque,
|
opaque,
|
||||||
/*operand_shapes_with_layout=*/absl::nullopt,
|
/*operand_shapes_with_layout=*/absl::nullopt,
|
||||||
has_side_effect, output_operand_aliasing);
|
has_side_effect, output_operand_aliasing, literal);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp CustomCallWithLayout(
|
XlaOp CustomCallWithLayout(
|
||||||
@ -4483,10 +4491,11 @@ XlaOp CustomCallWithLayout(
|
|||||||
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
|
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
|
||||||
bool has_side_effect,
|
bool has_side_effect,
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||||
output_operand_aliasing) {
|
output_operand_aliasing,
|
||||||
|
const Literal* literal) {
|
||||||
return builder->CustomCall(call_target_name, operands, shape, opaque,
|
return builder->CustomCall(call_target_name, operands, shape, opaque,
|
||||||
operand_shapes_with_layout, has_side_effect,
|
operand_shapes_with_layout, has_side_effect,
|
||||||
output_operand_aliasing);
|
output_operand_aliasing, literal);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
|
XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
|
||||||
|
@ -655,7 +655,8 @@ class XlaBuilder {
|
|||||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||||
bool has_side_effect,
|
bool has_side_effect,
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||||
output_operand_aliasing);
|
output_operand_aliasing,
|
||||||
|
const Literal* literal);
|
||||||
|
|
||||||
// Internal version of CustomCall without computation that doesn't do op
|
// Internal version of CustomCall without computation that doesn't do op
|
||||||
// specific error handling and expects arguments to be legal. CustomCall
|
// specific error handling and expects arguments to be legal. CustomCall
|
||||||
@ -666,7 +667,8 @@ class XlaBuilder {
|
|||||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||||
bool has_side_effect,
|
bool has_side_effect,
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||||
output_operand_aliasing);
|
output_operand_aliasing,
|
||||||
|
const Literal* literal);
|
||||||
|
|
||||||
XlaOp CustomCall(
|
XlaOp CustomCall(
|
||||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||||
@ -675,7 +677,8 @@ class XlaBuilder {
|
|||||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||||
bool has_side_effect,
|
bool has_side_effect,
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||||
output_operand_aliasing);
|
output_operand_aliasing,
|
||||||
|
const Literal* literal);
|
||||||
|
|
||||||
XlaOp Reduce(XlaOp operand, XlaOp init_value,
|
XlaOp Reduce(XlaOp operand, XlaOp init_value,
|
||||||
const XlaComputation& computation,
|
const XlaComputation& computation,
|
||||||
@ -1214,20 +1217,23 @@ class XlaBuilder {
|
|||||||
absl::Span<const XlaOp> operands, const Shape& shape,
|
absl::Span<const XlaOp> operands, const Shape& shape,
|
||||||
const string& opaque, bool has_side_effect,
|
const string& opaque, bool has_side_effect,
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||||
output_operand_aliasing);
|
output_operand_aliasing,
|
||||||
|
const Literal* literal);
|
||||||
friend XlaOp CustomCallWithComputation(
|
friend XlaOp CustomCallWithComputation(
|
||||||
XlaBuilder* builder, const string& call_target_name,
|
XlaBuilder* builder, const string& call_target_name,
|
||||||
absl::Span<const XlaOp> operands, const XlaComputation& computation,
|
absl::Span<const XlaOp> operands, const XlaComputation& computation,
|
||||||
const Shape& shape, const string& opaque, bool has_side_effect,
|
const Shape& shape, const string& opaque, bool has_side_effect,
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||||
output_operand_aliasing);
|
output_operand_aliasing,
|
||||||
|
const Literal* literal);
|
||||||
friend XlaOp CustomCallWithLayout(
|
friend XlaOp CustomCallWithLayout(
|
||||||
XlaBuilder* builder, const string& call_target_name,
|
XlaBuilder* builder, const string& call_target_name,
|
||||||
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
|
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
|
||||||
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
|
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
|
||||||
bool has_side_effect,
|
bool has_side_effect,
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||||
output_operand_aliasing);
|
output_operand_aliasing,
|
||||||
|
const Literal* literal);
|
||||||
friend XlaOp Complex(XlaOp real, XlaOp imag,
|
friend XlaOp Complex(XlaOp real, XlaOp imag,
|
||||||
absl::Span<const int64> broadcast_dimensions);
|
absl::Span<const int64> broadcast_dimensions);
|
||||||
friend XlaOp Conj(XlaOp operand);
|
friend XlaOp Conj(XlaOp operand);
|
||||||
@ -2025,7 +2031,8 @@ XlaOp CustomCall(
|
|||||||
absl::Span<const XlaOp> operands, const Shape& shape,
|
absl::Span<const XlaOp> operands, const Shape& shape,
|
||||||
const string& opaque = "", bool has_side_effect = false,
|
const string& opaque = "", bool has_side_effect = false,
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||||
output_operand_aliasing = {});
|
output_operand_aliasing = {},
|
||||||
|
const Literal* literal = nullptr);
|
||||||
|
|
||||||
// Overload which constructs a custom call that applies an Xla computation.
|
// Overload which constructs a custom call that applies an Xla computation.
|
||||||
XlaOp CustomCallWithComputation(
|
XlaOp CustomCallWithComputation(
|
||||||
@ -2033,7 +2040,8 @@ XlaOp CustomCallWithComputation(
|
|||||||
absl::Span<const XlaOp> operands, const XlaComputation& computation,
|
absl::Span<const XlaOp> operands, const XlaComputation& computation,
|
||||||
const Shape& shape, const string& opaque = "", bool has_side_effect = false,
|
const Shape& shape, const string& opaque = "", bool has_side_effect = false,
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||||
output_operand_aliasing = {});
|
output_operand_aliasing = {},
|
||||||
|
const Literal* literal = nullptr);
|
||||||
|
|
||||||
// Overload which constructs a custom call with fixed layouts. The operands will
|
// Overload which constructs a custom call with fixed layouts. The operands will
|
||||||
// have the layouts specified by |operand_shapes_with_layout| when provided to
|
// have the layouts specified by |operand_shapes_with_layout| when provided to
|
||||||
@ -2046,7 +2054,8 @@ XlaOp CustomCallWithLayout(
|
|||||||
absl::Span<const Shape> operand_shapes_with_layout,
|
absl::Span<const Shape> operand_shapes_with_layout,
|
||||||
const string& opaque = "", bool has_side_effect = false,
|
const string& opaque = "", bool has_side_effect = false,
|
||||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||||
output_operand_aliasing = {});
|
output_operand_aliasing = {},
|
||||||
|
const Literal* literal = nullptr);
|
||||||
|
|
||||||
// The following methods enqueue element-wise binary arithmetic operations
|
// The following methods enqueue element-wise binary arithmetic operations
|
||||||
// onto the computation. The shapes of the operands have to match unless one
|
// onto the computation. The shapes of the operands have to match unless one
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
|
#include "absl/strings/str_split.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/index_util.h"
|
#include "tensorflow/compiler/xla/index_util.h"
|
||||||
@ -71,6 +72,22 @@ void ConvertEndianShort(char* bytes, int64 size) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string CompactOneline(const string& input) {
|
||||||
|
string result;
|
||||||
|
std::vector<string> v = absl::StrSplit(input, absl::ByAnyChar("\n "));
|
||||||
|
bool first = true;
|
||||||
|
// Concatenate elements in "v" with spaces separating them, but ignoring
|
||||||
|
// empty entries.
|
||||||
|
for (const auto& s : v) {
|
||||||
|
if (s.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
absl::StrAppend(&result, (first ? "" : " "), s);
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
|
// Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
|
||||||
// able to transparently access the raw 16-bit value contained within.
|
// able to transparently access the raw 16-bit value contained within.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -1281,6 +1298,10 @@ string LiteralBase::ToString() const {
|
|||||||
return absl::StrJoin(pieces, "");
|
return absl::StrJoin(pieces, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string LiteralBase::ToStringOneline() const {
|
||||||
|
return CompactOneline(ToString());
|
||||||
|
}
|
||||||
|
|
||||||
string LiteralBase::ToStringWithoutShape() const {
|
string LiteralBase::ToStringWithoutShape() const {
|
||||||
std::vector<string> pieces;
|
std::vector<string> pieces;
|
||||||
CHECK(LayoutUtil::HasLayout(this->shape()));
|
CHECK(LayoutUtil::HasLayout(this->shape()));
|
||||||
@ -1289,6 +1310,10 @@ string LiteralBase::ToStringWithoutShape() const {
|
|||||||
return absl::StrJoin(pieces, "");
|
return absl::StrJoin(pieces, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string LiteralBase::ToStringWithoutShapeOneline() const {
|
||||||
|
return CompactOneline(ToStringWithoutShape());
|
||||||
|
}
|
||||||
|
|
||||||
string LiteralBase::ToStringWithLayout() const {
|
string LiteralBase::ToStringWithLayout() const {
|
||||||
std::vector<string> pieces;
|
std::vector<string> pieces;
|
||||||
CHECK(LayoutUtil::HasLayout(this->shape()));
|
CHECK(LayoutUtil::HasLayout(this->shape()));
|
||||||
|
@ -94,10 +94,18 @@ class LiteralBase {
|
|||||||
// element Literals.
|
// element Literals.
|
||||||
string ToString() const;
|
string ToString() const;
|
||||||
|
|
||||||
|
// Similar to ToString, but return the result in a compact
|
||||||
|
// one-line form.
|
||||||
|
string ToStringOneline() const;
|
||||||
|
|
||||||
// Returns a string representation of the literal value which does *not*
|
// Returns a string representation of the literal value which does *not*
|
||||||
// include the shape string.
|
// include the shape string.
|
||||||
string ToStringWithoutShape() const;
|
string ToStringWithoutShape() const;
|
||||||
|
|
||||||
|
// Similar to ToStringWithoutShape, but return the result in a compact
|
||||||
|
// one-line form.
|
||||||
|
string ToStringWithoutShapeOneline() const;
|
||||||
|
|
||||||
// Returns a string representation of the literal value which includes the
|
// Returns a string representation of the literal value which includes the
|
||||||
// shape string with its layout.does *not* include the shape string.
|
// shape string with its layout.does *not* include the shape string.
|
||||||
string ToStringWithLayout() const;
|
string ToStringWithLayout() const;
|
||||||
|
@ -328,7 +328,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
instruction = CreateConstant(std::move(literal));
|
instruction = CreateConstant(std::move(literal));
|
||||||
// Literal's shape may have no/different tiling info.
|
// Literal's shape may have no/different tiling info.
|
||||||
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
|
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
|
||||||
instruction->shape(), shape));
|
instruction->shape(), shape))
|
||||||
|
<< instruction->shape().ToString(true) << " vs "
|
||||||
|
<< shape.ToString(true);
|
||||||
*instruction->mutable_shape() = shape;
|
*instruction->mutable_shape() = shape;
|
||||||
} else {
|
} else {
|
||||||
instruction = absl::make_unique<HloConstantInstruction>(shape);
|
instruction = absl::make_unique<HloConstantInstruction>(shape);
|
||||||
@ -578,6 +580,12 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
if (proto.has_window()) {
|
if (proto.has_window()) {
|
||||||
custom_call_instr->set_window(proto.window());
|
custom_call_instr->set_window(proto.window());
|
||||||
}
|
}
|
||||||
|
if (proto.has_literal()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto literal,
|
||||||
|
Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
|
||||||
|
custom_call_instr->set_literal(std::move(literal));
|
||||||
|
}
|
||||||
if (proto.has_convolution_dimension_numbers()) {
|
if (proto.has_convolution_dimension_numbers()) {
|
||||||
custom_call_instr->set_convolution_dimension_numbers(
|
custom_call_instr->set_convolution_dimension_numbers(
|
||||||
proto.convolution_dimension_numbers());
|
proto.convolution_dimension_numbers());
|
||||||
|
@ -1328,19 +1328,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
|
|||||||
options.print_large_constants())) {
|
options.print_large_constants())) {
|
||||||
// Literal::ToString emits multidimensional arrays over multiple
|
// Literal::ToString emits multidimensional arrays over multiple
|
||||||
// lines. Compact this into one line by stripping out white space.
|
// lines. Compact this into one line by stripping out white space.
|
||||||
string tmp = literal().ToStringWithoutShape();
|
operands = literal_->ToStringWithoutShapeOneline();
|
||||||
std::replace(tmp.begin(), tmp.end(), '\n', ' ');
|
|
||||||
std::vector<string> v = absl::StrSplit(tmp, ' ');
|
|
||||||
bool first = true;
|
|
||||||
// Concatenate elements in "v" with spaces separating them, but ignoring
|
|
||||||
// empty entries.
|
|
||||||
for (const auto& s : v) {
|
|
||||||
if (s.empty()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
StrAppend(&operands, (first ? "" : " "), s);
|
|
||||||
first = false;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// Do not show large constants or tuples.
|
// Do not show large constants or tuples.
|
||||||
operands = "{...}";
|
operands = "{...}";
|
||||||
@ -2441,6 +2429,9 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
proto.set_custom_call_has_side_effect(custom_call_has_side_effect_);
|
proto.set_custom_call_has_side_effect(custom_call_has_side_effect_);
|
||||||
|
if (literal_.has_value()) {
|
||||||
|
*proto.mutable_literal() = literal_->ToProto();
|
||||||
|
}
|
||||||
for (const auto& pair : output_to_operand_aliasing_) {
|
for (const auto& pair : output_to_operand_aliasing_) {
|
||||||
auto aliasing = proto.add_custom_call_output_operand_aliasing();
|
auto aliasing = proto.add_custom_call_output_operand_aliasing();
|
||||||
aliasing->set_operand_index(pair.second.first);
|
aliasing->set_operand_index(pair.second.first);
|
||||||
@ -2495,6 +2486,9 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
|
|||||||
if (custom_call_has_side_effect_) {
|
if (custom_call_has_side_effect_) {
|
||||||
extra.push_back("custom_call_has_side_effect=true");
|
extra.push_back("custom_call_has_side_effect=true");
|
||||||
}
|
}
|
||||||
|
if (literal_.has_value()) {
|
||||||
|
extra.push_back(StrCat("literal=(", literal_->ToStringOneline(), ")"));
|
||||||
|
}
|
||||||
if (!output_to_operand_aliasing_.empty()) {
|
if (!output_to_operand_aliasing_.empty()) {
|
||||||
std::vector<string> pair_strings;
|
std::vector<string> pair_strings;
|
||||||
for (const auto& pair : output_to_operand_aliasing_) {
|
for (const auto& pair : output_to_operand_aliasing_) {
|
||||||
@ -2571,6 +2565,13 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (HasLiteral() == casted_other.HasLiteral()) {
|
||||||
|
if (HasLiteral() && literal() == casted_other.literal()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// Note: backend_config comparison is done in Identical, which is the
|
// Note: backend_config comparison is done in Identical, which is the
|
||||||
// intended/exposed way to compare computations, and so not repeated here.
|
// intended/exposed way to compare computations, and so not repeated here.
|
||||||
@ -2593,6 +2594,9 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
|
|||||||
if (convolution_dimension_numbers_ != nullptr) {
|
if (convolution_dimension_numbers_ != nullptr) {
|
||||||
cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
|
cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
|
||||||
}
|
}
|
||||||
|
if (HasLiteral()) {
|
||||||
|
cloned->set_literal(literal().Clone());
|
||||||
|
}
|
||||||
cloned->set_feature_group_count(feature_group_count_);
|
cloned->set_feature_group_count(feature_group_count_);
|
||||||
cloned->set_batch_group_count(batch_group_count_);
|
cloned->set_batch_group_count(batch_group_count_);
|
||||||
cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
|
cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
|
||||||
|
@ -1466,6 +1466,13 @@ class HloCustomCallInstruction : public HloInstruction {
|
|||||||
padding_type_ = padding_type;
|
padding_type_ = padding_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the literal associated with this instruction.
|
||||||
|
const Literal& literal() const { return *literal_; }
|
||||||
|
// Set the value of literal to a new one.
|
||||||
|
void set_literal(Literal&& literal) { literal_.emplace(std::move(literal)); }
|
||||||
|
// Returns whether there is literal associated with this instruction.
|
||||||
|
bool HasLiteral() const { return literal_.has_value(); }
|
||||||
|
|
||||||
const PrecisionConfig& precision_config() const { return precision_config_; }
|
const PrecisionConfig& precision_config() const { return precision_config_; }
|
||||||
PrecisionConfig* mutable_precision_config() { return &precision_config_; }
|
PrecisionConfig* mutable_precision_config() { return &precision_config_; }
|
||||||
|
|
||||||
@ -1532,6 +1539,7 @@ class HloCustomCallInstruction : public HloInstruction {
|
|||||||
// output_to_operand_aliasing().
|
// output_to_operand_aliasing().
|
||||||
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||||
output_to_operand_aliasing_;
|
output_to_operand_aliasing_;
|
||||||
|
absl::optional<Literal> literal_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class HloPadInstruction : public HloInstruction {
|
class HloPadInstruction : public HloInstruction {
|
||||||
|
@ -253,6 +253,7 @@ class HloParserImpl : public HloParser {
|
|||||||
bool ParseInstructionRhs(HloComputation::Builder* builder,
|
bool ParseInstructionRhs(HloComputation::Builder* builder,
|
||||||
const std::string& name, LocTy name_loc);
|
const std::string& name, LocTy name_loc);
|
||||||
bool ParseControlPredecessors(HloInstruction* instruction);
|
bool ParseControlPredecessors(HloInstruction* instruction);
|
||||||
|
bool ParseLiteral(Literal* literal);
|
||||||
bool ParseLiteral(Literal* literal, const Shape& shape);
|
bool ParseLiteral(Literal* literal, const Shape& shape);
|
||||||
bool ParseTupleLiteral(Literal* literal, const Shape& shape);
|
bool ParseTupleLiteral(Literal* literal, const Shape& shape);
|
||||||
bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
|
bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
|
||||||
@ -307,6 +308,7 @@ class HloParserImpl : public HloParser {
|
|||||||
kInt32,
|
kInt32,
|
||||||
kFloat,
|
kFloat,
|
||||||
kString,
|
kString,
|
||||||
|
kLiteral,
|
||||||
kBracedInt64List,
|
kBracedInt64List,
|
||||||
kBracedInt64ListList,
|
kBracedInt64ListList,
|
||||||
kHloComputation,
|
kHloComputation,
|
||||||
@ -2268,6 +2270,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
|
|
||||||
attrs["padding_type"] = {/*required=*/false, AttrTy::kPaddingType,
|
attrs["padding_type"] = {/*required=*/false, AttrTy::kPaddingType,
|
||||||
&padding_type};
|
&padding_type};
|
||||||
|
|
||||||
|
optional<Literal> literal;
|
||||||
|
attrs["literal"] = {/*required=*/false, AttrTy::kLiteral, &literal};
|
||||||
optional<std::vector<PrecisionConfig::Precision>> operand_precision;
|
optional<std::vector<PrecisionConfig::Precision>> operand_precision;
|
||||||
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
|
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
|
||||||
&operand_precision};
|
&operand_precision};
|
||||||
@ -2357,6 +2362,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
custom_call_instr->set_output_to_operand_aliasing(
|
custom_call_instr->set_output_to_operand_aliasing(
|
||||||
std::move(*output_to_operand_aliasing));
|
std::move(*output_to_operand_aliasing));
|
||||||
}
|
}
|
||||||
|
if (literal.has_value()) {
|
||||||
|
custom_call_instr->set_literal(std::move(*literal));
|
||||||
|
}
|
||||||
PrecisionConfig precision_config;
|
PrecisionConfig precision_config;
|
||||||
if (operand_precision) {
|
if (operand_precision) {
|
||||||
*precision_config.mutable_operand_precision() = {
|
*precision_config.mutable_operand_precision() = {
|
||||||
@ -3048,6 +3056,14 @@ bool HloParserImpl::SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HloParserImpl::ParseLiteral(Literal* literal) {
|
||||||
|
Shape literal_shape;
|
||||||
|
if (!ParseShape(&literal_shape)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return ParseLiteral(literal, literal_shape);
|
||||||
|
}
|
||||||
|
|
||||||
// literal
|
// literal
|
||||||
// ::= tuple
|
// ::= tuple
|
||||||
// ::= non_tuple
|
// ::= non_tuple
|
||||||
@ -3830,6 +3846,21 @@ bool HloParserImpl::ParseAttributeHelper(
|
|||||||
->emplace(std::move(aliasing_output_operand_pairs));
|
->emplace(std::move(aliasing_output_operand_pairs));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
case AttrTy::kLiteral: {
|
||||||
|
if (!ParseToken(TokKind::kLparen, "expects '(' before literal")) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
Literal result;
|
||||||
|
if (!ParseLiteral(&result)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!ParseToken(TokKind::kRparen, "expects ')' after literal")) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
static_cast<optional<Literal>*>(attr_out_ptr)
|
||||||
|
->emplace(std::move(result));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}();
|
}();
|
||||||
if (!success) {
|
if (!success) {
|
||||||
|
@ -399,6 +399,32 @@ ENTRY %CustomCall () -> f32[1,2,3] {
|
|||||||
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", backend_config="this string is opaque"
|
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", backend_config="this string is opaque"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
)"
|
||||||
|
},
|
||||||
|
|
||||||
|
// CustomCall with literal.
|
||||||
|
{
|
||||||
|
"CustomCallWithLiteral",
|
||||||
|
R"(HloModule custom_call
|
||||||
|
|
||||||
|
ENTRY %CustomCall () -> f32[1,2,3] {
|
||||||
|
%constant = f32[1]{0} constant({12345})
|
||||||
|
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", literal=(f32[1] {0.1})
|
||||||
|
}
|
||||||
|
|
||||||
|
)"
|
||||||
|
},
|
||||||
|
|
||||||
|
// CustomCall with literal R0.
|
||||||
|
{
|
||||||
|
"CustomCallWithLiteralR0",
|
||||||
|
R"(HloModule custom_call
|
||||||
|
|
||||||
|
ENTRY %CustomCall () -> f32[1,2,3] {
|
||||||
|
%constant = f32[1]{0} constant({12345})
|
||||||
|
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", literal=(f32[] 0.1)
|
||||||
|
}
|
||||||
|
|
||||||
)"
|
)"
|
||||||
},
|
},
|
||||||
// reduce window
|
// reduce window
|
||||||
|
Loading…
Reference in New Issue
Block a user