Add a literal field to custom call instruction.

- Add a literal field to custom call instruction.
- Change SetBound custom call to use literal as side data instead of a hand-serialized number.

PiperOrigin-RevId: 358006370
Change-Id: I67727bbe3ce12b082fbcfc7290e57194bd96c29a
This commit is contained in:
Yunxing Dai 2021-02-17 12:20:23 -08:00 committed by TensorFlower Gardener
parent 7ec1faf4d1
commit 09db4abc5b
12 changed files with 183 additions and 49 deletions

View File

@ -134,7 +134,8 @@ StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing) {
output_operand_aliasing,
const Literal* literal) {
if (operand_shapes_with_layout.has_value())
return Unimplemented(
"CustomCall doesn't support operands shapes with layout");
@ -142,6 +143,8 @@ StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
shape, builder_));
TF_RET_CHECK(output_operand_aliasing.empty())
<< "MLIR CustomCallOp does not support output_operand_aliasing yet";
TF_RET_CHECK(literal == nullptr)
<< "MLIR CustomCallOp does not support literal yet";
auto op = builder_.create<mlir::mhlo::CustomCallOp>(
loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
/*has_side_effect=*/builder_.getBoolAttr(has_side_effect),

View File

@ -137,7 +137,8 @@ class MlirHloBuilder : public XlaBuilder {
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing) override;
output_operand_aliasing,
const Literal* literal) override;
StatusOr<XlaOp> ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@ -97,10 +98,11 @@ class XlaSetBoundOp : public XlaOpKernel {
bound_shape.DebugString()));
int64 bound;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("bound", &bound));
xla::XlaOp result = xla::CustomCall(
ctx->builder(), "SetBound", {ctx->Input("input")},
ctx->InputXlaShape("input").ValueOrDie(), absl::StrFormat("%d", bound));
xla::Literal bound_literal = xla::LiteralUtil::CreateR0<int32>(bound);
xla::XlaOp result =
xla::CustomCall(ctx->builder(), "SetBound", {ctx->Input("input")},
ctx->InputXlaShape("input").ValueOrDie(), "", false, {},
&bound_literal);
ctx->SetOutput(0, result);
}
};

View File

