Add a literal field to custom call instruction.
- Add a literal field to custom call instruction. - Change SetBound custom call to use literal as side data instead of a hand-serialized number. PiperOrigin-RevId: 358006370 Change-Id: I67727bbe3ce12b082fbcfc7290e57194bd96c29a
This commit is contained in:
parent
7ec1faf4d1
commit
09db4abc5b
tensorflow/compiler
@ -134,7 +134,8 @@ StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing) {
|
||||
output_operand_aliasing,
|
||||
const Literal* literal) {
|
||||
if (operand_shapes_with_layout.has_value())
|
||||
return Unimplemented(
|
||||
"CustomCall doesn't support operands shapes with layout");
|
||||
@ -142,6 +143,8 @@ StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
|
||||
shape, builder_));
|
||||
TF_RET_CHECK(output_operand_aliasing.empty())
|
||||
<< "MLIR CustomCallOp does not support output_operand_aliasing yet";
|
||||
TF_RET_CHECK(literal == nullptr)
|
||||
<< "MLIR CustomCallOp does not support literal yet";
|
||||
auto op = builder_.create<mlir::mhlo::CustomCallOp>(
|
||||
loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
|
||||
/*has_side_effect=*/builder_.getBoolAttr(has_side_effect),
|
||||
|
@ -137,7 +137,8 @@ class MlirHloBuilder : public XlaBuilder {
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing) override;
|
||||
output_operand_aliasing,
|
||||
const Literal* literal) override;
|
||||
|
||||
StatusOr<XlaOp> ReduceInternal(
|
||||
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
@ -97,10 +98,11 @@ class XlaSetBoundOp : public XlaOpKernel {
|
||||
bound_shape.DebugString()));
|
||||
int64 bound;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("bound", &bound));
|
||||
|
||||
xla::XlaOp result = xla::CustomCall(
|
||||
ctx->builder(), "SetBound", {ctx->Input("input")},
|
||||
ctx->InputXlaShape("input").ValueOrDie(), absl::StrFormat("%d", bound));
|
||||
xla::Literal bound_literal = xla::LiteralUtil::CreateR0<int32>(bound);
|
||||
xla::XlaOp result =
|
||||
xla::CustomCall(ctx->builder(), "SetBound", {ctx->Input("input")},
|
||||
ctx->InputXlaShape("input").ValueOrDie(), "", false, {},
|
||||
&bound_literal);
|
||||
ctx->SetOutput(0, result);
|
||||
}
|
||||
};
|
||||
|
@ -1882,7 +1882,8 @@ XlaOp XlaBuilder::CustomCall(
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing) {
|
||||
output_operand_aliasing,
|
||||
const Literal* literal) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
if (absl::StartsWith(call_target_name, "$")) {
|
||||
return InvalidArgument(
|
||||
@ -1915,7 +1916,7 @@ XlaOp XlaBuilder::CustomCall(
|
||||
}
|
||||
return CustomCallInternal(call_target_name, operands, shape, opaque,
|
||||
operand_shapes_with_layout, has_side_effect,
|
||||
output_operand_aliasing);
|
||||
output_operand_aliasing, literal);
|
||||
});
|
||||
}
|
||||
|
||||
@ -1925,7 +1926,8 @@ StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing) {
|
||||
output_operand_aliasing,
|
||||
const Literal* literal) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.set_custom_call_target(call_target_name);
|
||||
@ -1936,6 +1938,9 @@ StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
|
||||
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
|
||||
}
|
||||
}
|
||||
if (literal != nullptr) {
|
||||
*instr.mutable_literal() = literal->ToProto();
|
||||
}
|
||||
instr.set_custom_call_has_side_effect(has_side_effect);
|
||||
for (const auto& pair : output_operand_aliasing) {
|
||||
auto aliasing = instr.add_custom_call_output_operand_aliasing();
|
||||
@ -1956,7 +1961,8 @@ XlaOp XlaBuilder::CustomCall(
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing) {
|
||||
output_operand_aliasing,
|
||||
const Literal* literal) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
if (absl::StartsWith(call_target_name, "$")) {
|
||||
@ -1968,6 +1974,9 @@ XlaOp XlaBuilder::CustomCall(
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.set_custom_call_target(call_target_name);
|
||||
instr.set_backend_config(opaque);
|
||||
if (literal != nullptr) {
|
||||
*instr.mutable_literal() = literal->ToProto();
|
||||
}
|
||||
if (operand_shapes_with_layout.has_value()) {
|
||||
if (!LayoutUtil::HasLayout(shape)) {
|
||||
return InvalidArgument(
|
||||
@ -3786,6 +3795,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
||||
HloOpcodeString(HloOpcode::kGetDimensionSize) ||
|
||||
InstrIsSetBound(instr_proto)) {
|
||||
int32 constant_value = -1;
|
||||
HloInstructionProto const_instr;
|
||||
|
||||
if (instr_proto->opcode() ==
|
||||
HloOpcodeString(HloOpcode::kGetDimensionSize)) {
|
||||
// At this point, BuildConstantSubGraph should never encounter a
|
||||
@ -3804,18 +3815,14 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
||||
constant_value =
|
||||
static_cast<int32>(operand_proto->shape().dimensions(dimension));
|
||||
}
|
||||
Literal literal = LiteralUtil::CreateR0(constant_value);
|
||||
*const_instr.mutable_literal() = literal.ToProto();
|
||||
*const_instr.mutable_shape() = literal.shape().ToProto();
|
||||
} else {
|
||||
TF_RET_CHECK(
|
||||
absl::SimpleAtoi(instr_proto->backend_config(), &constant_value));
|
||||
*const_instr.mutable_literal() = instr_proto->literal();
|
||||
*const_instr.mutable_shape() = instr_proto->shape();
|
||||
}
|
||||
|
||||
Literal literal = LiteralUtil::CreateR0(constant_value);
|
||||
|
||||
HloInstructionProto const_instr;
|
||||
*const_instr.mutable_shape() = literal.shape().ToProto();
|
||||
*const_instr.mutable_literal() = literal.ToProto();
|
||||
*const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
|
||||
|
||||
const_instr.set_id(handle);
|
||||
*const_instr.mutable_name() =
|
||||
GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id());
|
||||
@ -3866,7 +3873,6 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
||||
}
|
||||
}
|
||||
*module->add_computations() = std::move(entry);
|
||||
|
||||
return std::move(computation);
|
||||
}
|
||||
|
||||
@ -4459,10 +4465,11 @@ XlaOp CustomCall(
|
||||
absl::Span<const XlaOp> operands, const Shape& shape, const string& opaque,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing) {
|
||||
output_operand_aliasing,
|
||||
const Literal* literal) {
|
||||
return builder->CustomCall(call_target_name, operands, shape, opaque,
|
||||
/*operand_shapes_with_layout=*/absl::nullopt,
|
||||
has_side_effect, output_operand_aliasing);
|
||||
has_side_effect, output_operand_aliasing, literal);
|
||||
}
|
||||
|
||||
XlaOp CustomCallWithComputation(
|
||||
@ -4470,11 +4477,12 @@ XlaOp CustomCallWithComputation(
|
||||
absl::Span<const XlaOp> operands, const XlaComputation& computation,
|
||||
const Shape& shape, const string& opaque, bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing) {
|
||||
output_operand_aliasing,
|
||||
const Literal* literal) {
|
||||
return builder->CustomCall(call_target_name, operands, computation, shape,
|
||||
opaque,
|
||||
/*operand_shapes_with_layout=*/absl::nullopt,
|
||||
has_side_effect, output_operand_aliasing);
|
||||
has_side_effect, output_operand_aliasing, literal);
|
||||
}
|
||||
|
||||
XlaOp CustomCallWithLayout(
|
||||
@ -4483,10 +4491,11 @@ XlaOp CustomCallWithLayout(
|
||||
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing) {
|
||||
output_operand_aliasing,
|
||||
const Literal* literal) {
|
||||
return builder->CustomCall(call_target_name, operands, shape, opaque,
|
||||
operand_shapes_with_layout, has_side_effect,
|
||||
output_operand_aliasing);
|
||||
output_operand_aliasing, literal);
|
||||
}
|
||||
|
||||
XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
|
||||
|
@ -655,7 +655,8 @@ class XlaBuilder {
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing);
|
||||
output_operand_aliasing,
|
||||
const Literal* literal);
|
||||
|
||||
// Internal version of CustomCall without computation that doesn't do op
|
||||
// specific error handling and expects arguments to be legal. CustomCall
|
||||
@ -666,7 +667,8 @@ class XlaBuilder {
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing);
|
||||
output_operand_aliasing,
|
||||
const Literal* literal);
|
||||
|
||||
XlaOp CustomCall(
|
||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||
@ -675,7 +677,8 @@ class XlaBuilder {
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing);
|
||||
output_operand_aliasing,
|
||||
const Literal* literal);
|
||||
|
||||
XlaOp Reduce(XlaOp operand, XlaOp init_value,
|
||||
const XlaComputation& computation,
|
||||
@ -1214,20 +1217,23 @@ class XlaBuilder {
|
||||
absl::Span<const XlaOp> operands, const Shape& shape,
|
||||
const string& opaque, bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing);
|
||||
output_operand_aliasing,
|
||||
const Literal* literal);
|
||||
friend XlaOp CustomCallWithComputation(
|
||||
XlaBuilder* builder, const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const XlaComputation& computation,
|
||||
const Shape& shape, const string& opaque, bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing);
|
||||
output_operand_aliasing,
|
||||
const Literal* literal);
|
||||
friend XlaOp CustomCallWithLayout(
|
||||
XlaBuilder* builder, const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
|
||||
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing);
|
||||
output_operand_aliasing,
|
||||
const Literal* literal);
|
||||
friend XlaOp Complex(XlaOp real, XlaOp imag,
|
||||
absl::Span<const int64> broadcast_dimensions);
|
||||
friend XlaOp Conj(XlaOp operand);
|
||||
@ -2025,7 +2031,8 @@ XlaOp CustomCall(
|
||||
absl::Span<const XlaOp> operands, const Shape& shape,
|
||||
const string& opaque = "", bool has_side_effect = false,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing = {});
|
||||
output_operand_aliasing = {},
|
||||
const Literal* literal = nullptr);
|
||||
|
||||
// Overload which constructs a custom call that applies an Xla computation.
|
||||
XlaOp CustomCallWithComputation(
|
||||
@ -2033,7 +2040,8 @@ XlaOp CustomCallWithComputation(
|
||||
absl::Span<const XlaOp> operands, const XlaComputation& computation,
|
||||
const Shape& shape, const string& opaque = "", bool has_side_effect = false,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing = {});
|
||||
output_operand_aliasing = {},
|
||||
const Literal* literal = nullptr);
|
||||
|
||||
// Overload which constructs a custom call with fixed layouts. The operands will
|
||||
// have the layouts specified by |operand_shapes_with_layout| when provided to
|
||||
@ -2046,7 +2054,8 @@ XlaOp CustomCallWithLayout(
|
||||
absl::Span<const Shape> operand_shapes_with_layout,
|
||||
const string& opaque = "", bool has_side_effect = false,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing = {});
|
||||
output_operand_aliasing = {},
|
||||
const Literal* literal = nullptr);
|
||||
|
||||
// The following methods enqueue element-wise binary arithmetic operations
|
||||
// onto the computation. The shapes of the operands have to match unless one
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/index_util.h"
|
||||
@ -71,6 +72,22 @@ void ConvertEndianShort(char* bytes, int64 size) {
|
||||
}
|
||||
}
|
||||
|
||||
string CompactOneline(const string& input) {
|
||||
string result;
|
||||
std::vector<string> v = absl::StrSplit(input, absl::ByAnyChar("\n "));
|
||||
bool first = true;
|
||||
// Concatenate elements in "v" with spaces separating them, but ignoring
|
||||
// empty entries.
|
||||
for (const auto& s : v) {
|
||||
if (s.empty()) {
|
||||
continue;
|
||||
}
|
||||
absl::StrAppend(&result, (first ? "" : " "), s);
|
||||
first = false;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
|
||||
// able to transparently access the raw 16-bit value contained within.
|
||||
template <typename T>
|
||||
@ -1281,6 +1298,10 @@ string LiteralBase::ToString() const {
|
||||
return absl::StrJoin(pieces, "");
|
||||
}
|
||||
|
||||
string LiteralBase::ToStringOneline() const {
|
||||
return CompactOneline(ToString());
|
||||
}
|
||||
|
||||
string LiteralBase::ToStringWithoutShape() const {
|
||||
std::vector<string> pieces;
|
||||
CHECK(LayoutUtil::HasLayout(this->shape()));
|
||||
@ -1289,6 +1310,10 @@ string LiteralBase::ToStringWithoutShape() const {
|
||||
return absl::StrJoin(pieces, "");
|
||||
}
|
||||
|
||||
string LiteralBase::ToStringWithoutShapeOneline() const {
|
||||
return CompactOneline(ToStringWithoutShape());
|
||||
}
|
||||
|
||||
string LiteralBase::ToStringWithLayout() const {
|
||||
std::vector<string> pieces;
|
||||
CHECK(LayoutUtil::HasLayout(this->shape()));
|
||||
|
@ -94,10 +94,18 @@ class LiteralBase {
|
||||
// element Literals.
|
||||
string ToString() const;
|
||||
|
||||
// Similar to ToString, but return the result in a compact
|
||||
// one-line form.
|
||||
string ToStringOneline() const;
|
||||
|
||||
// Returns a string representation of the literal value which does *not*
|
||||
// include the shape string.
|
||||
string ToStringWithoutShape() const;
|
||||
|
||||
// Similar to ToStringWithoutShape, but return the result in a compact
|
||||
// one-line form.
|
||||
string ToStringWithoutShapeOneline() const;
|
||||
|
||||
// Returns a string representation of the literal value which includes the
|
||||
// shape string with its layout.does *not* include the shape string.
|
||||
string ToStringWithLayout() const;
|
||||
|
@ -328,7 +328,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
instruction = CreateConstant(std::move(literal));
|
||||
// Literal's shape may have no/different tiling info.
|
||||
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
|
||||
instruction->shape(), shape));
|
||||
instruction->shape(), shape))
|
||||
<< instruction->shape().ToString(true) << " vs "
|
||||
<< shape.ToString(true);
|
||||
*instruction->mutable_shape() = shape;
|
||||
} else {
|
||||
instruction = absl::make_unique<HloConstantInstruction>(shape);
|
||||
@ -578,6 +580,12 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
if (proto.has_window()) {
|
||||
custom_call_instr->set_window(proto.window());
|
||||
}
|
||||
if (proto.has_literal()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto literal,
|
||||
Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
|
||||
custom_call_instr->set_literal(std::move(literal));
|
||||
}
|
||||
if (proto.has_convolution_dimension_numbers()) {
|
||||
custom_call_instr->set_convolution_dimension_numbers(
|
||||
proto.convolution_dimension_numbers());
|
||||
|
@ -1328,19 +1328,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
|
||||
options.print_large_constants())) {
|
||||
// Literal::ToString emits multidimensional arrays over multiple
|
||||
// lines. Compact this into one line by stripping out white space.
|
||||
string tmp = literal().ToStringWithoutShape();
|
||||
std::replace(tmp.begin(), tmp.end(), '\n', ' ');
|
||||
std::vector<string> v = absl::StrSplit(tmp, ' ');
|
||||
bool first = true;
|
||||
// Concatenate elements in "v" with spaces separating them, but ignoring
|
||||
// empty entries.
|
||||
for (const auto& s : v) {
|
||||
if (s.empty()) {
|
||||
continue;
|
||||
}
|
||||
StrAppend(&operands, (first ? "" : " "), s);
|
||||
first = false;
|
||||
}
|
||||
operands = literal_->ToStringWithoutShapeOneline();
|
||||
} else {
|
||||
// Do not show large constants or tuples.
|
||||
operands = "{...}";
|
||||
@ -2441,6 +2429,9 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
|
||||
}
|
||||
}
|
||||
proto.set_custom_call_has_side_effect(custom_call_has_side_effect_);
|
||||
if (literal_.has_value()) {
|
||||
*proto.mutable_literal() = literal_->ToProto();
|
||||
}
|
||||
for (const auto& pair : output_to_operand_aliasing_) {
|
||||
auto aliasing = proto.add_custom_call_output_operand_aliasing();
|
||||
aliasing->set_operand_index(pair.second.first);
|
||||
@ -2495,6 +2486,9 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
|
||||
if (custom_call_has_side_effect_) {
|
||||
extra.push_back("custom_call_has_side_effect=true");
|
||||
}
|
||||
if (literal_.has_value()) {
|
||||
extra.push_back(StrCat("literal=(", literal_->ToStringOneline(), ")"));
|
||||
}
|
||||
if (!output_to_operand_aliasing_.empty()) {
|
||||
std::vector<string> pair_strings;
|
||||
for (const auto& pair : output_to_operand_aliasing_) {
|
||||
@ -2571,6 +2565,13 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (HasLiteral() == casted_other.HasLiteral()) {
|
||||
if (HasLiteral() && literal() == casted_other.literal()) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Note: backend_config comparison is done in Identical, which is the
|
||||
// intended/exposed way to compare computations, and so not repeated here.
|
||||
@ -2593,6 +2594,9 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
|
||||
if (convolution_dimension_numbers_ != nullptr) {
|
||||
cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
|
||||
}
|
||||
if (HasLiteral()) {
|
||||
cloned->set_literal(literal().Clone());
|
||||
}
|
||||
cloned->set_feature_group_count(feature_group_count_);
|
||||
cloned->set_batch_group_count(batch_group_count_);
|
||||
cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
|
||||
|
@ -1466,6 +1466,13 @@ class HloCustomCallInstruction : public HloInstruction {
|
||||
padding_type_ = padding_type;
|
||||
}
|
||||
|
||||
// Returns the literal associated with this instruction.
|
||||
const Literal& literal() const { return *literal_; }
|
||||
// Set the value of literal to a new one.
|
||||
void set_literal(Literal&& literal) { literal_.emplace(std::move(literal)); }
|
||||
// Returns whether there is literal associated with this instruction.
|
||||
bool HasLiteral() const { return literal_.has_value(); }
|
||||
|
||||
const PrecisionConfig& precision_config() const { return precision_config_; }
|
||||
PrecisionConfig* mutable_precision_config() { return &precision_config_; }
|
||||
|
||||
@ -1532,6 +1539,7 @@ class HloCustomCallInstruction : public HloInstruction {
|
||||
// output_to_operand_aliasing().
|
||||
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_to_operand_aliasing_;
|
||||
absl::optional<Literal> literal_;
|
||||
};
|
||||
|
||||
class HloPadInstruction : public HloInstruction {
|
||||
|
@ -253,6 +253,7 @@ class HloParserImpl : public HloParser {
|
||||
bool ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
const std::string& name, LocTy name_loc);
|
||||
bool ParseControlPredecessors(HloInstruction* instruction);
|
||||
bool ParseLiteral(Literal* literal);
|
||||
bool ParseLiteral(Literal* literal, const Shape& shape);
|
||||
bool ParseTupleLiteral(Literal* literal, const Shape& shape);
|
||||
bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
|
||||
@ -307,6 +308,7 @@ class HloParserImpl : public HloParser {
|
||||
kInt32,
|
||||
kFloat,
|
||||
kString,
|
||||
kLiteral,
|
||||
kBracedInt64List,
|
||||
kBracedInt64ListList,
|
||||
kHloComputation,
|
||||
@ -2268,6 +2270,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
|
||||
attrs["padding_type"] = {/*required=*/false, AttrTy::kPaddingType,
|
||||
&padding_type};
|
||||
|
||||
optional<Literal> literal;
|
||||
attrs["literal"] = {/*required=*/false, AttrTy::kLiteral, &literal};
|
||||
optional<std::vector<PrecisionConfig::Precision>> operand_precision;
|
||||
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
|
||||
&operand_precision};
|
||||
@ -2357,6 +2362,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
custom_call_instr->set_output_to_operand_aliasing(
|
||||
std::move(*output_to_operand_aliasing));
|
||||
}
|
||||
if (literal.has_value()) {
|
||||
custom_call_instr->set_literal(std::move(*literal));
|
||||
}
|
||||
PrecisionConfig precision_config;
|
||||
if (operand_precision) {
|
||||
*precision_config.mutable_operand_precision() = {
|
||||
@ -3048,6 +3056,14 @@ bool HloParserImpl::SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParserImpl::ParseLiteral(Literal* literal) {
|
||||
Shape literal_shape;
|
||||
if (!ParseShape(&literal_shape)) {
|
||||
return false;
|
||||
}
|
||||
return ParseLiteral(literal, literal_shape);
|
||||
}
|
||||
|
||||
// literal
|
||||
// ::= tuple
|
||||
// ::= non_tuple
|
||||
@ -3830,6 +3846,21 @@ bool HloParserImpl::ParseAttributeHelper(
|
||||
->emplace(std::move(aliasing_output_operand_pairs));
|
||||
return true;
|
||||
}
|
||||
case AttrTy::kLiteral: {
|
||||
if (!ParseToken(TokKind::kLparen, "expects '(' before literal")) {
|
||||
return false;
|
||||
}
|
||||
Literal result;
|
||||
if (!ParseLiteral(&result)) {
|
||||
return false;
|
||||
}
|
||||
if (!ParseToken(TokKind::kRparen, "expects ')' after literal")) {
|
||||
return false;
|
||||
}
|
||||
static_cast<optional<Literal>*>(attr_out_ptr)
|
||||
->emplace(std::move(result));
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}();
|
||||
if (!success) {
|
||||
|
@ -399,6 +399,32 @@ ENTRY %CustomCall () -> f32[1,2,3] {
|
||||
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", backend_config="this string is opaque"
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
|
||||
// CustomCall with literal.
|
||||
{
|
||||
"CustomCallWithLiteral",
|
||||
R"(HloModule custom_call
|
||||
|
||||
ENTRY %CustomCall () -> f32[1,2,3] {
|
||||
%constant = f32[1]{0} constant({12345})
|
||||
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", literal=(f32[1] {0.1})
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
|
||||
// CustomCall with literal R0.
|
||||
{
|
||||
"CustomCallWithLiteralR0",
|
||||
R"(HloModule custom_call
|
||||
|
||||
ENTRY %CustomCall () -> f32[1,2,3] {
|
||||
%constant = f32[1]{0} constant({12345})
|
||||
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", literal=(f32[] 0.1)
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// reduce window
|
||||
|
Loading…
Reference in New Issue
Block a user