XLA value inference.

XLA value inference is designed to be a one stop service that analyzes the attributes of each value in a tensor. The attributes include:

- What's the upper-bound of each value in a tensor.
- What's the lower-bound of each value in a tensor.
- What's the constant value of each tensor.
- Whether or not each value in a tensor is dynamic.

This cl implements the dynamism inference part, which replaces the old one in XLA builder that has become a bit too complex to maintain over time.

PiperOrigin-RevId: 360753561
Change-Id: I55c6ec9abf43f1502c5173dafcc47ac6c41bea09
This commit is contained in:
Yunxing Dai 2021-03-03 14:09:30 -08:00 committed by TensorFlower Gardener
parent 47f244679c
commit 71ac02ec1f
10 changed files with 637 additions and 648 deletions

View File

@ -485,6 +485,7 @@ cc_library(
":xla_resource",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client",
"//tensorflow/compiler/xla/client:value_inference",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/client/value_inference.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
@ -136,18 +137,15 @@ xla::StatusOr<Tensor> XlaExpression::ResolveDynamism(
if (!client)
return errors::InvalidArgument("client is required to resolve constant");
TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
handle().builder()->BuildDynamicInferenceGraph(handle()));
TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
// The XLA layout is specified minor to major, and TensorFlow uses a major to
// minor order.
std::vector<int64> layout_indices(shape.dims());
std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
TF_ASSIGN_OR_RETURN(xla::Literal literal,
client->ComputeConstant(constant_graph, &layout));
xla::ValueInference value_inference(handle().builder());
TF_ASSIGN_OR_RETURN(xla::LiteralSlice literal,
value_inference.AnalyzeIsDynamic(handle()));
Tensor tensor(DT_BOOL);
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, DT_BOOL, &tensor));
return tensor;

View File