@ -1882,7 +1882,8 @@ XlaOp XlaBuilder::CustomCall(
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing) {
output_operand_aliasing,
const Literal* literal) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (absl::StartsWith(call_target_name, "$")) {
return InvalidArgument(
@ -1915,7 +1916,7 @@ XlaOp XlaBuilder::CustomCall(
}
return CustomCallInternal(call_target_name, operands, shape, opaque,
operand_shapes_with_layout, has_side_effect,
output_operand_aliasing);
output_operand_aliasing, literal);
});
}
@ -1925,7 +1926,8 @@ StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing) {
output_operand_aliasing,
const Literal* literal) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.set_custom_call_target(call_target_name);
@ -1936,6 +1938,9 @@ StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
}
}
if (literal != nullptr) {
*instr.mutable_literal() = literal->ToProto();
}
instr.set_custom_call_has_side_effect(has_side_effect);
for (const auto& pair : output_operand_aliasing) {
auto aliasing = instr.add_custom_call_output_operand_aliasing();
@ -1956,7 +1961,8 @@ XlaOp XlaBuilder::CustomCall(
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing) {
output_operand_aliasing,
const Literal* literal) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (absl::StartsWith(call_target_name, "$")) {
@ -1968,6 +1974,9 @@ XlaOp XlaBuilder::CustomCall(
*instr.mutable_shape() = shape.ToProto();
instr.set_custom_call_target(call_target_name);
instr.set_backend_config(opaque);
if (literal != nullptr) {
*instr.mutable_literal() = literal->ToProto();
}
if (operand_shapes_with_layout.has_value()) {
if (!LayoutUtil::HasLayout(shape)) {
return InvalidArgument(
@ -3786,6 +3795,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
HloOpcodeString(HloOpcode::kGetDimensionSize) ||
InstrIsSetBound(instr_proto)) {
int32 constant_value = -1;
HloInstructionProto const_instr;
if (instr_proto->opcode() ==
HloOpcodeString(HloOpcode::kGetDimensionSize)) {
// At this point, BuildConstantSubGraph should never encounter a
@ -3804,18 +3815,14 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
constant_value =
static_cast<int32>(operand_proto->shape().dimensions(dimension));
}
Literal literal = LiteralUtil::CreateR0(constant_value);
*const_instr.mutable_literal() = literal.ToProto();
*const_instr.mutable_shape() = literal.shape().ToProto();
} else {
TF_RET_CHECK(
absl::SimpleAtoi(instr_proto->backend_config(), &constant_value));
*const_instr.mutable_literal() = instr_proto->literal();
*const_instr.mutable_shape() = instr_proto->shape();
}
Literal literal = LiteralUtil::CreateR0(constant_value);
HloInstructionProto const_instr;
*const_instr.mutable_shape() = literal.shape().ToProto();
*const_instr.mutable_literal() = literal.ToProto();
*const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
const_instr.set_id(handle);
*const_instr.mutable_name() =
GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id());
@ -3866,7 +3873,6 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
}
}
*module->add_computations() = std::move(entry);
return std::move(computation);
}
@ -4459,10 +4465,11 @@ XlaOp CustomCall(
absl::Span<const XlaOp> operands, const Shape& shape, const string& opaque,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing) {
output_operand_aliasing,
const Literal* literal) {
return builder->CustomCall(call_target_name, operands, shape, opaque,
/*operand_shapes_with_layout=*/absl::nullopt,
has_side_effect, output_operand_aliasing);
has_side_effect, output_operand_aliasing, literal);
}
XlaOp CustomCallWithComputation(
@ -4470,11 +4477,12 @@ XlaOp CustomCallWithComputation(
absl::Span<const XlaOp> operands, const XlaComputation& computation,
const Shape& shape, const string& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing) {
output_operand_aliasing,
const Literal* literal) {
return builder->CustomCall(call_target_name, operands, computation, shape,
opaque,
/*operand_shapes_with_layout=*/absl::nullopt,
has_side_effect, output_operand_aliasing);
has_side_effect, output_operand_aliasing, literal);
}
XlaOp CustomCallWithLayout(
@ -4483,10 +4491,11 @@ XlaOp CustomCallWithLayout(
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing) {
output_operand_aliasing,
const Literal* literal) {
return builder->CustomCall(call_target_name, operands, shape, opaque,
operand_shapes_with_layout, has_side_effect,
output_operand_aliasing);
output_operand_aliasing, literal);
}
XlaOp Complex(const XlaOp lhs, const XlaOp rhs,

View File

@ -655,7 +655,8 @@ class XlaBuilder {
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing);
output_operand_aliasing,
const Literal* literal);
// Internal version of CustomCall without computation that doesn't do op
// specific error handling and expects arguments to be legal. CustomCall
@ -666,7 +667,8 @@ class XlaBuilder {
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing);
output_operand_aliasing,
const Literal* literal);
XlaOp CustomCall(
const string& call_target_name, absl::Span<const XlaOp> operands,
@ -675,7 +677,8 @@ class XlaBuilder {
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing);
output_operand_aliasing,
const Literal* literal);
XlaOp Reduce(XlaOp operand, XlaOp init_value,
const XlaComputation& computation,
@ -1214,20 +1217,23 @@ class XlaBuilder {
absl::Span<const XlaOp> operands, const Shape& shape,
const string& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing);
output_operand_aliasing,
const Literal* literal);
friend XlaOp CustomCallWithComputation(
XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const XlaComputation& computation,
const Shape& shape, const string& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing);
output_operand_aliasing,
const Literal* literal);
friend XlaOp CustomCallWithLayout(
XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing);
output_operand_aliasing,
const Literal* literal);
friend XlaOp Complex(XlaOp real, XlaOp imag,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Conj(XlaOp operand);
@ -2025,7 +2031,8 @@ XlaOp CustomCall(
absl::Span<const XlaOp> operands, const Shape& shape,
const string& opaque = "", bool has_side_effect = false,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing = {});
output_operand_aliasing = {},
const Literal* literal = nullptr);
// Overload which constructs a custom call that applies an Xla computation.
XlaOp CustomCallWithComputation(
@ -2033,7 +2040,8 @@ XlaOp CustomCallWithComputation(
absl::Span<const XlaOp> operands, const XlaComputation& computation,
const Shape& shape, const string& opaque = "", bool has_side_effect = false,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing = {});
output_operand_aliasing = {},
const Literal* literal = nullptr);
// Overload which constructs a custom call with fixed layouts. The operands will
// have the layouts specified by |operand_shapes_with_layout| when provided to
@ -2046,7 +2054,8 @@ XlaOp CustomCallWithLayout(
absl::Span<const Shape> operand_shapes_with_layout,
const string& opaque = "", bool has_side_effect = false,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing = {});
output_operand_aliasing = {},
const Literal* literal = nullptr);
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/index_util.h"
@ -71,6 +72,22 @@ void ConvertEndianShort(char* bytes, int64 size) {
}
}
string CompactOneline(const string& input) {
string result;
std::vector<string> v = absl::StrSplit(input, absl::ByAnyChar("\n "));
bool first = true;
// Concatenate elements in "v" with spaces separating them, but ignoring
// empty entries.
for (const auto& s : v) {
if (s.empty()) {
continue;
}
absl::StrAppend(&result, (first ? "" : " "), s);
first = false;
}
return result;
}
// Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
// able to transparently access the raw 16-bit value contained within.
template <typename T>
@ -1281,6 +1298,10 @@ string LiteralBase::ToString() const {
return absl::StrJoin(pieces, "");
}
string LiteralBase::ToStringOneline() const {
return CompactOneline(ToString());
}
string LiteralBase::ToStringWithoutShape() const {
std::vector<string> pieces;
CHECK(LayoutUtil::HasLayout(this->shape()));
@ -1289,6 +1310,10 @@ string LiteralBase::ToStringWithoutShape() const {
return absl::StrJoin(pieces, "");
}
string LiteralBase::ToStringWithoutShapeOneline() const {
return CompactOneline(ToStringWithoutShape());
}
string LiteralBase::ToStringWithLayout() const {
std::vector<string> pieces;
CHECK(LayoutUtil::HasLayout(this->shape()));

View File

@ -94,10 +94,18 @@ class LiteralBase {
// element Literals.
string ToString() const;
// Similar to ToString, but return the result in a compact
// one-line form.
string ToStringOneline() const;
// Returns a string representation of the literal value which does *not*
// include the shape string.
string ToStringWithoutShape() const;
// Similar to ToStringWithoutShape, but return the result in a compact
// one-line form.
string ToStringWithoutShapeOneline() const;
// Returns a string representation of the literal value which includes the
// shape string with its layout.does *not* include the shape string.
string ToStringWithLayout() const;

View File

@ -328,7 +328,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction = CreateConstant(std::move(literal));
// Literal's shape may have no/different tiling info.
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
instruction->shape(), shape));
instruction->shape(), shape))
<< instruction->shape().ToString(true) << " vs "
<< shape.ToString(true);
*instruction->mutable_shape() = shape;
} else {
instruction = absl::make_unique<HloConstantInstruction>(shape);
@ -578,6 +580,12 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
if (proto.has_window()) {
custom_call_instr->set_window(proto.window());
}
if (proto.has_literal()) {
TF_ASSIGN_OR_RETURN(
auto literal,
Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
custom_call_instr->set_literal(std::move(literal));
}
if (proto.has_convolution_dimension_numbers()) {
custom_call_instr->set_convolution_dimension_numbers(
proto.convolution_dimension_numbers());

View File

@ -1328,19 +1328,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
options.print_large_constants())) {
// Literal::ToString emits multidimensional arrays over multiple
// lines. Compact this into one line by stripping out white space.
string tmp = literal().ToStringWithoutShape();
std::replace(tmp.begin(), tmp.end(), '\n', ' ');
std::vector<string> v = absl::StrSplit(tmp, ' ');
bool first = true;
// Concatenate elements in "v" with spaces separating them, but ignoring
// empty entries.
for (const auto& s : v) {
if (s.empty()) {
continue;
}
StrAppend(&operands, (first ? "" : " "), s);
first = false;
}
operands = literal_->ToStringWithoutShapeOneline();
} else {
// Do not show large constants or tuples.
operands = "{...}";
@ -2441,6 +2429,9 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
}
}
proto.set_custom_call_has_side_effect(custom_call_has_side_effect_);
if (literal_.has_value()) {
*proto.mutable_literal() = literal_->ToProto();
}
for (const auto& pair : output_to_operand_aliasing_) {
auto aliasing = proto.add_custom_call_output_operand_aliasing();
aliasing->set_operand_index(pair.second.first);
@ -2495,6 +2486,9 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
if (custom_call_has_side_effect_) {
extra.push_back("custom_call_has_side_effect=true");
}
if (literal_.has_value()) {
extra.push_back(StrCat("literal=(", literal_->ToStringOneline(), ")"));
}
if (!output_to_operand_aliasing_.empty()) {
std::vector<string> pair_strings;
for (const auto& pair : output_to_operand_aliasing_) {
@ -2571,6 +2565,13 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
return false;
}
}
if (HasLiteral() == casted_other.HasLiteral()) {
if (HasLiteral() && literal() == casted_other.literal()) {
return false;
}
} else {
return true;
}
// Note: backend_config comparison is done in Identical, which is the
// intended/exposed way to compare computations, and so not repeated here.
@ -2593,6 +2594,9 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
if (convolution_dimension_numbers_ != nullptr) {
cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
}
if (HasLiteral()) {
cloned->set_literal(literal().Clone());
}
cloned->set_feature_group_count(feature_group_count_);
cloned->set_batch_group_count(batch_group_count_);
cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);

