From 71ac02ec1f8a0af70d6a49447c9b5d49f4115f8a Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Wed, 3 Mar 2021 14:09:30 -0800 Subject: [PATCH] XLA value inference. XLA value inference is designed to be a one stop service that analyzes the attributes of each value in a tensor. The attributes include: - What's the upper-bound of each value in a tensor. - What's the lower-bound of each value in a tensor. - What's the constant value of each tensor. - Whether or not each value in a tensor is dynamic. This cl implements the dynamism inference part, which replaces the old one in XLA builder that has become a bit too complex to maintain over time. PiperOrigin-RevId: 360753561 Change-Id: I55c6ec9abf43f1502c5173dafcc47ac6c41bea09 --- tensorflow/compiler/tf2xla/BUILD | 1 + tensorflow/compiler/tf2xla/xla_expression.cc | 10 +- tensorflow/compiler/xla/client/BUILD | 19 + .../compiler/xla/client/value_inference.cc | 384 +++++++++++++++++ .../compiler/xla/client/value_inference.h | 70 +++ tensorflow/compiler/xla/client/xla_builder.cc | 397 ------------------ tensorflow/compiler/xla/client/xla_builder.h | 28 +- .../compiler/xla/service/hlo_instruction.h | 3 +- tensorflow/compiler/xla/tests/BUILD | 1 + .../xla/tests/dynamism_inference_test.cc | 372 +++++++--------- 10 files changed, 637 insertions(+), 648 deletions(-) create mode 100644 tensorflow/compiler/xla/client/value_inference.cc create mode 100644 tensorflow/compiler/xla/client/value_inference.h diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 78e054dcc22..05008285e14 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -485,6 +485,7 @@ cc_library( ":xla_resource", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:value_inference", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc index 498b3f80d41..099d54685ca 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.cc +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/xla/client/value_inference.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" @@ -136,18 +137,15 @@ xla::StatusOr XlaExpression::ResolveDynamism( if (!client) return errors::InvalidArgument("client is required to resolve constant"); - TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph, - handle().builder()->BuildDynamicInferenceGraph(handle())); - TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape()); // The XLA layout is specified minor to major, and TensorFlow uses a major to // minor order. std::vector layout_indices(shape.dims()); std::iota(layout_indices.rbegin(), layout_indices.rend(), 0); - xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices); - TF_ASSIGN_OR_RETURN(xla::Literal literal, - client->ComputeConstant(constant_graph, &layout)); + xla::ValueInference value_inference(handle().builder()); + TF_ASSIGN_OR_RETURN(xla::LiteralSlice literal, + value_inference.AnalyzeIsDynamic(handle())); Tensor tensor(DT_BOOL); TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, DT_BOOL, &tensor)); return tensor; diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 6d3a1261d5d..e673fdea43d 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -211,6 +211,25 @@ cc_library( ], ) +cc_library( + name = "value_inference", + srcs = ["value_inference.cc"], + hdrs = ["value_inference.h"], + visibility = ["//visibility:public"], + deps = [ + ":xla_builder", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_evaluator", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "xla_builder", srcs = ["xla_builder.cc"], diff --git a/tensorflow/compiler/xla/client/value_inference.cc b/tensorflow/compiler/xla/client/value_inference.cc new file mode 100644 index 00000000000..bd773490617 --- /dev/null +++ b/tensorflow/compiler/xla/client/value_inference.cc @@ -0,0 +1,384 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/client/value_inference.h" + +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { +Literal CreatePredLiteral(bool pred, const Shape& reference_shape) { + Literal literal = LiteralUtil::CreateR0(pred); + Literal literal_broadcast = + literal + .Broadcast(ShapeUtil::ChangeElementType(Shape(reference_shape), PRED), + {}) + .ValueOrDie(); + return literal_broadcast; +} + +using GetOperand = std::function(int64 operand_index, + int64 opreand_handle)>; + +// HloProtoEvaluator evaluates an hlo proto and returns a literal. The user has +// to provide operand as literals through the get_operand function. +struct HloProtoEvaluator { + explicit HloProtoEvaluator(HloInstructionProto inst, GetOperand get_operand) + : inst(std::move(inst)), + get_operand(get_operand), + module("EmptyModuleForEvaluation", HloModuleConfig()) {} + + // WithOpCode changes the called computation of the instruction being + // evaluated. + HloProtoEvaluator& WithComputation( + std::unique_ptr new_computation) { + computation = new_computation.get(); + computation->ClearUniqueIdInternal(); + for (HloInstruction* inst : computation->instructions()) { + inst->ClearUniqueIdInternal(); + } + module.AddEmbeddedComputation(std::move(new_computation)); + return *this; + } + + // WithOpCode changes the primitive type of the instruction being evaluated. + HloProtoEvaluator& WithPrimitiveType(PrimitiveType new_primitive_type) { + primitive_type = new_primitive_type; + return *this; + } + + // WithOpCode changes the opcode of the instruction being evaluated. + HloProtoEvaluator& WithOpCode(HloOpcode new_opcode) { + opcode = new_opcode; + return *this; + } + + StatusOr Evaluate() { + // Evaluate the instruction by swapping it's operands with constant + // instructions with given literals. + HloComputation::Builder builder("EmptyComputation"); + absl::flat_hash_map operand_map; + for (int64 i = 0; i < inst.operand_ids_size(); ++i) { + int64 operand_handle = inst.operand_ids(i); + TF_ASSIGN_OR_RETURN(auto literal, get_operand(i, inst.operand_ids(i))); + std::unique_ptr operand = + HloInstruction::CreateConstant(literal.Clone()); + operand_map[operand_handle] = operand.get(); + builder.AddInstruction(std::move(operand)); + } + + if (primitive_type.has_value()) { + *inst.mutable_shape() = ShapeUtil::ChangeElementType( + Shape(inst.shape()), primitive_type.value()) + .ToProto(); + } + if (opcode.has_value()) { + *inst.mutable_opcode() = HloOpcodeString(opcode.value()); + } + absl::flat_hash_map computation_map; + if (inst.called_computation_ids_size() != 0) { + TF_RET_CHECK(inst.called_computation_ids_size() == 1 && + computation != nullptr) + << inst.DebugString(); + computation_map[inst.called_computation_ids(0)] = computation; + } + TF_ASSIGN_OR_RETURN( + auto new_instruction, + HloInstruction::CreateFromProto(inst, operand_map, computation_map)); + new_instruction->ClearUniqueIdInternal(); + builder.AddInstruction(std::move(new_instruction)); + auto computation = builder.Build(); + module.AddEntryComputation(std::move(computation)); + HloEvaluator evaluator; + return evaluator.Evaluate(module.entry_computation()->root_instruction()); + } + + HloInstructionProto inst; + GetOperand get_operand; + HloModule module; + HloComputation* computation = nullptr; + absl::optional primitive_type = absl::nullopt; + absl::optional opcode = absl::nullopt; +}; +} // namespace + +StatusOr ValueInference::AnalyzeConstantLiteral(int64 handle) { + TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, + builder_->LookUpInstructionByHandle(handle)); + if (constant_.contains(handle)) { + return constant_[handle].Clone(); + } + TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode())); + switch (opcode) { + case HloOpcode::kGetDimensionSize: { + int64 dimension = root->dimensions(0); + int64 operand_handle = root->operand_ids(0); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, + builder_->LookUpInstructionByHandle(operand_handle)); + if (operand_proto->shape().is_dynamic_dimension(dimension)) { + return InvalidArgument( + "AnalyzeConstant is called on a GetDimensionSize on dynamic " + "dimension."); + } else { + return LiteralUtil::CreateR0( + operand_proto->shape().dimensions(dimension)); + } + } + // Non functional ops. + case HloOpcode::kRng: + case HloOpcode::kAllReduce: + // TODO(b/33009255): Implement constant folding for cross replica sum. + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kCall: + // TODO(b/32495713): We aren't checking the to_apply computation itself, + // so we conservatively say that computations containing the Call op + // cannot be constant. We cannot set is_functional=false in other similar + // cases since we're already relying on IsConstant to return true. + case HloOpcode::kCustomCall: + case HloOpcode::kWhile: + case HloOpcode::kConditional: + // TODO(b/32495713): We aren't checking the condition and body + // computations themselves. + case HloOpcode::kSend: + case HloOpcode::kRecv: + case HloOpcode::kParameter: { + return InvalidArgument("Can't analyze constant values on instruction %s", + root->DebugString()); + } + case HloOpcode::kReduce: + case HloOpcode::kScatter: + case HloOpcode::kReduceWindow: { + HloComputationProto computation_proto = + builder_->embedded_[root->called_computation_ids(0)]; + TF_ASSIGN_OR_RETURN(auto computation, HloComputation::CreateFromProto( + computation_proto, {})); + return HloProtoEvaluator(*root, + [&](int64 operand_index, int64 operand_handle) { + return AnalyzeConstant(operand_handle); + }) + .WithComputation(std::move(computation)) + .Evaluate(); + } + default: + return HloProtoEvaluator(*root, + [&](int64 operand_index, int64 operand_handle) { + return AnalyzeConstant(operand_handle); + }) + .Evaluate(); + } +} + +StatusOr ValueInference::AnalyzeIsDynamicLiteral(int64 handle) { + TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, + builder_->LookUpInstructionByHandle(handle)); + if (is_dynamic_.contains(handle)) { + return is_dynamic_[handle].Clone(); + } + TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode())); + switch (opcode) { + case HloOpcode::kGetDimensionSize: { + int64 dimension = root->dimensions(0); + int64 operand_handle = root->operand_ids(0); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, + builder_->LookUpInstructionByHandle(operand_handle)); + return LiteralUtil::CreateR0( + operand_proto->shape().is_dynamic_dimension(dimension)); + } + case HloOpcode::kAbs: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kBitcast: + case HloOpcode::kCeil: + case HloOpcode::kCollectivePermuteDone: + case HloOpcode::kCos: + case HloOpcode::kClz: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kFloor: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kNot: + case HloOpcode::kNegate: + case HloOpcode::kPopulationCount: + case HloOpcode::kReal: + case HloOpcode::kRsqrt: + case HloOpcode::kLogistic: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kConvert: + case HloOpcode::kSqrt: + case HloOpcode::kCbrt: + case HloOpcode::kTanh: { + // Forward operand as they don't change if a value is dynamic or static. + int64 operand_handle = root->operand_ids(0); + TF_ASSIGN_OR_RETURN(auto literal, AnalyzeIsDynamic(operand_handle)); + return literal.Clone(); + } + case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kDivide: + case HloOpcode::kComplex: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kSubtract: + case HloOpcode::kCompare: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kXor: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: { + return HloProtoEvaluator(*root, + [&](int64 operand_index, int64 operand_handle) { + return AnalyzeIsDynamic(operand_handle); + }) + .WithPrimitiveType(PRED) + .WithOpCode(HloOpcode::kOr) + .Evaluate(); + } + case HloOpcode::kTuple: + case HloOpcode::kTranspose: + case HloOpcode::kSlice: + case HloOpcode::kBroadcast: + case HloOpcode::kConcatenate: + case HloOpcode::kReshape: + case HloOpcode::kPad: { + return HloProtoEvaluator(*root, + [&](int64 operand_index, int64 operand_handle) { + return AnalyzeIsDynamic(operand_handle); + }) + .WithPrimitiveType(PRED) + .Evaluate(); + } + case HloOpcode::kGetTupleElement: { + int64 operand_handle = root->operand_ids(0); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, + builder_->LookUpInstructionByHandle(operand_handle)); + TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode, + StringToHloOpcode(operand_proto->opcode())); + if (operand_opcode == HloOpcode::kParameter) { + // Don't materialize the whole parameter if it's followed by a GTE. + return CreatePredLiteral(true, Shape(root->shape())); + } + return HloProtoEvaluator(*root, + [&](int64 operand_index, int64 operand_handle) { + return AnalyzeIsDynamic(operand_handle); + }) + .WithPrimitiveType(PRED) + .Evaluate(); + } + + case HloOpcode::kReduce: { + std::vector> operand_storage; + absl::flat_hash_map operand_map; + absl::flat_hash_map computation_map; + + Shape scalar_shape = ShapeUtil::MakeScalarShape(xla::PRED); + HloComputation::Builder b("reduce_or"); + auto lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + b.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kOr, lhs, rhs)); + auto reduce_computation = b.Build(); + return HloProtoEvaluator(*root, + [&](int64 operand_index, int64 operand_handle) { + return AnalyzeIsDynamic(operand_handle); + }) + .WithPrimitiveType(PRED) + .WithComputation(std::move(reduce_computation)) + .Evaluate(); + } + case HloOpcode::kConstant: + case HloOpcode::kIota: { + return CreatePredLiteral(false, Shape(root->shape())); + } + case HloOpcode::kParameter: { + return CreatePredLiteral(true, Shape(root->shape())); + } + case HloOpcode::kSelect: { + if (!AnalyzeConstant(root->operand_ids(0)).ok()) { + // If the predicate operand is not constant, conservatively assume the + // entire result values are dynamic. + return CreatePredLiteral(true, Shape(root->shape())); + } + return HloProtoEvaluator(*root, + [&](int64 operand_index, int64 operand_handle) { + if (operand_index == 0) { + return AnalyzeConstant(operand_handle); + } else { + return AnalyzeIsDynamic(operand_handle); + } + }) + .WithPrimitiveType(PRED) + .Evaluate(); + } + case HloOpcode::kGather: { + if (!AnalyzeConstant(root->operand_ids(1)).ok()) { + // If the index operand is not constant, conservatively assume the + // entire result values are dynamic. + return CreatePredLiteral(true, Shape(root->shape())); + } + return HloProtoEvaluator(*root, + [&](int64 operand_index, int64 operand_handle) { + if (operand_index == 1) { + return AnalyzeConstant(operand_handle); + } else { + return AnalyzeIsDynamic(operand_handle); + } + }) + .WithPrimitiveType(PRED) + .Evaluate(); + } + case HloOpcode::kCustomCall: { + if (root->custom_call_target() == "SetBound") { + return CreatePredLiteral(true, Shape(root->shape())); + } else { + return InvalidArgument( + "Dynamic inferencing on custom call %s is not supported", + root->DebugString()); + } + + break; + } + default: + return Unimplemented("Can't infer upper bound through %s: %s", + root->opcode(), root->DebugString()); + } +} + +StatusOr ValueInference::AnalyzeIsDynamic(int64 handle) { + TF_ASSIGN_OR_RETURN(Literal literal, AnalyzeIsDynamicLiteral(handle)); + is_dynamic_[handle] = std::move(literal); + return LiteralSlice(is_dynamic_[handle]); +} + +StatusOr ValueInference::AnalyzeConstant(int64 handle) { + TF_ASSIGN_OR_RETURN(Literal literal, AnalyzeConstantLiteral(handle)); + constant_[handle] = std::move(literal); + return LiteralSlice(constant_[handle]); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/value_inference.h b/tensorflow/compiler/xla/client/value_inference.h new file mode 100644 index 00000000000..9906df87efc --- /dev/null +++ b/tensorflow/compiler/xla/client/value_inference.h @@ -0,0 +1,70 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_VALUE_INFERENCE_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_VALUE_INFERENCE_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +class ValueInference { + public: + // ValueInference analyzes values in XlaOp answers following questions: + // - What's the upper-bound of each value in a tensor. + // - What's the lower-bound of each value in a tensor. + // - What's the constant value of each tensor. + // - Whether or not each value in a tensor is dynamic. + explicit ValueInference(XlaBuilder* builder) : builder_(builder) {} + StatusOr AnalyzeUpperBound(XlaOp op) { + return Unimplemented("Analyzing upper-bound is not implemented yet."); + } + StatusOr AnalyzeLowerBound(XlaOp op) { + return Unimplemented("Analyzing lower-bound is not implemented yet."); + } + StatusOr AnalyzeIsDynamic(XlaOp op) { + return AnalyzeIsDynamic(op.handle()); + } + StatusOr AnalyzeConstant(XlaOp op) { + return AnalyzeConstant(op.handle()); + } + + private: + StatusOr AnalyzeIsDynamic(int64 handle); + StatusOr AnalyzeConstant(int64 handle); + + StatusOr AnalyzeIsDynamicLiteral(int64 handle); + StatusOr AnalyzeConstantLiteral(int64 handle); + + XlaBuilder* builder_; + // Cache to avoid re-evaluating. Mapping of xla handle to evaluated + // literals. + absl::flat_hash_map upper_bound_; + absl::flat_hash_map lower_bound_; + absl::flat_hash_map is_dynamic_; + absl::flat_hash_map constant_; + HloEvaluator evaluator_; +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_VALUE_INFERENCE_H_ diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 04b87c6b37d..e3842f18a1e 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -83,61 +83,6 @@ void SetProtoIdAndName(T* entry, const string& base_name, char separator, entry->set_name(GetFullName(base_name, separator, id)); } -ShapeProto ConvertShapeProtoToPred(const ShapeProto& shape_proto) { - return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto(); -} - -void SetInstructionAsConstant(HloInstructionProto* instr, int64 id, - const Shape& shape, bool pred) { - Literal literal = LiteralUtil::CreateR0(pred); - Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie(); - *instr->mutable_shape() = shape.ToProto(); - *instr->mutable_literal() = literal_broadcast.ToProto(); - *instr->mutable_opcode() = HloOpcodeString(HloOpcode::kConstant); -} - -// Copy `original_reducer` into a new computation proto with `reducer_id` as new -// id. If `rewrite_into_pred` is true, the instructions in the reducer are -// rewritten into predicate form. -HloComputationProto CopyReducer(int64 reducer_id, - HloComputationProto* original_reducer, - bool rewrite_into_pred, int64* global_id) { - HloComputationProto reducer; - SetProtoIdAndName(&reducer, StrCat("reduce_or"), kNameSeparator, reducer_id); - std::vector operands_id; - for (auto& inst : original_reducer->instructions()) { - // Copy params. - if (StringToHloOpcode(inst.opcode()).ValueOrDie() == - HloOpcode::kParameter) { - HloInstructionProto* new_param = reducer.add_instructions(); - *new_param = inst; - new_param->set_id((*global_id)++); - *new_param->mutable_name() = - GetFullName(inst.name(), '.', new_param->id()); - if (rewrite_into_pred) { - *new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape()); - } - operands_id.push_back(new_param->id()); - } - if (inst.id() == original_reducer->root_id()) { - HloInstructionProto* new_root = reducer.add_instructions(); - *new_root = inst; - new_root->set_id((*global_id)++); - *new_root->mutable_name() = GetFullName(inst.name(), '.', new_root->id()); - if (rewrite_into_pred) { - *new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape()); - *new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr); - } - new_root->clear_operand_ids(); - for (int64 operand_id : operands_id) { - new_root->add_operand_ids(operand_id); - } - reducer.set_root_id(new_root->id()); - } - } - return reducer; -} - bool InstrIsSetBound(const HloInstructionProto* instr_proto) { HloOpcode opcode = StringToHloOpcode(instr_proto->opcode()).ValueOrDie(); if (opcode == HloOpcode::kCustomCall && @@ -3444,348 +3389,6 @@ StatusOr XlaBuilder::IsConstant(XlaOp operand) const { return is_constant; } -StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { - TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, - LookUpInstruction(root_op)); - - HloComputationProto entry; - SetProtoIdAndName(&entry, StrCat(name_, "_dynamic_inference"), kNameSeparator, - GetNextId()); - ProgramShapeProto* program_shape = entry.mutable_program_shape(); - *program_shape->mutable_result() = - ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto(); - - std::vector called_computations; - auto operand_is_constant = [&](const HloInstructionProto* instr_proto, - int64 operand_index) -> StatusOr { - int64 operand_id = instr_proto->operand_ids(operand_index); - bool is_constant = true; - absl::flat_hash_set visited; - IsConstantVisitor(operand_id, &visited, &is_constant); - return is_constant; - }; - // Process instruction and copy it into the new graph. The new node in the new - // graph with have id set to `id`. - auto process_instruction = [&](const HloInstructionProto* instr_proto, - bool need_rewrite, int64 id, - absl::Span operand_ids, - int64* global_id) { - // Rewrite the instruction with following rules: - // - Unary ops: Convert into bitcast (identity) with type Pred. - // - Binary ops: Convert into binary or. - // - Select: Convert into binary or with its two data operands. - // - Concat / Tuple/ GTE / Bitcast: Copy. - // - Param: Convert to constant True. - // - GetDimensionSize: Convert to constant True if dimension is dynamic, - // contant False if dimension is static. - // - Reduce: Convert to reduce or. - // - Constant: Convert to constant False. - // - Reshape, slice, transpose, pad: - // Convert into predicate type with same opcode. - // - Other ops: Not supported. - // Create the instruction for the new handle. - TF_ASSIGN_OR_RETURN(HloOpcode opcode, - StringToHloOpcode(instr_proto->opcode())); - auto* new_instr = entry.add_instructions(); - *new_instr = *instr_proto; - new_instr->set_id(id); - new_instr->mutable_operand_ids()->Clear(); - for (auto operand_id : operand_ids) { - new_instr->mutable_operand_ids()->Add(operand_id); - } - - if (!need_rewrite) { - *new_instr->mutable_name() = - GetFullName(instr_proto->opcode(), kNameSeparator, id); - if (opcode == HloOpcode::kReduce) { - // Copy the reducer to the new module, with a new id that's same as the - // reduce op. - HloComputationProto* reducer = - &embedded_[new_instr->called_computation_ids(0)]; - int64 reducer_id = (*global_id)++; - new_instr->clear_called_computation_ids(); - new_instr->add_called_computation_ids(reducer_id); - called_computations.push_back(CopyReducer( - reducer_id, reducer, /*rewrite_into_pred=*/false, global_id)); - } - return Status::OK(); - } - *new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape()); - Shape new_shape(new_instr->shape()); - switch (opcode) { - case HloOpcode::kAbs: - case HloOpcode::kRoundNearestAfz: - case HloOpcode::kBitcast: - case HloOpcode::kCeil: - case HloOpcode::kCollectivePermuteDone: - case HloOpcode::kCos: - case HloOpcode::kClz: - case HloOpcode::kExp: - case HloOpcode::kExpm1: - case HloOpcode::kFloor: - case HloOpcode::kImag: - case HloOpcode::kIsFinite: - case HloOpcode::kLog: - case HloOpcode::kLog1p: - case HloOpcode::kNot: - case HloOpcode::kNegate: - case HloOpcode::kPopulationCount: - case HloOpcode::kReal: - case HloOpcode::kRsqrt: - case HloOpcode::kLogistic: - case HloOpcode::kSign: - case HloOpcode::kSin: - case HloOpcode::kConvert: - case HloOpcode::kSqrt: - case HloOpcode::kCbrt: - case HloOpcode::kTanh: - CHECK_EQ(instr_proto->operand_ids_size(), 1); - *new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kBitcast); - break; - case HloOpcode::kAdd: - case HloOpcode::kAtan2: - case HloOpcode::kDivide: - case HloOpcode::kComplex: - case HloOpcode::kMaximum: - case HloOpcode::kMinimum: - case HloOpcode::kMultiply: - case HloOpcode::kPower: - case HloOpcode::kRemainder: - case HloOpcode::kSubtract: - case HloOpcode::kCompare: - case HloOpcode::kAnd: - case HloOpcode::kOr: - case HloOpcode::kXor: - case HloOpcode::kShiftLeft: - case HloOpcode::kShiftRightArithmetic: - case HloOpcode::kShiftRightLogical: - CHECK_EQ(instr_proto->operand_ids_size(), 2); - *new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr); - break; - case HloOpcode::kSelect: { - TF_ASSIGN_OR_RETURN(bool constant_predicate, - operand_is_constant(instr_proto, 0)); - if (!constant_predicate) { - // If the predicate operand is not constant, conservatively assume the - // entire result values are dynamic. - SetInstructionAsConstant(new_instr, id, new_shape, true); - } - break; - } - case HloOpcode::kGather: { - TF_ASSIGN_OR_RETURN(bool constant_indices, - operand_is_constant(instr_proto, 1)); - if (!constant_indices) { - // If the indices operand is not constant, conservatively assume the - // entire result values are dynamic. - SetInstructionAsConstant(new_instr, id, new_shape, true); - } - break; - } - case HloOpcode::kReduce: { - auto* reducer = &embedded_[new_instr->called_computation_ids(0)]; - int64 reducer_id = (*global_id)++; - new_instr->clear_called_computation_ids(); - new_instr->add_called_computation_ids(reducer_id); - called_computations.push_back(CopyReducer( - reducer_id, reducer, /*rewrite_into_pred=*/true, global_id)); - break; - } - case HloOpcode::kTuple: - case HloOpcode::kTranspose: - case HloOpcode::kSlice: - case HloOpcode::kBroadcast: - case HloOpcode::kConcatenate: - case HloOpcode::kReshape: - case HloOpcode::kPad: - break; - case HloOpcode::kGetTupleElement: { - // Rewrite parameter followed by gte into constants to avoid - // rematerializing the tuple parameter (could be very large). - int64 operand_handle = instr_proto->operand_ids(0); - TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, - LookUpInstructionByHandle(operand_handle)); - TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode, - StringToHloOpcode(operand_proto->opcode())); - if (operand_opcode == HloOpcode::kParameter) { - SetInstructionAsConstant(new_instr, id, new_shape, true); - } - break; - } - case HloOpcode::kGetDimensionSize: { - int64 dimension = instr_proto->dimensions(0); - int64 operand_handle = instr_proto->operand_ids(0); - TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, - LookUpInstructionByHandle(operand_handle)); - - SetInstructionAsConstant( - new_instr, id, new_shape, - operand_proto->shape().is_dynamic_dimension(dimension)); - break; - } - case HloOpcode::kConstant: - case HloOpcode::kIota: - SetInstructionAsConstant(new_instr, id, new_shape, false); - break; - case HloOpcode::kCustomCall: - if (instr_proto->custom_call_target() == "SetBound") { - SetInstructionAsConstant(new_instr, id, new_shape, true); - break; - } else { - return InvalidArgument( - "Dynamic inferencing on custom call %s is not supported", - instr_proto->DebugString()); - } - case HloOpcode::kParameter: - SetInstructionAsConstant(new_instr, id, new_shape, true); - break; - default: - return InvalidArgument("Dynamic inferencing %s is not supported", - instr_proto->DebugString()); - } - *new_instr->mutable_name() = - GetFullName(instr_proto->opcode(), kNameSeparator, id); - return Status::OK(); - }; - - struct WorkItem { - explicit WorkItem(int64 handle, bool need_rewrite) - : handle(handle), need_rewrite(need_rewrite), visited(false) {} - int64 handle; - // If need_rewrite is true, the instruction will be copied and rewrite into - // a pred instruction indicating if each value is dynamic. If need_rewrite - // is false, simply copy the instruction to the output graph. - // E.g., - // For select(P, A, B), we need to rewrite A and B into predicates, but - // don't need to rewrite P. - bool need_rewrite; - // Used in dfs to remember the ids of processed operands of this item. - std::vector processed_operands; - // Whether this node been visited before or not. - bool visited; - }; - // Only copy each pair of {handle, need_rewrite} once. Value is the id in the - // new graph. - absl::flat_hash_map, int64> seen; - // Monotonically increasing id to assign to new instructions. - int64 global_id = 0; - // The result id of the last rewritten item -- return value of last stack - // item. - int64 stacktop_id = -1; - std::vector worklist; - worklist.push_back(WorkItem(root->id(), true)); - while (!worklist.empty()) { - WorkItem& item = worklist.back(); - auto item_key = std::make_pair(item.handle, item.need_rewrite); - auto iter = seen.find(item_key); - // Already processed this item. Return previous results. - if (iter != seen.end()) { - stacktop_id = iter->second; - worklist.pop_back(); - continue; - } - - int64 next_operand = item.processed_operands.size(); - TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto, - LookUpInstructionByHandle(item.handle)); - VLOG(3) << "Processing" << instr_proto->name(); - if (!item.visited) { - item.visited = true; - } else { - // Record previous processed operand. - item.processed_operands.push_back(stacktop_id); - next_operand++; - } - TF_ASSIGN_OR_RETURN(HloOpcode opcode, - StringToHloOpcode(instr_proto->opcode())); - bool should_visit_operand = true; - if (opcode == HloOpcode::kGetDimensionSize) { - // We rewrite gte instructions into constants based on its operand shape - // so no need to visit their operands. - should_visit_operand = false; - } - - if (opcode == HloOpcode::kGetTupleElement) { - int64 operand_handle = instr_proto->operand_ids(0); - TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, - LookUpInstructionByHandle(operand_handle)); - TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode, - StringToHloOpcode(operand_proto->opcode())); - if (operand_opcode == HloOpcode::kParameter) { - // Don't rematerialize the whole parameter if it's followed by a GTE. - should_visit_operand = false; - } - } - - if (opcode == HloOpcode::kSelect) { - TF_ASSIGN_OR_RETURN(bool constant_predicate, - operand_is_constant(instr_proto, 0)); - // If the predicate operand (first operand) is non-constant, we don't - // visit operands and we say the all result values are dynamic. - should_visit_operand = constant_predicate; - } - if (opcode == HloOpcode::kGather) { - TF_ASSIGN_OR_RETURN(bool constant_indices, - operand_is_constant(instr_proto, 1)); - // If the indices operand (second operand) is non-constant, we don't - // visit operands and we say the all result values are dynamic. - should_visit_operand = constant_indices; - } - if (opcode == HloOpcode::kGetDimensionSize && next_operand == 0) { - // Always rewrite get dimension size into constant. - item.need_rewrite = true; - } - if (next_operand >= instr_proto->operand_ids_size() || - !should_visit_operand || InstrIsSetBound(instr_proto)) { - // No more operands to process, process self. - int64 new_id = global_id++; - VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name(); - TF_RETURN_IF_ERROR(process_instruction(instr_proto, item.need_rewrite, - new_id, item.processed_operands, - &global_id)); - stacktop_id = new_id; - seen[item_key] = stacktop_id; - worklist.pop_back(); - } else { - // Visit and process operand. If an operand doesn't need rewrite - // (predicate of kSelect, or indices of kGather), we also don't rewrite - // its ancestors. - WorkItem next_item(instr_proto->operand_ids(next_operand), - item.need_rewrite); - if (opcode == HloOpcode::kSelect && next_operand == 0) { - next_item.need_rewrite = false; - } - if (opcode == HloOpcode::kGather && next_operand == 1) { - next_item.need_rewrite = false; - } - // Push next operand into worklist. - worklist.push_back(next_item); - } - } - TF_RET_CHECK(stacktop_id != -1); - entry.set_root_id(stacktop_id); - absl::c_sort(*entry.mutable_instructions(), - [](const HloInstructionProto& p1, - const HloInstructionProto& p2) { return p1.id() < p2.id(); }); - XlaComputation computation(entry.id()); - HloModuleProto* module = computation.mutable_proto(); - module->set_name(entry.name()); - module->set_id(entry.id()); - module->set_entry_computation_name(entry.name()); - module->set_entry_computation_id(entry.id()); - *module->mutable_host_program_shape() = *program_shape; - for (auto& called_comp : called_computations) { - *module->add_computations() = called_comp; - } - *module->add_computations() = std::move(entry); - // Make sure all ids appear in the computation with ascending order. - absl::c_sort(*module->mutable_computations(), - [](const HloComputationProto& c1, - const HloComputationProto& c2) { return c1.id() < c2.id(); }); - XLA_VLOG_LINES(3, module->DebugString()); - return std::move(computation); -} - StatusOr XlaBuilder::BuildConstantSubGraph( XlaOp root_op, bool dynamic_dimension_is_minus_one) { TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op)); diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 561934c7fdd..a6cbdbe5597 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -114,6 +114,7 @@ class XlaOp { int64 handle() const { return handle_; } friend class XlaBuilder; + friend class ValueInference; friend class MlirHloBuilder; friend struct internal::XlaBuilderFriend; @@ -296,31 +297,6 @@ class XlaBuilder { StatusOr BuildConstantSubGraph( XlaOp root_op, bool dynamic_dimension_is_uint_max = false); - // Similar to BuildConstantSubGraph, but with root element type changed to - // boolean. A true value in the root indicates that the value is dynamic while - // false value indicates that the value is a constant. This will copy the - // needed ops/computations to the subgraph. - // - // E.g., - // Compuptation { - // a = 3 - // b = param(0) - // ROOT Tuple(a + b, a + 1, b + 1) - // } - // Calling BuildDynamicInferenceGraph on root will produce the following - // graph: - // - // Compuptation { - // a = False - // b = True - // ROOT Tuple(a | b, a, b) - // } - // - // The result, which is (True, False, True) after evaluation, can be - // interpreted as "First element is dynamic; Second element is static; Third - // element is dynamic". - StatusOr BuildDynamicInferenceGraph(XlaOp root_op); - // Returns the first error that was encountered while building the // computation. When an error is encountered, by default we return a vacuous // XlaOp and inform the user of the error that occurred while @@ -1478,6 +1454,8 @@ class XlaBuilder { } friend struct internal::XlaBuilderFriend; + + friend class ValueInference; }; // RAII-style object: sets the current sharding assignment in builder on diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index da11d3e3367..561e6214b2e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -497,7 +497,7 @@ class HloInstruction { static StatusOr> CreateFromProto( const HloInstructionProto& proto, const absl::flat_hash_map& instruction_map, - const absl::flat_hash_map& computation_map, + const absl::flat_hash_map& computation_map = {}, bool prohibit_empty_literal = true); // Creates a parameter-retrieving instruction. @@ -1043,6 +1043,7 @@ class HloInstruction { // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } + HloOpcode* mutable_opcode() { return &opcode_; } // Returns true if this instruction has a side effect, irrespective of whether // any called computations may contain an instruction with side effects. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 99a38fdefa2..6f7085736b9 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -2097,6 +2097,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:value_inference", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", diff --git a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc index 892fdb86362..9d547fdf3b7 100644 --- a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc +++ b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/prng.h" +#include "tensorflow/compiler/xla/client/value_inference.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" @@ -44,7 +45,6 @@ namespace { // An enumerator for the client types that we want to iterate over in // the various tests. enum class ClientType { kLocal, kCompileOnly }; -ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly}; class DynamismInferenceTest : public ::testing::Test { public: @@ -72,21 +72,18 @@ class DynamismInferenceTest : public ::testing::Test { LOG(FATAL) << "invalid client_type value"; } - StatusOr ComputeDynamismLiteral(Client* client, XlaOp operand, - XlaBuilder* builder, + StatusOr ComputeDynamismLiteral(XlaOp operand, XlaBuilder* builder, Layout* output_layout = nullptr) { - TF_ASSIGN_OR_RETURN(auto subgraph, - builder->BuildDynamicInferenceGraph(operand)); - TF_ASSIGN_OR_RETURN(auto computed, - client->ComputeConstant(subgraph, output_layout)); - return std::move(computed); + ValueInference value_inference(builder); + TF_ASSIGN_OR_RETURN(auto literal_slice, + value_inference.AnalyzeIsDynamic(operand)); + return literal_slice.Clone(); } - StatusOr ComputeDynamismScalar(Client* client, XlaOp operand, - XlaBuilder* builder, + StatusOr ComputeDynamismScalar(XlaOp operand, XlaBuilder* builder, ShapeIndex index = {}) { - TF_ASSIGN_OR_RETURN(auto literal, ComputeDynamismLiteral(client, operand, - builder, nullptr)); + TF_ASSIGN_OR_RETURN(auto literal, + ComputeDynamismLiteral(operand, builder, nullptr)); return literal.Get({}, index); } @@ -94,265 +91,202 @@ class DynamismInferenceTest : public ::testing::Test { }; TEST_F(DynamismInferenceTest, ScalarInt32Literal) { - for (ClientType client_type : client_types) { - Client* client = ClientOrDie(platform_, client_type); - XlaBuilder b(TestName()); - auto computation = ConstantR0(&b, 42); + XlaBuilder b(TestName()); + auto computation = ConstantR0(&b, 42); - auto value = ComputeDynamismScalar(client, computation, &b); - ASSERT_TRUE(value.ok()) << value.status(); - // A constant is not dynamic. - EXPECT_EQ(value.ValueOrDie(), false); - } + auto value = ComputeDynamismScalar(computation, &b); + ASSERT_TRUE(value.ok()) << value.status(); + // A constant is not dynamic. + EXPECT_EQ(value.ValueOrDie(), false); } TEST_F(DynamismInferenceTest, Iota) { // The output of iota are consistened static. - for (ClientType client_type : client_types) { - Client* client = ClientOrDie(platform_, client_type); - XlaBuilder b(TestName()); - auto computation = Iota(&b, S32, 2); - // Iota is not dynamic. - EXPECT_FALSE(ComputeDynamismLiteral(client, computation, &b) - .ValueOrDie() - .Get({0})); - } + XlaBuilder b(TestName()); + auto computation = Iota(&b, S32, 2); + // Iota is not dynamic. + EXPECT_FALSE( + ComputeDynamismLiteral(computation, &b).ValueOrDie().Get({0})); } TEST_F(DynamismInferenceTest, TupleSimple) { - for (ClientType client_type : client_types) { - Client* client = ClientOrDie(platform_, client_type); - XlaBuilder b(TestName()); - auto c = ConstantR0(&b, 42); - auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); - auto tuple = Tuple(&b, {c, p}); - EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {0}).ValueOrDie(), - false); - EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {1}).ValueOrDie(), true); - } + auto tuple = Tuple(&b, {c, p}); + EXPECT_EQ(ComputeDynamismScalar(tuple, &b, {0}).ValueOrDie(), false); + EXPECT_EQ(ComputeDynamismScalar(tuple, &b, {1}).ValueOrDie(), true); } TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) { - for (ClientType client_type : client_types) { - Client* client = ClientOrDie(platform_, client_type); - XlaBuilder b(TestName()); - auto c = ConstantR0(&b, 42); - auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); - auto tuple = Tuple(&b, {c, p}); - auto gte0 = GetTupleElement(tuple, 0); - auto gte1 = GetTupleElement(tuple, 1); - auto tuple_2 = Tuple(&b, {gte0, gte1}); - EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), - false); - EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), - true); - } + auto tuple = Tuple(&b, {c, p}); + auto gte0 = GetTupleElement(tuple, 0); + auto gte1 = GetTupleElement(tuple, 1); + auto tuple_2 = Tuple(&b, {gte0, gte1}); + EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false); + EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true); } TEST_F(DynamismInferenceTest, PredValueUsedTwice) { - for (ClientType client_type : client_types) { - Client* client = ClientOrDie(platform_, client_type); - XlaBuilder b(TestName()); - auto c = ConstantR0(&b, 42); - auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); - auto pred = Eq(c, p); - auto result = Select(pred, p, c); - EXPECT_EQ(ComputeDynamismScalar(client, result, &b, {}).ValueOrDie(), true); - } + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); + auto pred = Eq(c, p); + auto result = Select(pred, p, c); + EXPECT_EQ(ComputeDynamismScalar(result, &b, {}).ValueOrDie(), true); } TEST_F(DynamismInferenceTest, ReduceUsedTwice) { - for (ClientType client_type : client_types) { - Client* client = ClientOrDie(platform_, client_type); - XlaBuilder b(TestName()); - auto c = ConstantR0(&b, 42); - auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}), "p0"); - auto zero = ConstantR0(&b, 0); - XlaComputation add_s32 = CreateScalarAddComputation(S32, &b); - auto reduce = Reduce(p, zero, add_s32, {0}); - auto pred = Eq(c, reduce); - auto result = Select(pred, reduce, c); - EXPECT_EQ(ComputeDynamismScalar(client, result, &b, {}).ValueOrDie(), true); - } + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}), "p0"); + auto zero = ConstantR0(&b, 0); + XlaComputation add_s32 = CreateScalarAddComputation(S32, &b); + auto reduce = Reduce(p, zero, add_s32, {0}); + auto pred = Eq(c, reduce); + auto result = Select(pred, reduce, c); + EXPECT_EQ(ComputeDynamismScalar(result, &b, {}).ValueOrDie(), true); } TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) { - for (ClientType client_type : client_types) { - Client* client = ClientOrDie(platform_, client_type); - XlaBuilder b(TestName()); - auto c = ConstantR0(&b, 42); - auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); - auto concat = ConcatScalars(&b, {c, p}); - auto slice0 = SliceInDim(concat, 0, 1, 1, 0); - auto reshape0 = Reshape(slice0, {}); - auto slice1 = SliceInDim(concat, 1, 2, 1, 0); - auto reshape1 = Reshape(slice1, {}); - auto tuple_2 = Tuple(&b, {reshape0, reshape1}); - EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), - false); - EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), - true); - } + auto concat = ConcatScalars(&b, {c, p}); + auto slice0 = SliceInDim(concat, 0, 1, 1, 0); + auto reshape0 = Reshape(slice0, {}); + auto slice1 = SliceInDim(concat, 1, 2, 1, 0); + auto reshape1 = Reshape(slice1, {}); + auto tuple_2 = Tuple(&b, {reshape0, reshape1}); + EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false); + EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true); } TEST_F(DynamismInferenceTest, ParameterIsDynamic) { - for (ClientType client_type : client_types) { - Client* client = ClientOrDie(platform_, client_type); - XlaBuilder b(TestName()); - auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); + XlaBuilder b(TestName()); + auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); - auto value = ComputeDynamismScalar(client, computation, &b); - ASSERT_TRUE(value.ok()) << value.status(); - // A parameter is considered dynamic. - EXPECT_EQ(value.ValueOrDie(), true); - } + auto value = ComputeDynamismScalar(computation, &b); + ASSERT_TRUE(value.ok()) << value.status(); + // A parameter is considered dynamic. + EXPECT_EQ(value.ValueOrDie(), true); } TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) { - for (ClientType client_type : client_types) { - Client* client = ClientOrDie(platform_, client_type); - XlaBuilder b(TestName()); - auto c = ConstantR0(&b, 42); - auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); - auto neg0 = Neg(c); - auto neg1 = Neg(p); - auto tuple_2 = Tuple(&b, {neg0, neg1}); - EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), - false); - EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), - true); - } + auto neg0 = Neg(c); + auto neg1 = Neg(p); + auto tuple_2 = Tuple(&b, {neg0, neg1}); + EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false); + EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true); } TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) { - for (ClientType client_type : client_types) { - Client* client = ClientOrDie(platform_, client_type); - XlaBuilder b(TestName()); - auto c = ConstantR0(&b, 42); - auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); - // Static value + static value = static - auto add1 = Add(c, c); - // Dynamic value + dynamic value = dynamic - auto add2 = Add(p, c); - auto tuple_2 = Tuple(&b, {add1, add2}); - EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), - false); - EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), - true); - } + // Static value + static value = static + auto add1 = Add(c, c); + // Dynamic value + dynamic value = dynamic + auto add2 = Add(p, c); + auto tuple_2 = Tuple(&b, {add1, add2}); + EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false); + EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true); } TEST_F(DynamismInferenceTest, GetDimensionSize) { - for (ClientType client_type : client_types) { - Client* client = ClientOrDie(platform_, client_type); - XlaBuilder b(TestName()); - // param = Param([<=2, 3]) - // get_dimension_size(param, 0) is dynamic - // get_dimension_size(param, 1) is static - auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), - "p0"); + XlaBuilder b(TestName()); + // param = Param([<=2, 3]) + // get_dimension_size(param, 0) is dynamic + // get_dimension_size(param, 1) is static + auto p = + Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0"); - auto gds0 = GetDimensionSize(p, 0); - auto gds1 = GetDimensionSize(p, 1); - auto tuple_2 = Tuple(&b, {gds0, gds1}); - EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), - true); - EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), - false); - } + auto gds0 = GetDimensionSize(p, 0); + auto gds1 = GetDimensionSize(p, 1); + auto tuple_2 = Tuple(&b, {gds0, gds1}); + EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), true); + EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), false); } TEST_F(DynamismInferenceTest, GatherWithCommonParent) { - for (ClientType client_type : client_types) { - Client* client = ClientOrDie(platform_, client_type); - XlaBuilder b(TestName()); - // Test the analysis on a gather where first operand and second operand have - // common parents. - Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); + XlaBuilder b(TestName()); + // Test the analysis on a gather where first operand and second operand have + // common parents. + Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); - auto operand1 = Parameter(&b, 0, indices_shape, "p1"); - auto operand2 = Parameter(&b, 1, indices_shape, "p2"); - auto indices = Sub(operand1, operand2); - GatherDimensionNumbers dim_numbers; - dim_numbers.add_offset_dims(1); - dim_numbers.add_start_index_map(0); - dim_numbers.set_index_vector_dim(1); - auto gather = Gather(operand1, indices, dim_numbers, {1}); - ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message(); - EXPECT_TRUE(ComputeDynamismLiteral(client, gather, &b) - .ValueOrDie() - .Get({0, 0})); - } + auto operand1 = Parameter(&b, 0, indices_shape, "p1"); + auto operand2 = Parameter(&b, 1, indices_shape, "p2"); + auto indices = Sub(operand1, operand2); + GatherDimensionNumbers dim_numbers; + dim_numbers.add_offset_dims(1); + dim_numbers.add_start_index_map(0); + dim_numbers.set_index_vector_dim(1); + auto gather = Gather(operand1, indices, dim_numbers, {1}); + ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message(); + EXPECT_TRUE( + ComputeDynamismLiteral(gather, &b).ValueOrDie().Get({0, 0})); } TEST_F(DynamismInferenceTest, GatherWithConstantParent) { - for (ClientType client_type : client_types) { - Client* client = ClientOrDie(platform_, client_type); - XlaBuilder b(TestName()); - // Test the analysis on a gather. - Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); - auto data_operand = ConstantR1(&b, {1, 2}); - auto indices = ConstantR1(&b, {1, 2}); - GatherDimensionNumbers dim_numbers; - dim_numbers.add_offset_dims(1); - dim_numbers.add_start_index_map(0); - dim_numbers.set_index_vector_dim(1); - auto gather = Gather(data_operand, indices, dim_numbers, {1}); - ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message(); - // Everything is constant, result is also contant. - EXPECT_FALSE(ComputeDynamismLiteral(client, gather, &b) - .ValueOrDie() - .Get({0, 0})); - } + XlaBuilder b(TestName()); + // Test the analysis on a gather. + Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); + auto data_operand = ConstantR1(&b, {1, 2}); + auto indices = ConstantR1(&b, {1, 2}); + GatherDimensionNumbers dim_numbers; + dim_numbers.add_offset_dims(1); + dim_numbers.add_start_index_map(0); + dim_numbers.set_index_vector_dim(1); + auto gather = Gather(data_operand, indices, dim_numbers, {1}); + ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message(); + // Everything is constant, result is also contant. + EXPECT_FALSE( + ComputeDynamismLiteral(gather, &b).ValueOrDie().Get({0, 0})); } TEST_F(DynamismInferenceTest, GatherWithSharedConstantParent) { - for (ClientType client_type : client_types) { - Client* client = ClientOrDie(platform_, client_type); - XlaBuilder b(TestName()); - // Test the analysis on a gather. - Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); - auto operand1 = ConstantR1(&b, {1, 2}); - auto operand2 = ConstantR1(&b, {1, 2}); - auto indices = Sub(operand1, operand2); - GatherDimensionNumbers dim_numbers; - dim_numbers.add_offset_dims(1); - dim_numbers.add_start_index_map(0); - dim_numbers.set_index_vector_dim(1); - auto gather = Gather(operand1, indices, dim_numbers, {1}); - ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message(); - // Everything is constant, result is also contant. - EXPECT_FALSE(ComputeDynamismLiteral(client, gather, &b) - .ValueOrDie() - .Get({0, 0})); - } + XlaBuilder b(TestName()); + // Test the analysis on a gather. + Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); + auto operand1 = ConstantR1(&b, {1, 2}); + auto operand2 = ConstantR1(&b, {1, 2}); + auto indices = Sub(operand1, operand2); + GatherDimensionNumbers dim_numbers; + dim_numbers.add_offset_dims(1); + dim_numbers.add_start_index_map(0); + dim_numbers.set_index_vector_dim(1); + auto gather = Gather(operand1, indices, dim_numbers, {1}); + ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message(); + // Everything is constant, result is also contant. + EXPECT_FALSE( + ComputeDynamismLiteral(gather, &b).ValueOrDie().Get({0, 0})); } TEST_F(DynamismInferenceTest, InferThroughPad) { - for (ClientType client_type : client_types) { - Client* client = ClientOrDie(platform_, client_type); - XlaBuilder b(TestName()); - // Test the analysis on a gather. - auto operand1 = ConstantR1(&b, {1, 2}); - auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0"); - PaddingConfig padding_config; - padding_config.add_dimensions()->set_edge_padding_high(1); - // After pad the value is [constant, constant, parameter]. - auto pad = Pad(operand1, parameter, padding_config); - ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message(); - // Everything is constant, result is also contant. - EXPECT_FALSE( - ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get({0})); - EXPECT_FALSE( - ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get({1})); - EXPECT_TRUE( - ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get({2})); - } + XlaBuilder b(TestName()); + // Test the analysis on a gather. + auto operand1 = ConstantR1(&b, {1, 2}); + auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0"); + PaddingConfig padding_config; + padding_config.add_dimensions()->set_edge_padding_high(1); + // After pad the value is [constant, constant, parameter]. + auto pad = Pad(operand1, parameter, padding_config); + ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message(); + // Everything is constant, result is also contant. + EXPECT_FALSE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get({0})); + EXPECT_FALSE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get({1})); + EXPECT_TRUE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get({2})); } } // namespace