@ -211,6 +211,25 @@ cc_library(
],
)
cc_library(
name = "value_inference",
srcs = ["value_inference.cc"],
hdrs = ["value_inference.h"],
visibility = ["//visibility:public"],
deps = [
":xla_builder",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_evaluator",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "xla_builder",
srcs = ["xla_builder.cc"],

View File

@ -0,0 +1,384 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/value_inference.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace {
Literal CreatePredLiteral(bool pred, const Shape& reference_shape) {
Literal literal = LiteralUtil::CreateR0(pred);
Literal literal_broadcast =
literal
.Broadcast(ShapeUtil::ChangeElementType(Shape(reference_shape), PRED),
{})
.ValueOrDie();
return literal_broadcast;
}
using GetOperand = std::function<StatusOr<LiteralSlice>(int64 operand_index,
int64 opreand_handle)>;
// HloProtoEvaluator evaluates an hlo proto and returns a literal. The user has
// to provide operand as literals through the get_operand function.
struct HloProtoEvaluator {
explicit HloProtoEvaluator(HloInstructionProto inst, GetOperand get_operand)
: inst(std::move(inst)),
get_operand(get_operand),
module("EmptyModuleForEvaluation", HloModuleConfig()) {}
// WithOpCode changes the called computation of the instruction being
// evaluated.
HloProtoEvaluator& WithComputation(
std::unique_ptr<HloComputation> new_computation) {
computation = new_computation.get();
computation->ClearUniqueIdInternal();
for (HloInstruction* inst : computation->instructions()) {
inst->ClearUniqueIdInternal();
}
module.AddEmbeddedComputation(std::move(new_computation));
return *this;
}
// WithOpCode changes the primitive type of the instruction being evaluated.
HloProtoEvaluator& WithPrimitiveType(PrimitiveType new_primitive_type) {
primitive_type = new_primitive_type;
return *this;
}
// WithOpCode changes the opcode of the instruction being evaluated.
HloProtoEvaluator& WithOpCode(HloOpcode new_opcode) {
opcode = new_opcode;
return *this;
}
StatusOr<Literal> Evaluate() {
// Evaluate the instruction by swapping it's operands with constant
// instructions with given literals.
HloComputation::Builder builder("EmptyComputation");
absl::flat_hash_map<int64, HloInstruction*> operand_map;
for (int64 i = 0; i < inst.operand_ids_size(); ++i) {
int64 operand_handle = inst.operand_ids(i);
TF_ASSIGN_OR_RETURN(auto literal, get_operand(i, inst.operand_ids(i)));
std::unique_ptr<HloInstruction> operand =
HloInstruction::CreateConstant(literal.Clone());
operand_map[operand_handle] = operand.get();
builder.AddInstruction(std::move(operand));
}
if (primitive_type.has_value()) {
*inst.mutable_shape() = ShapeUtil::ChangeElementType(
Shape(inst.shape()), primitive_type.value())
.ToProto();
}
if (opcode.has_value()) {
*inst.mutable_opcode() = HloOpcodeString(opcode.value());
}
absl::flat_hash_map<int64, HloComputation*> computation_map;
if (inst.called_computation_ids_size() != 0) {
TF_RET_CHECK(inst.called_computation_ids_size() == 1 &&
computation != nullptr)
<< inst.DebugString();
computation_map[inst.called_computation_ids(0)] = computation;
}
TF_ASSIGN_OR_RETURN(
auto new_instruction,
HloInstruction::CreateFromProto(inst, operand_map, computation_map));
new_instruction->ClearUniqueIdInternal();
builder.AddInstruction(std::move(new_instruction));
auto computation = builder.Build();
module.AddEntryComputation(std::move(computation));
HloEvaluator evaluator;
return evaluator.Evaluate(module.entry_computation()->root_instruction());
}
HloInstructionProto inst;
GetOperand get_operand;
HloModule module;
HloComputation* computation = nullptr;
absl::optional<PrimitiveType> primitive_type = absl::nullopt;
absl::optional<HloOpcode> opcode = absl::nullopt;
};
} // namespace
StatusOr<Literal> ValueInference::AnalyzeConstantLiteral(int64 handle) {
TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
builder_->LookUpInstructionByHandle(handle));
if (constant_.contains(handle)) {
return constant_[handle].Clone();
}
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
switch (opcode) {
case HloOpcode::kGetDimensionSize: {
int64 dimension = root->dimensions(0);
int64 operand_handle = root->operand_ids(0);
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
builder_->LookUpInstructionByHandle(operand_handle));
if (operand_proto->shape().is_dynamic_dimension(dimension)) {
return InvalidArgument(
"AnalyzeConstant is called on a GetDimensionSize on dynamic "
"dimension.");
} else {
return LiteralUtil::CreateR0<int32>(
operand_proto->shape().dimensions(dimension));
}
}
// Non functional ops.
case HloOpcode::kRng:
case HloOpcode::kAllReduce:
// TODO(b/33009255): Implement constant folding for cross replica sum.
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kCall:
// TODO(b/32495713): We aren't checking the to_apply computation itself,
// so we conservatively say that computations containing the Call op
// cannot be constant. We cannot set is_functional=false in other similar
// cases since we're already relying on IsConstant to return true.
case HloOpcode::kCustomCall:
case HloOpcode::kWhile:
case HloOpcode::kConditional:
// TODO(b/32495713): We aren't checking the condition and body
// computations themselves.
case HloOpcode::kSend:
case HloOpcode::kRecv:
case HloOpcode::kParameter: {
return InvalidArgument("Can't analyze constant values on instruction %s",
root->DebugString());
}
case HloOpcode::kReduce:
case HloOpcode::kScatter:
case HloOpcode::kReduceWindow: {
HloComputationProto computation_proto =
builder_->embedded_[root->called_computation_ids(0)];
TF_ASSIGN_OR_RETURN(auto computation, HloComputation::CreateFromProto(
computation_proto, {}));
return HloProtoEvaluator(*root,
[&](int64 operand_index, int64 operand_handle) {
return AnalyzeConstant(operand_handle);
})
.WithComputation(std::move(computation))
.Evaluate();
}
default:
return HloProtoEvaluator(*root,
[&](int64 operand_index, int64 operand_handle) {
return AnalyzeConstant(operand_handle);
})
.Evaluate();
}
}
StatusOr<Literal> ValueInference::AnalyzeIsDynamicLiteral(int64 handle) {
TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
builder_->LookUpInstructionByHandle(handle));
if (is_dynamic_.contains(handle)) {
return is_dynamic_[handle].Clone();
}
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
switch (opcode) {
case HloOpcode::kGetDimensionSize: {
int64 dimension = root->dimensions(0);
int64 operand_handle = root->operand_ids(0);
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
builder_->LookUpInstructionByHandle(operand_handle));
return LiteralUtil::CreateR0<bool>(
operand_proto->shape().is_dynamic_dimension(dimension));
}
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kBitcast:
case HloOpcode::kCeil:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kCos:
case HloOpcode::kClz:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal:
case HloOpcode::kRsqrt:
case HloOpcode::kLogistic:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kConvert:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kTanh: {
// Forward operand as they don't change if a value is dynamic or static.
int64 operand_handle = root->operand_ids(0);
TF_ASSIGN_OR_RETURN(auto literal, AnalyzeIsDynamic(operand_handle));
return literal.Clone();
}
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
case HloOpcode::kDivide:
case HloOpcode::kComplex:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kSubtract:
case HloOpcode::kCompare:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical: {
return HloProtoEvaluator(*root,
[&](int64 operand_index, int64 operand_handle) {
return AnalyzeIsDynamic(operand_handle);
})
.WithPrimitiveType(PRED)
.WithOpCode(HloOpcode::kOr)
.Evaluate();
}
case HloOpcode::kTuple:
case HloOpcode::kTranspose:
case HloOpcode::kSlice:
case HloOpcode::kBroadcast:
case HloOpcode::kConcatenate:
case HloOpcode::kReshape:
case HloOpcode::kPad: {
return HloProtoEvaluator(*root,
[&](int64 operand_index, int64 operand_handle) {
return AnalyzeIsDynamic(operand_handle);
})
.WithPrimitiveType(PRED)
.Evaluate();
}
case HloOpcode::kGetTupleElement: {
int64 operand_handle = root->operand_ids(0);
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
builder_->LookUpInstructionByHandle(operand_handle));
TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
StringToHloOpcode(operand_proto->opcode()));
if (operand_opcode == HloOpcode::kParameter) {
// Don't materialize the whole parameter if it's followed by a GTE.
return CreatePredLiteral(true, Shape(root->shape()));
}
return HloProtoEvaluator(*root,
[&](int64 operand_index, int64 operand_handle) {
return AnalyzeIsDynamic(operand_handle);
})
.WithPrimitiveType(PRED)
.Evaluate();
}
case HloOpcode::kReduce: {
std::vector<std::unique_ptr<HloInstruction>> operand_storage;
absl::flat_hash_map<int64, HloInstruction*> operand_map;
absl::flat_hash_map<int64, HloComputation*> computation_map;
Shape scalar_shape = ShapeUtil::MakeScalarShape(xla::PRED);
HloComputation::Builder b("reduce_or");
auto lhs = b.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
auto rhs = b.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
b.AddInstruction(
HloInstruction::CreateBinary(scalar_shape, HloOpcode::kOr, lhs, rhs));
auto reduce_computation = b.Build();
return HloProtoEvaluator(*root,
[&](int64 operand_index, int64 operand_handle) {
return AnalyzeIsDynamic(operand_handle);
})
.WithPrimitiveType(PRED)
.WithComputation(std::move(reduce_computation))
.Evaluate();
}
case HloOpcode::kConstant:
case HloOpcode::kIota: {
return CreatePredLiteral(false, Shape(root->shape()));
}
case HloOpcode::kParameter: {
return CreatePredLiteral(true, Shape(root->shape()));
}
case HloOpcode::kSelect: {
if (!AnalyzeConstant(root->operand_ids(0)).ok()) {
// If the predicate operand is not constant, conservatively assume the
// entire result values are dynamic.
return CreatePredLiteral(true, Shape(root->shape()));
}
return HloProtoEvaluator(*root,
[&](int64 operand_index, int64 operand_handle) {
if (operand_index == 0) {
return AnalyzeConstant(operand_handle);
} else {
return AnalyzeIsDynamic(operand_handle);
}
})
.WithPrimitiveType(PRED)
.Evaluate();
}
case HloOpcode::kGather: {
if (!AnalyzeConstant(root->operand_ids(1)).ok()) {
// If the index operand is not constant, conservatively assume the
// entire result values are dynamic.
return CreatePredLiteral(true, Shape(root->shape()));
}
return HloProtoEvaluator(*root,
[&](int64 operand_index, int64 operand_handle) {
if (operand_index == 1) {
return AnalyzeConstant(operand_handle);
} else {
return AnalyzeIsDynamic(operand_handle);
}
})
.WithPrimitiveType(PRED)
.Evaluate();
}
case HloOpcode::kCustomCall: {
if (root->custom_call_target() == "SetBound") {
return CreatePredLiteral(true, Shape(root->shape()));
} else {
return InvalidArgument(
"Dynamic inferencing on custom call %s is not supported",
root->DebugString());
}
break;
}
default:
return Unimplemented("Can't infer upper bound through %s: %s",
root->opcode(), root->DebugString());
}
}
StatusOr<LiteralSlice> ValueInference::AnalyzeIsDynamic(int64 handle) {
TF_ASSIGN_OR_RETURN(Literal literal, AnalyzeIsDynamicLiteral(handle));
is_dynamic_[handle] = std::move(literal);
return LiteralSlice(is_dynamic_[handle]);
}
StatusOr<LiteralSlice> ValueInference::AnalyzeConstant(int64 handle) {
TF_ASSIGN_OR_RETURN(Literal literal, AnalyzeConstantLiteral(handle));
constant_[handle] = std::move(literal);
return LiteralSlice(constant_[handle]);
}
} // namespace xla