View File

@ -1466,6 +1466,13 @@ class HloCustomCallInstruction : public HloInstruction {
padding_type_ = padding_type;
}
// Returns the literal associated with this instruction.
const Literal& literal() const { return *literal_; }
// Set the value of literal to a new one.
void set_literal(Literal&& literal) { literal_.emplace(std::move(literal)); }
// Returns whether there is literal associated with this instruction.
bool HasLiteral() const { return literal_.has_value(); }
const PrecisionConfig& precision_config() const { return precision_config_; }
PrecisionConfig* mutable_precision_config() { return &precision_config_; }
@ -1532,6 +1539,7 @@ class HloCustomCallInstruction : public HloInstruction {
// output_to_operand_aliasing().
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_to_operand_aliasing_;
absl::optional<Literal> literal_;
};
class HloPadInstruction : public HloInstruction {

View File

@ -253,6 +253,7 @@ class HloParserImpl : public HloParser {
bool ParseInstructionRhs(HloComputation::Builder* builder,
const std::string& name, LocTy name_loc);
bool ParseControlPredecessors(HloInstruction* instruction);
bool ParseLiteral(Literal* literal);
bool ParseLiteral(Literal* literal, const Shape& shape);
bool ParseTupleLiteral(Literal* literal, const Shape& shape);
bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
@ -307,6 +308,7 @@ class HloParserImpl : public HloParser {
kInt32,
kFloat,
kString,
kLiteral,
kBracedInt64List,
kBracedInt64ListList,
kHloComputation,
@ -2268,6 +2270,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
attrs["padding_type"] = {/*required=*/false, AttrTy::kPaddingType,
&padding_type};
optional<Literal> literal;
attrs["literal"] = {/*required=*/false, AttrTy::kLiteral, &literal};
optional<std::vector<PrecisionConfig::Precision>> operand_precision;
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
&operand_precision};
@ -2357,6 +2362,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
custom_call_instr->set_output_to_operand_aliasing(
std::move(*output_to_operand_aliasing));
}
if (literal.has_value()) {
custom_call_instr->set_literal(std::move(*literal));
}
PrecisionConfig precision_config;
if (operand_precision) {
*precision_config.mutable_operand_precision() = {
@ -3048,6 +3056,14 @@ bool HloParserImpl::SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
return true;
}
bool HloParserImpl::ParseLiteral(Literal* literal) {
Shape literal_shape;
if (!ParseShape(&literal_shape)) {
return false;
}
return ParseLiteral(literal, literal_shape);
}
// literal
// ::= tuple
// ::= non_tuple
@ -3830,6 +3846,21 @@ bool HloParserImpl::ParseAttributeHelper(
->emplace(std::move(aliasing_output_operand_pairs));
return true;
}
case AttrTy::kLiteral: {
if (!ParseToken(TokKind::kLparen, "expects '(' before literal")) {
return false;
}
Literal result;
if (!ParseLiteral(&result)) {
return false;
}
if (!ParseToken(TokKind::kRparen, "expects ')' after literal")) {
return false;
}
static_cast<optional<Literal>*>(attr_out_ptr)
->emplace(std::move(result));
return true;
}
}
}();
if (!success) {

View File

@ -399,6 +399,32 @@ ENTRY %CustomCall () -> f32[1,2,3] {
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", backend_config="this string is opaque"
}
)"
},
// CustomCall with literal.
{
"CustomCallWithLiteral",
R"(HloModule custom_call
ENTRY %CustomCall () -> f32[1,2,3] {
%constant = f32[1]{0} constant({12345})
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", literal=(f32[1] {0.1})
}
)"
},
// CustomCall with literal R0.
{
"CustomCallWithLiteralR0",
R"(HloModule custom_call
ENTRY %CustomCall () -> f32[1,2,3] {
%constant = f32[1]{0} constant({12345})
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", literal=(f32[] 0.1)
}
)"
},
// reduce window