From 09db4abc5b947bb8956f9f37316035b8b912b22b Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Wed, 17 Feb 2021 12:20:23 -0800 Subject: [PATCH] Add a literal field to custom call instruction. - Add a literal field to custom call instruction. - Change SetBound custom call to use literal as side data instead of a hand-serialized number. PiperOrigin-RevId: 358006370 Change-Id: I67727bbe3ce12b082fbcfc7290e57194bd96c29a --- .../compiler/mlir/xla/ir/mlir_hlo_builder.cc | 5 +- .../compiler/mlir/xla/ir/mlir_hlo_builder.h | 3 +- .../compiler/tf2xla/kernels/shape_op.cc | 10 ++-- tensorflow/compiler/xla/client/xla_builder.cc | 49 +++++++++++-------- tensorflow/compiler/xla/client/xla_builder.h | 27 ++++++---- tensorflow/compiler/xla/literal.cc | 25 ++++++++++ tensorflow/compiler/xla/literal.h | 8 +++ .../compiler/xla/service/hlo_instruction.cc | 10 +++- .../compiler/xla/service/hlo_instructions.cc | 30 +++++++----- .../compiler/xla/service/hlo_instructions.h | 8 +++ tensorflow/compiler/xla/service/hlo_parser.cc | 31 ++++++++++++ .../compiler/xla/service/hlo_parser_test.cc | 26 ++++++++++ 12 files changed, 183 insertions(+), 49 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 9ff44c04fdc..90d74efb401 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -134,7 +134,8 @@ StatusOr MlirHloBuilder::CustomCallInternal( absl::optional> operand_shapes_with_layout, bool has_side_effect, absl::Span>> - output_operand_aliasing) { + output_operand_aliasing, + const Literal* literal) { if (operand_shapes_with_layout.has_value()) return Unimplemented( "CustomCall doesn't support operands shapes with layout"); @@ -142,6 +143,8 @@ StatusOr MlirHloBuilder::CustomCallInternal( shape, builder_)); TF_RET_CHECK(output_operand_aliasing.empty()) << "MLIR CustomCallOp does not support output_operand_aliasing yet"; + TF_RET_CHECK(literal == nullptr) + << "MLIR CustomCallOp does not support literal yet"; auto op = builder_.create( loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name), /*has_side_effect=*/builder_.getBoolAttr(has_side_effect), diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index cc95b58cae0..2935089b18a 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -137,7 +137,8 @@ class MlirHloBuilder : public XlaBuilder { absl::optional> operand_shapes_with_layout, bool has_side_effect, absl::Span>> - output_operand_aliasing) override; + output_operand_aliasing, + const Literal* literal) override; StatusOr ReduceInternal( const Shape& shape, absl::Span all_operands, diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 99f50101ee0..22ade147198 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -97,10 +98,11 @@ class XlaSetBoundOp : public XlaOpKernel { bound_shape.DebugString())); int64 bound; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("bound", &bound)); - - xla::XlaOp result = xla::CustomCall( - ctx->builder(), "SetBound", {ctx->Input("input")}, - ctx->InputXlaShape("input").ValueOrDie(), absl::StrFormat("%d", bound)); + xla::Literal bound_literal = xla::LiteralUtil::CreateR0(bound); + xla::XlaOp result = + xla::CustomCall(ctx->builder(), "SetBound", {ctx->Input("input")}, + ctx->InputXlaShape("input").ValueOrDie(), "", false, {}, + &bound_literal); ctx->SetOutput(0, result); } }; diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 35cd1c25b7d..b8bfb7e553d 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1882,7 +1882,8 @@ XlaOp XlaBuilder::CustomCall( absl::optional> operand_shapes_with_layout, bool has_side_effect, absl::Span>> - output_operand_aliasing) { + output_operand_aliasing, + const Literal* literal) { return ReportErrorOrReturn([&]() -> StatusOr { if (absl::StartsWith(call_target_name, "$")) { return InvalidArgument( @@ -1915,7 +1916,7 @@ XlaOp XlaBuilder::CustomCall( } return CustomCallInternal(call_target_name, operands, shape, opaque, operand_shapes_with_layout, has_side_effect, - output_operand_aliasing); + output_operand_aliasing, literal); }); } @@ -1925,7 +1926,8 @@ StatusOr XlaBuilder::CustomCallInternal( absl::optional> operand_shapes_with_layout, bool has_side_effect, absl::Span>> - output_operand_aliasing) { + output_operand_aliasing, + const Literal* literal) { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); instr.set_custom_call_target(call_target_name); @@ -1936,6 +1938,9 @@ StatusOr XlaBuilder::CustomCallInternal( *instr.add_operand_shapes_with_layout() = operand_shape.ToProto(); } } + if (literal != nullptr) { + *instr.mutable_literal() = literal->ToProto(); + } instr.set_custom_call_has_side_effect(has_side_effect); for (const auto& pair : output_operand_aliasing) { auto aliasing = instr.add_custom_call_output_operand_aliasing(); @@ -1956,7 +1961,8 @@ XlaOp XlaBuilder::CustomCall( absl::optional> operand_shapes_with_layout, bool has_side_effect, absl::Span>> - output_operand_aliasing) { + output_operand_aliasing, + const Literal* literal) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (absl::StartsWith(call_target_name, "$")) { @@ -1968,6 +1974,9 @@ XlaOp XlaBuilder::CustomCall( *instr.mutable_shape() = shape.ToProto(); instr.set_custom_call_target(call_target_name); instr.set_backend_config(opaque); + if (literal != nullptr) { + *instr.mutable_literal() = literal->ToProto(); + } if (operand_shapes_with_layout.has_value()) { if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument( @@ -3786,6 +3795,8 @@ StatusOr XlaBuilder::BuildConstantSubGraph( HloOpcodeString(HloOpcode::kGetDimensionSize) || InstrIsSetBound(instr_proto)) { int32 constant_value = -1; + HloInstructionProto const_instr; + if (instr_proto->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize)) { // At this point, BuildConstantSubGraph should never encounter a @@ -3804,18 +3815,14 @@ StatusOr XlaBuilder::BuildConstantSubGraph( constant_value = static_cast(operand_proto->shape().dimensions(dimension)); } + Literal literal = LiteralUtil::CreateR0(constant_value); + *const_instr.mutable_literal() = literal.ToProto(); + *const_instr.mutable_shape() = literal.shape().ToProto(); } else { - TF_RET_CHECK( - absl::SimpleAtoi(instr_proto->backend_config(), &constant_value)); + *const_instr.mutable_literal() = instr_proto->literal(); + *const_instr.mutable_shape() = instr_proto->shape(); } - - Literal literal = LiteralUtil::CreateR0(constant_value); - - HloInstructionProto const_instr; - *const_instr.mutable_shape() = literal.shape().ToProto(); - *const_instr.mutable_literal() = literal.ToProto(); *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant); - const_instr.set_id(handle); *const_instr.mutable_name() = GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id()); @@ -3866,7 +3873,6 @@ StatusOr XlaBuilder::BuildConstantSubGraph( } } *module->add_computations() = std::move(entry); - return std::move(computation); } @@ -4459,10 +4465,11 @@ XlaOp CustomCall( absl::Span operands, const Shape& shape, const string& opaque, bool has_side_effect, absl::Span>> - output_operand_aliasing) { + output_operand_aliasing, + const Literal* literal) { return builder->CustomCall(call_target_name, operands, shape, opaque, /*operand_shapes_with_layout=*/absl::nullopt, - has_side_effect, output_operand_aliasing); + has_side_effect, output_operand_aliasing, literal); } XlaOp CustomCallWithComputation( @@ -4470,11 +4477,12 @@ XlaOp CustomCallWithComputation( absl::Span operands, const XlaComputation& computation, const Shape& shape, const string& opaque, bool has_side_effect, absl::Span>> - output_operand_aliasing) { + output_operand_aliasing, + const Literal* literal) { return builder->CustomCall(call_target_name, operands, computation, shape, opaque, /*operand_shapes_with_layout=*/absl::nullopt, - has_side_effect, output_operand_aliasing); + has_side_effect, output_operand_aliasing, literal); } XlaOp CustomCallWithLayout( @@ -4483,10 +4491,11 @@ XlaOp CustomCallWithLayout( absl::Span operand_shapes_with_layout, const string& opaque, bool has_side_effect, absl::Span>> - output_operand_aliasing) { + output_operand_aliasing, + const Literal* literal) { return builder->CustomCall(call_target_name, operands, shape, opaque, operand_shapes_with_layout, has_side_effect, - output_operand_aliasing); + output_operand_aliasing, literal); } XlaOp Complex(const XlaOp lhs, const XlaOp rhs, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index cb212e17d40..cc1806be9a4 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -655,7 +655,8 @@ class XlaBuilder { absl::optional> operand_shapes_with_layout, bool has_side_effect, absl::Span>> - output_operand_aliasing); + output_operand_aliasing, + const Literal* literal); // Internal version of CustomCall without computation that doesn't do op // specific error handling and expects arguments to be legal. CustomCall @@ -666,7 +667,8 @@ class XlaBuilder { absl::optional> operand_shapes_with_layout, bool has_side_effect, absl::Span>> - output_operand_aliasing); + output_operand_aliasing, + const Literal* literal); XlaOp CustomCall( const string& call_target_name, absl::Span operands, @@ -675,7 +677,8 @@ class XlaBuilder { absl::optional> operand_shapes_with_layout, bool has_side_effect, absl::Span>> - output_operand_aliasing); + output_operand_aliasing, + const Literal* literal); XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation, @@ -1214,20 +1217,23 @@ class XlaBuilder { absl::Span operands, const Shape& shape, const string& opaque, bool has_side_effect, absl::Span>> - output_operand_aliasing); + output_operand_aliasing, + const Literal* literal); friend XlaOp CustomCallWithComputation( XlaBuilder* builder, const string& call_target_name, absl::Span operands, const XlaComputation& computation, const Shape& shape, const string& opaque, bool has_side_effect, absl::Span>> - output_operand_aliasing); + output_operand_aliasing, + const Literal* literal); friend XlaOp CustomCallWithLayout( XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, absl::Span operand_shapes_with_layout, const string& opaque, bool has_side_effect, absl::Span>> - output_operand_aliasing); + output_operand_aliasing, + const Literal* literal); friend XlaOp Complex(XlaOp real, XlaOp imag, absl::Span broadcast_dimensions); friend XlaOp Conj(XlaOp operand); @@ -2025,7 +2031,8 @@ XlaOp CustomCall( absl::Span operands, const Shape& shape, const string& opaque = "", bool has_side_effect = false, absl::Span>> - output_operand_aliasing = {}); + output_operand_aliasing = {}, + const Literal* literal = nullptr); // Overload which constructs a custom call that applies an Xla computation. XlaOp CustomCallWithComputation( @@ -2033,7 +2040,8 @@ XlaOp CustomCallWithComputation( absl::Span operands, const XlaComputation& computation, const Shape& shape, const string& opaque = "", bool has_side_effect = false, absl::Span>> - output_operand_aliasing = {}); + output_operand_aliasing = {}, + const Literal* literal = nullptr); // Overload which constructs a custom call with fixed layouts. The operands will // have the layouts specified by |operand_shapes_with_layout| when provided to @@ -2046,7 +2054,8 @@ XlaOp CustomCallWithLayout( absl::Span operand_shapes_with_layout, const string& opaque = "", bool has_side_effect = false, absl::Span>> - output_operand_aliasing = {}); + output_operand_aliasing = {}, + const Literal* literal = nullptr); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index d15f78c41e0..57a2ec131f7 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/index_util.h" @@ -71,6 +72,22 @@ void ConvertEndianShort(char* bytes, int64 size) { } } +string CompactOneline(const string& input) { + string result; + std::vector v = absl::StrSplit(input, absl::ByAnyChar("\n ")); + bool first = true; + // Concatenate elements in "v" with spaces separating them, but ignoring + // empty entries. + for (const auto& s : v) { + if (s.empty()) { + continue; + } + absl::StrAppend(&result, (first ? "" : " "), s); + first = false; + } + return result; +} + // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be // able to transparently access the raw 16-bit value contained within. template @@ -1281,6 +1298,10 @@ string LiteralBase::ToString() const { return absl::StrJoin(pieces, ""); } +string LiteralBase::ToStringOneline() const { + return CompactOneline(ToString()); +} + string LiteralBase::ToStringWithoutShape() const { std::vector pieces; CHECK(LayoutUtil::HasLayout(this->shape())); @@ -1289,6 +1310,10 @@ string LiteralBase::ToStringWithoutShape() const { return absl::StrJoin(pieces, ""); } +string LiteralBase::ToStringWithoutShapeOneline() const { + return CompactOneline(ToStringWithoutShape()); +} + string LiteralBase::ToStringWithLayout() const { std::vector pieces; CHECK(LayoutUtil::HasLayout(this->shape())); diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index 1ee71618887..4147436330f 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -94,10 +94,18 @@ class LiteralBase { // element Literals. string ToString() const; + // Similar to ToString, but return the result in a compact + // one-line form. + string ToStringOneline() const; + // Returns a string representation of the literal value which does *not* // include the shape string. string ToStringWithoutShape() const; + // Similar to ToStringWithoutShape, but return the result in a compact + // one-line form. + string ToStringWithoutShapeOneline() const; + // Returns a string representation of the literal value which includes the // shape string with its layout.does *not* include the shape string. string ToStringWithLayout() const; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index c3951d52aa8..9b7679e6f46 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -328,7 +328,9 @@ StatusOr> HloInstruction::CreateFromProto( instruction = CreateConstant(std::move(literal)); // Literal's shape may have no/different tiling info. TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()( - instruction->shape(), shape)); + instruction->shape(), shape)) + << instruction->shape().ToString(true) << " vs " + << shape.ToString(true); *instruction->mutable_shape() = shape; } else { instruction = absl::make_unique(shape); @@ -578,6 +580,12 @@ StatusOr> HloInstruction::CreateFromProto( if (proto.has_window()) { custom_call_instr->set_window(proto.window()); } + if (proto.has_literal()) { + TF_ASSIGN_OR_RETURN( + auto literal, + Literal::CreateFromProto(proto.literal(), prohibit_empty_literal)); + custom_call_instr->set_literal(std::move(literal)); + } if (proto.has_convolution_dimension_numbers()) { custom_call_instr->set_convolution_dimension_numbers( proto.convolution_dimension_numbers()); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index f2a7fe188c8..7a77e86657d 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1328,19 +1328,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( options.print_large_constants())) { // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. - string tmp = literal().ToStringWithoutShape(); - std::replace(tmp.begin(), tmp.end(), '\n', ' '); - std::vector v = absl::StrSplit(tmp, ' '); - bool first = true; - // Concatenate elements in "v" with spaces separating them, but ignoring - // empty entries. - for (const auto& s : v) { - if (s.empty()) { - continue; - } - StrAppend(&operands, (first ? "" : " "), s); - first = false; - } + operands = literal_->ToStringWithoutShapeOneline(); } else { // Do not show large constants or tuples. operands = "{...}"; @@ -2441,6 +2429,9 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { } } proto.set_custom_call_has_side_effect(custom_call_has_side_effect_); + if (literal_.has_value()) { + *proto.mutable_literal() = literal_->ToProto(); + } for (const auto& pair : output_to_operand_aliasing_) { auto aliasing = proto.add_custom_call_output_operand_aliasing(); aliasing->set_operand_index(pair.second.first); @@ -2495,6 +2486,9 @@ std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( if (custom_call_has_side_effect_) { extra.push_back("custom_call_has_side_effect=true"); } + if (literal_.has_value()) { + extra.push_back(StrCat("literal=(", literal_->ToStringOneline(), ")")); + } if (!output_to_operand_aliasing_.empty()) { std::vector pair_strings; for (const auto& pair : output_to_operand_aliasing_) { @@ -2571,6 +2565,13 @@ bool HloCustomCallInstruction::IdenticalSlowPath( return false; } } + if (HasLiteral() == casted_other.HasLiteral()) { + if (HasLiteral() && literal() == casted_other.literal()) { + return false; + } + } else { + return true; + } // Note: backend_config comparison is done in Identical, which is the // intended/exposed way to compare computations, and so not repeated here. @@ -2593,6 +2594,9 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( if (convolution_dimension_numbers_ != nullptr) { cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_); } + if (HasLiteral()) { + cloned->set_literal(literal().Clone()); + } cloned->set_feature_group_count(feature_group_count_); cloned->set_batch_group_count(batch_group_count_); cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index bacbce15206..4df82e186fa 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1466,6 +1466,13 @@ class HloCustomCallInstruction : public HloInstruction { padding_type_ = padding_type; } + // Returns the literal associated with this instruction. + const Literal& literal() const { return *literal_; } + // Set the value of literal to a new one. + void set_literal(Literal&& literal) { literal_.emplace(std::move(literal)); } + // Returns whether there is literal associated with this instruction. + bool HasLiteral() const { return literal_.has_value(); } + const PrecisionConfig& precision_config() const { return precision_config_; } PrecisionConfig* mutable_precision_config() { return &precision_config_; } @@ -1532,6 +1539,7 @@ class HloCustomCallInstruction : public HloInstruction { // output_to_operand_aliasing(). std::vector>> output_to_operand_aliasing_; + absl::optional literal_; }; class HloPadInstruction : public HloInstruction { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 3341864d50d..6f5a877ecc9 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -253,6 +253,7 @@ class HloParserImpl : public HloParser { bool ParseInstructionRhs(HloComputation::Builder* builder, const std::string& name, LocTy name_loc); bool ParseControlPredecessors(HloInstruction* instruction); + bool ParseLiteral(Literal* literal); bool ParseLiteral(Literal* literal, const Shape& shape); bool ParseTupleLiteral(Literal* literal, const Shape& shape); bool ParseNonTupleLiteral(Literal* literal, const Shape& shape); @@ -307,6 +308,7 @@ class HloParserImpl : public HloParser { kInt32, kFloat, kString, + kLiteral, kBracedInt64List, kBracedInt64ListList, kHloComputation, @@ -2268,6 +2270,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, attrs["padding_type"] = {/*required=*/false, AttrTy::kPaddingType, &padding_type}; + + optional literal; + attrs["literal"] = {/*required=*/false, AttrTy::kLiteral, &literal}; optional> operand_precision; attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, &operand_precision}; @@ -2357,6 +2362,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, custom_call_instr->set_output_to_operand_aliasing( std::move(*output_to_operand_aliasing)); } + if (literal.has_value()) { + custom_call_instr->set_literal(std::move(*literal)); + } PrecisionConfig precision_config; if (operand_precision) { *precision_config.mutable_operand_precision() = { @@ -3048,6 +3056,14 @@ bool HloParserImpl::SetValueInLiteralHelper(LocTy loc, ParsedElemT value, return true; } +bool HloParserImpl::ParseLiteral(Literal* literal) { + Shape literal_shape; + if (!ParseShape(&literal_shape)) { + return false; + } + return ParseLiteral(literal, literal_shape); +} + // literal // ::= tuple // ::= non_tuple @@ -3830,6 +3846,21 @@ bool HloParserImpl::ParseAttributeHelper( ->emplace(std::move(aliasing_output_operand_pairs)); return true; } + case AttrTy::kLiteral: { + if (!ParseToken(TokKind::kLparen, "expects '(' before literal")) { + return false; + } + Literal result; + if (!ParseLiteral(&result)) { + return false; + } + if (!ParseToken(TokKind::kRparen, "expects ')' after literal")) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(std::move(result)); + return true; + } } }(); if (!success) { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 4dac92b59e1..696f8097d83 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -399,6 +399,32 @@ ENTRY %CustomCall () -> f32[1,2,3] { ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", backend_config="this string is opaque" } +)" +}, + +// CustomCall with literal. +{ +"CustomCallWithLiteral", +R"(HloModule custom_call + +ENTRY %CustomCall () -> f32[1,2,3] { + %constant = f32[1]{0} constant({12345}) + ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", literal=(f32[1] {0.1}) +} + +)" +}, + +// CustomCall with literal R0. +{ +"CustomCallWithLiteralR0", +R"(HloModule custom_call + +ENTRY %CustomCall () -> f32[1,2,3] { + %constant = f32[1]{0} constant({12345}) + ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", literal=(f32[] 0.1) +} + )" }, // reduce window