View File

@ -0,0 +1,70 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_VALUE_INFERENCE_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_VALUE_INFERENCE_H_
#include "absl/container/flat_hash_map.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
class ValueInference {
public:
// ValueInference analyzes values in XlaOp answers following questions:
// - What's the upper-bound of each value in a tensor.
// - What's the lower-bound of each value in a tensor.
// - What's the constant value of each tensor.
// - Whether or not each value in a tensor is dynamic.
explicit ValueInference(XlaBuilder* builder) : builder_(builder) {}
StatusOr<LiteralSlice> AnalyzeUpperBound(XlaOp op) {
return Unimplemented("Analyzing upper-bound is not implemented yet.");
}
StatusOr<LiteralSlice> AnalyzeLowerBound(XlaOp op) {
return Unimplemented("Analyzing lower-bound is not implemented yet.");
}
StatusOr<LiteralSlice> AnalyzeIsDynamic(XlaOp op) {
return AnalyzeIsDynamic(op.handle());
}
StatusOr<LiteralSlice> AnalyzeConstant(XlaOp op) {
return AnalyzeConstant(op.handle());
}
private:
StatusOr<LiteralSlice> AnalyzeIsDynamic(int64 handle);
StatusOr<LiteralSlice> AnalyzeConstant(int64 handle);
StatusOr<Literal> AnalyzeIsDynamicLiteral(int64 handle);
StatusOr<Literal> AnalyzeConstantLiteral(int64 handle);
XlaBuilder* builder_;
// Cache to avoid re-evaluating. Mapping of xla handle to evaluated
// literals.
absl::flat_hash_map<int64, Literal> upper_bound_;
absl::flat_hash_map<int64, Literal> lower_bound_;
absl::flat_hash_map<int64, Literal> is_dynamic_;
absl::flat_hash_map<int64, Literal> constant_;
HloEvaluator evaluator_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_VALUE_INFERENCE_H_

View File

@ -83,61 +83,6 @@ void SetProtoIdAndName(T* entry, const string& base_name, char separator,
entry->set_name(GetFullName(base_name, separator, id));
}
ShapeProto ConvertShapeProtoToPred(const ShapeProto& shape_proto) {
return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto();
}
void SetInstructionAsConstant(HloInstructionProto* instr, int64 id,
const Shape& shape, bool pred) {
Literal literal = LiteralUtil::CreateR0(pred);
Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie();
*instr->mutable_shape() = shape.ToProto();
*instr->mutable_literal() = literal_broadcast.ToProto();
*instr->mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
}
// Copy `original_reducer` into a new computation proto with `reducer_id` as new
// id. If `rewrite_into_pred` is true, the instructions in the reducer are
// rewritten into predicate form.
HloComputationProto CopyReducer(int64 reducer_id,
HloComputationProto* original_reducer,
bool rewrite_into_pred, int64* global_id) {
HloComputationProto reducer;
SetProtoIdAndName(&reducer, StrCat("reduce_or"), kNameSeparator, reducer_id);
std::vector<int64> operands_id;
for (auto& inst : original_reducer->instructions()) {
// Copy params.
if (StringToHloOpcode(inst.opcode()).ValueOrDie() ==
HloOpcode::kParameter) {
HloInstructionProto* new_param = reducer.add_instructions();
*new_param = inst;
new_param->set_id((*global_id)++);
*new_param->mutable_name() =
GetFullName(inst.name(), '.', new_param->id());
if (rewrite_into_pred) {
*new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
}
operands_id.push_back(new_param->id());
}
if (inst.id() == original_reducer->root_id()) {
HloInstructionProto* new_root = reducer.add_instructions();
*new_root = inst;
new_root->set_id((*global_id)++);
*new_root->mutable_name() = GetFullName(inst.name(), '.', new_root->id());
if (rewrite_into_pred) {
*new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
*new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
}
new_root->clear_operand_ids();
for (int64 operand_id : operands_id) {
new_root->add_operand_ids(operand_id);
}
reducer.set_root_id(new_root->id());
}
}
return reducer;
}
bool InstrIsSetBound(const HloInstructionProto* instr_proto) {
HloOpcode opcode = StringToHloOpcode(instr_proto->opcode()).ValueOrDie();
if (opcode == HloOpcode::kCustomCall &&
@ -3444,348 +3389,6 @@ StatusOr<bool> XlaBuilder::IsConstant(XlaOp operand) const {
return is_constant;
}
StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
LookUpInstruction(root_op));
HloComputationProto entry;
SetProtoIdAndName(&entry, StrCat(name_, "_dynamic_inference"), kNameSeparator,
GetNextId());
ProgramShapeProto* program_shape = entry.mutable_program_shape();
*program_shape->mutable_result() =
ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto();
std::vector<HloComputationProto> called_computations;
auto operand_is_constant = [&](const HloInstructionProto* instr_proto,
int64 operand_index) -> StatusOr<bool> {
int64 operand_id = instr_proto->operand_ids(operand_index);
bool is_constant = true;
absl::flat_hash_set<int64> visited;
IsConstantVisitor(operand_id, &visited, &is_constant);
return is_constant;
};
// Process instruction and copy it into the new graph. The new node in the new
// graph with have id set to `id`.
auto process_instruction = [&](const HloInstructionProto* instr_proto,
bool need_rewrite, int64 id,
absl::Span<int64 const> operand_ids,
int64* global_id) {
// Rewrite the instruction with following rules:
// - Unary ops: Convert into bitcast (identity) with type Pred.
// - Binary ops: Convert into binary or.
// - Select: Convert into binary or with its two data operands.
// - Concat / Tuple/ GTE / Bitcast: Copy.
// - Param: Convert to constant True.
// - GetDimensionSize: Convert to constant True if dimension is dynamic,
// contant False if dimension is static.
// - Reduce: Convert to reduce or.
// - Constant: Convert to constant False.
// - Reshape, slice, transpose, pad:
// Convert into predicate type with same opcode.
// - Other ops: Not supported.
// Create the instruction for the new handle.
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
StringToHloOpcode(instr_proto->opcode()));
auto* new_instr = entry.add_instructions();
*new_instr = *instr_proto;
new_instr->set_id(id);
new_instr->mutable_operand_ids()->Clear();
for (auto operand_id : operand_ids) {
new_instr->mutable_operand_ids()->Add(operand_id);
}
if (!need_rewrite) {
*new_instr->mutable_name() =
GetFullName(instr_proto->opcode(), kNameSeparator, id);
if (opcode == HloOpcode::kReduce) {
// Copy the reducer to the new module, with a new id that's same as the
// reduce op.
HloComputationProto* reducer =
&embedded_[new_instr->called_computation_ids(0)];
int64 reducer_id = (*global_id)++;
new_instr->clear_called_computation_ids();
new_instr->add_called_computation_ids(reducer_id);
called_computations.push_back(CopyReducer(
reducer_id, reducer, /*rewrite_into_pred=*/false, global_id));
}
return Status::OK();
}
*new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape());
Shape new_shape(new_instr->shape());
switch (opcode) {
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kBitcast:
case HloOpcode::kCeil:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kCos:
case HloOpcode::kClz:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal:
case HloOpcode::kRsqrt:
case HloOpcode::kLogistic:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kConvert:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kTanh:
CHECK_EQ(instr_proto->operand_ids_size(), 1);
*new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kBitcast);
break;
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
case HloOpcode::kDivide:
case HloOpcode::kComplex:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kSubtract:
case HloOpcode::kCompare:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
CHECK_EQ(instr_proto->operand_ids_size(), 2);
*new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
break;
case HloOpcode::kSelect: {
TF_ASSIGN_OR_RETURN(bool constant_predicate,
operand_is_constant(instr_proto, 0));
if (!constant_predicate) {
// If the predicate operand is not constant, conservatively assume the
// entire result values are dynamic.
SetInstructionAsConstant(new_instr, id, new_shape, true);
}
break;
}
case HloOpcode::kGather: {
TF_ASSIGN_OR_RETURN(bool constant_indices,
operand_is_constant(instr_proto, 1));
if (!constant_indices) {
// If the indices operand is not constant, conservatively assume the
// entire result values are dynamic.
SetInstructionAsConstant(new_instr, id, new_shape, true);
}
break;
}
case HloOpcode::kReduce: {
auto* reducer = &embedded_[new_instr->called_computation_ids(0)];
int64 reducer_id = (*global_id)++;
new_instr->clear_called_computation_ids();
new_instr->add_called_computation_ids(reducer_id);
called_computations.push_back(CopyReducer(
reducer_id, reducer, /*rewrite_into_pred=*/true, global_id));
break;
}
case HloOpcode::kTuple:
case HloOpcode::kTranspose:
case HloOpcode::kSlice:
case HloOpcode::kBroadcast:
case HloOpcode::kConcatenate:
case HloOpcode::kReshape:
case HloOpcode::kPad:
break;
case HloOpcode::kGetTupleElement: {
// Rewrite parameter followed by gte into constants to avoid
// rematerializing the tuple parameter (could be very large).
int64 operand_handle = instr_proto->operand_ids(0);
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
LookUpInstructionByHandle(operand_handle));
TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
StringToHloOpcode(operand_proto->opcode()));
if (operand_opcode == HloOpcode::kParameter) {
SetInstructionAsConstant(new_instr, id, new_shape, true);
}
break;
}
case HloOpcode::kGetDimensionSize: {
int64 dimension = instr_proto->dimensions(0);
int64 operand_handle = instr_proto->operand_ids(0);
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
LookUpInstructionByHandle(operand_handle));
SetInstructionAsConstant(
new_instr, id, new_shape,
operand_proto->shape().is_dynamic_dimension(dimension));
break;
}
case HloOpcode::kConstant:
case HloOpcode::kIota:
SetInstructionAsConstant(new_instr, id, new_shape, false);
break;
case HloOpcode::kCustomCall:
if (instr_proto->custom_call_target() == "SetBound") {
SetInstructionAsConstant(new_instr, id, new_shape, true);
break;
} else {
return InvalidArgument(
"Dynamic inferencing on custom call %s is not supported",
instr_proto->DebugString());
}
case HloOpcode::kParameter:
SetInstructionAsConstant(new_instr, id, new_shape, true);
break;
default:
return InvalidArgument("Dynamic inferencing %s is not supported",
instr_proto->DebugString());
}
*new_instr->mutable_name() =
GetFullName(instr_proto->opcode(), kNameSeparator, id);
return Status::OK();
};
struct WorkItem {
explicit WorkItem(int64 handle, bool need_rewrite)
: handle(handle), need_rewrite(need_rewrite), visited(false) {}
int64 handle;
// If need_rewrite is true, the instruction will be copied and rewrite into
// a pred instruction indicating if each value is dynamic. If need_rewrite
// is false, simply copy the instruction to the output graph.
// E.g.,
// For select(P, A, B), we need to rewrite A and B into predicates, but
// don't need to rewrite P.
bool need_rewrite;
// Used in dfs to remember the ids of processed operands of this item.
std::vector<int64> processed_operands;
// Whether this node been visited before or not.
bool visited;
};
// Only copy each pair of {handle, need_rewrite} once. Value is the id in the
// new graph.
absl::flat_hash_map<std::pair<int64, bool>, int64> seen;
// Monotonically increasing id to assign to new instructions.
int64 global_id = 0;
// The result id of the last rewritten item -- return value of last stack
// item.
int64 stacktop_id = -1;
std::vector<WorkItem> worklist;
worklist.push_back(WorkItem(root->id(), true));
while (!worklist.empty()) {
WorkItem& item = worklist.back();
auto item_key = std::make_pair(item.handle, item.need_rewrite);
auto iter = seen.find(item_key);
// Already processed this item. Return previous results.
if (iter != seen.end()) {
stacktop_id = iter->second;
worklist.pop_back();
continue;
}
int64 next_operand = item.processed_operands.size();
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
LookUpInstructionByHandle(item.handle));
VLOG(3) << "Processing" << instr_proto->name();
if (!item.visited) {
item.visited = true;
} else {
// Record previous processed operand.
item.processed_operands.push_back(stacktop_id);
next_operand++;
}
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
StringToHloOpcode(instr_proto->opcode()));
bool should_visit_operand = true;
if (opcode == HloOpcode::kGetDimensionSize) {
// We rewrite gte instructions into constants based on its operand shape
// so no need to visit their operands.
should_visit_operand = false;
}
if (opcode == HloOpcode::kGetTupleElement) {
int64 operand_handle = instr_proto->operand_ids(0);
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
LookUpInstructionByHandle(operand_handle));
TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
StringToHloOpcode(operand_proto->opcode()));
if (operand_opcode == HloOpcode::kParameter) {
// Don't rematerialize the whole parameter if it's followed by a GTE.
should_visit_operand = false;
}
}
if (opcode == HloOpcode::kSelect) {
TF_ASSIGN_OR_RETURN(bool constant_predicate,
operand_is_constant(instr_proto, 0));
// If the predicate operand (first operand) is non-constant, we don't
// visit operands and we say the all result values are dynamic.
should_visit_operand = constant_predicate;
}
if (opcode == HloOpcode::kGather) {
TF_ASSIGN_OR_RETURN(bool constant_indices,
operand_is_constant(instr_proto, 1));
// If the indices operand (second operand) is non-constant, we don't
// visit operands and we say the all result values are dynamic.
should_visit_operand = constant_indices;
}
if (opcode == HloOpcode::kGetDimensionSize && next_operand == 0) {
// Always rewrite get dimension size into constant.
item.need_rewrite = true;
}
if (next_operand >= instr_proto->operand_ids_size() ||
!should_visit_operand || InstrIsSetBound(instr_proto)) {
// No more operands to process, process self.
int64 new_id = global_id++;
VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name();
TF_RETURN_IF_ERROR(process_instruction(instr_proto, item.need_rewrite,
new_id, item.processed_operands,
&global_id));
stacktop_id = new_id;
seen[item_key] = stacktop_id;
worklist.pop_back();
} else {
// Visit and process operand. If an operand doesn't need rewrite
// (predicate of kSelect, or indices of kGather), we also don't rewrite
// its ancestors.
WorkItem next_item(instr_proto->operand_ids(next_operand),
item.need_rewrite);
if (opcode == HloOpcode::kSelect && next_operand == 0) {
next_item.need_rewrite = false;
}
if (opcode == HloOpcode::kGather && next_operand == 1) {
next_item.need_rewrite = false;
}
// Push next operand into worklist.
worklist.push_back(next_item);
}
}
TF_RET_CHECK(stacktop_id != -1);
entry.set_root_id(stacktop_id);
absl::c_sort(*entry.mutable_instructions(),
[](const HloInstructionProto& p1,
const HloInstructionProto& p2) { return p1.id() < p2.id(); });
XlaComputation computation(entry.id());
HloModuleProto* module = computation.mutable_proto();
module->set_name(entry.name());
module->set_id(entry.id());
module->set_entry_computation_name(entry.name());
module->set_entry_computation_id(entry.id());
*module->mutable_host_program_shape() = *program_shape;
for (auto& called_comp : called_computations) {
*module->add_computations() = called_comp;
}
*module->add_computations() = std::move(entry);
// Make sure all ids appear in the computation with ascending order.
absl::c_sort(*module->mutable_computations(),
[](const HloComputationProto& c1,
const HloComputationProto& c2) { return c1.id() < c2.id(); });
XLA_VLOG_LINES(3, module->DebugString());
return std::move(computation);
}
StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
XlaOp root_op, bool dynamic_dimension_is_minus_one) {
TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));

View File

@ -114,6 +114,7 @@ class XlaOp {
int64 handle() const { return handle_; }
friend class XlaBuilder;
friend class ValueInference;
friend class MlirHloBuilder;
friend struct internal::XlaBuilderFriend;
@ -296,31 +297,6 @@ class XlaBuilder {
StatusOr<XlaComputation> BuildConstantSubGraph(
XlaOp root_op, bool dynamic_dimension_is_uint_max = false);
// Similar to BuildConstantSubGraph, but with root element type changed to
// boolean. A true value in the root indicates that the value is dynamic while
// false value indicates that the value is a constant. This will copy the
// needed ops/computations to the subgraph.
//
// E.g.,
// Compuptation {
// a = 3
// b = param(0)
// ROOT Tuple(a + b, a + 1, b + 1)
// }
// Calling BuildDynamicInferenceGraph on root will produce the following
// graph:
//
// Compuptation {
// a = False
// b = True
// ROOT Tuple(a | b, a, b)
// }
//
// The result, which is (True, False, True) after evaluation, can be
// interpreted as "First element is dynamic; Second element is static; Third
// element is dynamic".
StatusOr<XlaComputation> BuildDynamicInferenceGraph(XlaOp root_op);
// Returns the first error that was encountered while building the
// computation. When an error is encountered, by default we return a vacuous
// XlaOp and inform the user of the error that occurred while
@ -1478,6 +1454,8 @@ class XlaBuilder {
}
friend struct internal::XlaBuilderFriend;
friend class ValueInference;
};
// RAII-style object: sets the current sharding assignment in builder on

View File

@ -497,7 +497,7 @@ class HloInstruction {
static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
const HloInstructionProto& proto,
const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
const absl::flat_hash_map<int64, HloComputation*>& computation_map,
const absl::flat_hash_map<int64, HloComputation*>& computation_map = {},
bool prohibit_empty_literal = true);
// Creates a parameter-retrieving instruction.
@ -1043,6 +1043,7 @@ class HloInstruction {
// Returns the opcode for this instruction.
HloOpcode opcode() const { return opcode_; }
HloOpcode* mutable_opcode() { return &opcode_; }
// Returns true if this instruction has a side effect, irrespective of whether
// any called computations may contain an instruction with side effects.

View File

@ -2097,6 +2097,7 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:value_inference",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/prng.h"
#include "tensorflow/compiler/xla/client/value_inference.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
@ -44,7 +45,6 @@ namespace {
// An enumerator for the client types that we want to iterate over in
// the various tests.
enum class ClientType { kLocal, kCompileOnly };
ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly};
class DynamismInferenceTest : public ::testing::Test {
public:
@ -72,21 +72,18 @@ class DynamismInferenceTest : public ::testing::Test {
LOG(FATAL) << "invalid client_type value";
}
StatusOr<Literal> ComputeDynamismLiteral(Client* client, XlaOp operand,
XlaBuilder* builder,
StatusOr<Literal> ComputeDynamismLiteral(XlaOp operand, XlaBuilder* builder,
Layout* output_layout = nullptr) {
TF_ASSIGN_OR_RETURN(auto subgraph,
builder->BuildDynamicInferenceGraph(operand));
TF_ASSIGN_OR_RETURN(auto computed,
client->ComputeConstant(subgraph, output_layout));
return std::move(computed);
ValueInference value_inference(builder);
TF_ASSIGN_OR_RETURN(auto literal_slice,
value_inference.AnalyzeIsDynamic(operand));
return literal_slice.Clone();
}
StatusOr<bool> ComputeDynamismScalar(Client* client, XlaOp operand,
XlaBuilder* builder,
StatusOr<bool> ComputeDynamismScalar(XlaOp operand, XlaBuilder* builder,
ShapeIndex index = {}) {
TF_ASSIGN_OR_RETURN(auto literal, ComputeDynamismLiteral(client, operand,
builder, nullptr));
TF_ASSIGN_OR_RETURN(auto literal,
ComputeDynamismLiteral(operand, builder, nullptr));
return literal.Get<bool>({}, index);
}
@ -94,265 +91,202 @@ class DynamismInferenceTest : public ::testing::Test {
};
TEST_F(DynamismInferenceTest, ScalarInt32Literal) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto computation = ConstantR0<int32>(&b, 42);
XlaBuilder b(TestName());
auto computation = ConstantR0<int32>(&b, 42);
auto value = ComputeDynamismScalar(client, computation, &b);
ASSERT_TRUE(value.ok()) << value.status();
// A constant is not dynamic.
EXPECT_EQ(value.ValueOrDie(), false);
}
auto value = ComputeDynamismScalar(computation, &b);
ASSERT_TRUE(value.ok()) << value.status();
// A constant is not dynamic.
EXPECT_EQ(value.ValueOrDie(), false);
}
TEST_F(DynamismInferenceTest, Iota) {
// The output of iota are consistened static.
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto computation = Iota(&b, S32, 2);
// Iota is not dynamic.
EXPECT_FALSE(ComputeDynamismLiteral(client, computation, &b)
.ValueOrDie()
.Get<bool>({0}));
}
XlaBuilder b(TestName());
auto computation = Iota(&b, S32, 2);
// Iota is not dynamic.
EXPECT_FALSE(
ComputeDynamismLiteral(computation, &b).ValueOrDie().Get<bool>({0}));
}
TEST_F(DynamismInferenceTest, TupleSimple) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
auto tuple = Tuple(&b, {c, p});
EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {0}).ValueOrDie(),
false);
EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {1}).ValueOrDie(), true);
}
auto tuple = Tuple(&b, {c, p});
EXPECT_EQ(ComputeDynamismScalar(tuple, &b, {0}).ValueOrDie(), false);
EXPECT_EQ(ComputeDynamismScalar(tuple, &b, {1}).ValueOrDie(), true);
}
TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
auto tuple = Tuple(&b, {c, p});
auto gte0 = GetTupleElement(tuple, 0);
auto gte1 = GetTupleElement(tuple, 1);
auto tuple_2 = Tuple(&b, {gte0, gte1});
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
false);
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
true);
}
auto tuple = Tuple(&b, {c, p});
auto gte0 = GetTupleElement(tuple, 0);
auto gte1 = GetTupleElement(tuple, 1);
auto tuple_2 = Tuple(&b, {gte0, gte1});
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
}
TEST_F(DynamismInferenceTest, PredValueUsedTwice) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
auto pred = Eq(c, p);
auto result = Select(pred, p, c);
EXPECT_EQ(ComputeDynamismScalar(client, result, &b, {}).ValueOrDie(), true);
}
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
auto pred = Eq(c, p);
auto result = Select(pred, p, c);
EXPECT_EQ(ComputeDynamismScalar(result, &b, {}).ValueOrDie(), true);
}
TEST_F(DynamismInferenceTest, ReduceUsedTwice) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}), "p0");
auto zero = ConstantR0<int32>(&b, 0);
XlaComputation add_s32 = CreateScalarAddComputation(S32, &b);
auto reduce = Reduce(p, zero, add_s32, {0});
auto pred = Eq(c, reduce);
auto result = Select(pred, reduce, c);
EXPECT_EQ(ComputeDynamismScalar(client, result, &b, {}).ValueOrDie(), true);
}
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}), "p0");
auto zero = ConstantR0<int32>(&b, 0);
XlaComputation add_s32 = CreateScalarAddComputation(S32, &b);
auto reduce = Reduce(p, zero, add_s32, {0});
auto pred = Eq(c, reduce);
auto result = Select(pred, reduce, c);
EXPECT_EQ(ComputeDynamismScalar(result, &b, {}).ValueOrDie(), true);
}
TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
auto concat = ConcatScalars(&b, {c, p});
auto slice0 = SliceInDim(concat, 0, 1, 1, 0);
auto reshape0 = Reshape(slice0, {});
auto slice1 = SliceInDim(concat, 1, 2, 1, 0);
auto reshape1 = Reshape(slice1, {});
auto tuple_2 = Tuple(&b, {reshape0, reshape1});
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
false);
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
true);
}
auto concat = ConcatScalars(&b, {c, p});
auto slice0 = SliceInDim(concat, 0, 1, 1, 0);
auto reshape0 = Reshape(slice0, {});
auto slice1 = SliceInDim(concat, 1, 2, 1, 0);
auto reshape1 = Reshape(slice1, {});
auto tuple_2 = Tuple(&b, {reshape0, reshape1});
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
}
TEST_F(DynamismInferenceTest, ParameterIsDynamic) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
XlaBuilder b(TestName());
auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
auto value = ComputeDynamismScalar(client, computation, &b);
ASSERT_TRUE(value.ok()) << value.status();
// A parameter is considered dynamic.
EXPECT_EQ(value.ValueOrDie(), true);
}
auto value = ComputeDynamismScalar(computation, &b);
ASSERT_TRUE(value.ok()) << value.status();
// A parameter is considered dynamic.
EXPECT_EQ(value.ValueOrDie(), true);
}
TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
auto neg0 = Neg(c);
auto neg1 = Neg(p);
auto tuple_2 = Tuple(&b, {neg0, neg1});
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
false);
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
true);
}
auto neg0 = Neg(c);
auto neg1 = Neg(p);
auto tuple_2 = Tuple(&b, {neg0, neg1});
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
}
TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
// Static value + static value = static
auto add1 = Add(c, c);
// Dynamic value + dynamic value = dynamic
auto add2 = Add(p, c);
auto tuple_2 = Tuple(&b, {add1, add2});
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
false);
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
true);
}
// Static value + static value = static
auto add1 = Add(c, c);
// Dynamic value + dynamic value = dynamic
auto add2 = Add(p, c);
auto tuple_2 = Tuple(&b, {add1, add2});
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
}
TEST_F(DynamismInferenceTest, GetDimensionSize) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
// param = Param([<=2, 3])
// get_dimension_size(param, 0) is dynamic
// get_dimension_size(param, 1) is static
auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}),
"p0");
XlaBuilder b(TestName());
// param = Param([<=2, 3])
// get_dimension_size(param, 0) is dynamic
// get_dimension_size(param, 1) is static
auto p =
Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
auto gds0 = GetDimensionSize(p, 0);
auto gds1 = GetDimensionSize(p, 1);
auto tuple_2 = Tuple(&b, {gds0, gds1});
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
true);
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
false);
}
auto gds0 = GetDimensionSize(p, 0);
auto gds1 = GetDimensionSize(p, 1);
auto tuple_2 = Tuple(&b, {gds0, gds1});
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), true);
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), false);
}
TEST_F(DynamismInferenceTest, GatherWithCommonParent) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
// Test the analysis on a gather where first operand and second operand have
// common parents.
Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
XlaBuilder b(TestName());
// Test the analysis on a gather where first operand and second operand have
// common parents.
Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
auto operand1 = Parameter(&b, 0, indices_shape, "p1");
auto operand2 = Parameter(&b, 1, indices_shape, "p2");
auto indices = Sub(operand1, operand2);
GatherDimensionNumbers dim_numbers;
dim_numbers.add_offset_dims(1);
dim_numbers.add_start_index_map(0);
dim_numbers.set_index_vector_dim(1);
auto gather = Gather(operand1, indices, dim_numbers, {1});
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
EXPECT_TRUE(ComputeDynamismLiteral(client, gather, &b)
.ValueOrDie()
.Get<bool>({0, 0}));
}
auto operand1 = Parameter(&b, 0, indices_shape, "p1");
auto operand2 = Parameter(&b, 1, indices_shape, "p2");
auto indices = Sub(operand1, operand2);
GatherDimensionNumbers dim_numbers;
dim_numbers.add_offset_dims(1);
dim_numbers.add_start_index_map(0);
dim_numbers.set_index_vector_dim(1);
auto gather = Gather(operand1, indices, dim_numbers, {1});
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
EXPECT_TRUE(
ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0}));
}
TEST_F(DynamismInferenceTest, GatherWithConstantParent) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
// Test the analysis on a gather.
Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
auto data_operand = ConstantR1<int32>(&b, {1, 2});
auto indices = ConstantR1<int32>(&b, {1, 2});
GatherDimensionNumbers dim_numbers;
dim_numbers.add_offset_dims(1);
dim_numbers.add_start_index_map(0);
dim_numbers.set_index_vector_dim(1);
auto gather = Gather(data_operand, indices, dim_numbers, {1});
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
// Everything is constant, result is also contant.
EXPECT_FALSE(ComputeDynamismLiteral(client, gather, &b)
.ValueOrDie()
.Get<bool>({0, 0}));
}
XlaBuilder b(TestName());
// Test the analysis on a gather.
Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
auto data_operand = ConstantR1<int32>(&b, {1, 2});
auto indices = ConstantR1<int32>(&b, {1, 2});
GatherDimensionNumbers dim_numbers;
dim_numbers.add_offset_dims(1);
dim_numbers.add_start_index_map(0);
dim_numbers.set_index_vector_dim(1);
auto gather = Gather(data_operand, indices, dim_numbers, {1});
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
// Everything is constant, result is also contant.
EXPECT_FALSE(
ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0}));
}
TEST_F(DynamismInferenceTest, GatherWithSharedConstantParent) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
// Test the analysis on a gather.
Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
auto operand1 = ConstantR1<int32>(&b, {1, 2});
auto operand2 = ConstantR1<int32>(&b, {1, 2});
auto indices = Sub(operand1, operand2);
GatherDimensionNumbers dim_numbers;
dim_numbers.add_offset_dims(1);
dim_numbers.add_start_index_map(0);
dim_numbers.set_index_vector_dim(1);
auto gather = Gather(operand1, indices, dim_numbers, {1});
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
// Everything is constant, result is also contant.
EXPECT_FALSE(ComputeDynamismLiteral(client, gather, &b)
.ValueOrDie()
.Get<bool>({0, 0}));
}
XlaBuilder b(TestName());
// Test the analysis on a gather.
Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
auto operand1 = ConstantR1<int32>(&b, {1, 2});
auto operand2 = ConstantR1<int32>(&b, {1, 2});
auto indices = Sub(operand1, operand2);
GatherDimensionNumbers dim_numbers;
dim_numbers.add_offset_dims(1);
dim_numbers.add_start_index_map(0);
dim_numbers.set_index_vector_dim(1);
auto gather = Gather(operand1, indices, dim_numbers, {1});
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
// Everything is constant, result is also contant.
EXPECT_FALSE(
ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0}));
}
TEST_F(DynamismInferenceTest, InferThroughPad) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
// Test the analysis on a gather.
auto operand1 = ConstantR1<int32>(&b, {1, 2});
auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0");
PaddingConfig padding_config;
padding_config.add_dimensions()->set_edge_padding_high(1);
// After pad the value is [constant, constant, parameter].
auto pad = Pad(operand1, parameter, padding_config);
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
// Everything is constant, result is also contant.
EXPECT_FALSE(
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({0}));
EXPECT_FALSE(
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({1}));
EXPECT_TRUE(
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({2}));
}
XlaBuilder b(TestName());
// Test the analysis on a gather.
auto operand1 = ConstantR1<int32>(&b, {1, 2});
auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0");
PaddingConfig padding_config;
padding_config.add_dimensions()->set_edge_padding_high(1);
// After pad the value is [constant, constant, parameter].
auto pad = Pad(operand1, parameter, padding_config);
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
// Everything is constant, result is also contant.
EXPECT_FALSE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({0}));
EXPECT_FALSE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({1}));
EXPECT_TRUE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({2}));
}
} // namespace