XLA value inference.
XLA value inference is designed to be a one stop service that analyzes the attributes of each value in a tensor. The attributes include: - What's the upper-bound of each value in a tensor. - What's the lower-bound of each value in a tensor. - What's the constant value of each tensor. - Whether or not each value in a tensor is dynamic. This cl implements the dynamism inference part, which replaces the old one in XLA builder that has become a bit too complex to maintain over time. PiperOrigin-RevId: 360753561 Change-Id: I55c6ec9abf43f1502c5173dafcc47ac6c41bea09
This commit is contained in:
parent
47f244679c
commit
71ac02ec1f
@ -485,6 +485,7 @@ cc_library(
|
||||
":xla_resource",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client",
|
||||
"//tensorflow/compiler/xla/client:value_inference",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/client/value_inference.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
@ -136,18 +137,15 @@ xla::StatusOr<Tensor> XlaExpression::ResolveDynamism(
|
||||
if (!client)
|
||||
return errors::InvalidArgument("client is required to resolve constant");
|
||||
|
||||
TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
|
||||
handle().builder()->BuildDynamicInferenceGraph(handle()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
|
||||
|
||||
// The XLA layout is specified minor to major, and TensorFlow uses a major to
|
||||
// minor order.
|
||||
std::vector<int64> layout_indices(shape.dims());
|
||||
std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
|
||||
xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
|
||||
TF_ASSIGN_OR_RETURN(xla::Literal literal,
|
||||
client->ComputeConstant(constant_graph, &layout));
|
||||
xla::ValueInference value_inference(handle().builder());
|
||||
TF_ASSIGN_OR_RETURN(xla::LiteralSlice literal,
|
||||
value_inference.AnalyzeIsDynamic(handle()));
|
||||
Tensor tensor(DT_BOOL);
|
||||
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, DT_BOOL, &tensor));
|
||||
return tensor;
|
||||
|
@ -211,6 +211,25 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "value_inference",
|
||||
srcs = ["value_inference.cc"],
|
||||
hdrs = ["value_inference.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":xla_builder",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_evaluator",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_builder",
|
||||
srcs = ["xla_builder.cc"],
|
||||
|
384
tensorflow/compiler/xla/client/value_inference.cc
Normal file
384
tensorflow/compiler/xla/client/value_inference.cc
Normal file
@ -0,0 +1,384 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/xla/client/value_inference.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
Literal CreatePredLiteral(bool pred, const Shape& reference_shape) {
|
||||
Literal literal = LiteralUtil::CreateR0(pred);
|
||||
Literal literal_broadcast =
|
||||
literal
|
||||
.Broadcast(ShapeUtil::ChangeElementType(Shape(reference_shape), PRED),
|
||||
{})
|
||||
.ValueOrDie();
|
||||
return literal_broadcast;
|
||||
}
|
||||
|
||||
using GetOperand = std::function<StatusOr<LiteralSlice>(int64 operand_index,
|
||||
int64 opreand_handle)>;
|
||||
|
||||
// HloProtoEvaluator evaluates an hlo proto and returns a literal. The user has
|
||||
// to provide operand as literals through the get_operand function.
|
||||
struct HloProtoEvaluator {
|
||||
explicit HloProtoEvaluator(HloInstructionProto inst, GetOperand get_operand)
|
||||
: inst(std::move(inst)),
|
||||
get_operand(get_operand),
|
||||
module("EmptyModuleForEvaluation", HloModuleConfig()) {}
|
||||
|
||||
// WithOpCode changes the called computation of the instruction being
|
||||
// evaluated.
|
||||
HloProtoEvaluator& WithComputation(
|
||||
std::unique_ptr<HloComputation> new_computation) {
|
||||
computation = new_computation.get();
|
||||
computation->ClearUniqueIdInternal();
|
||||
for (HloInstruction* inst : computation->instructions()) {
|
||||
inst->ClearUniqueIdInternal();
|
||||
}
|
||||
module.AddEmbeddedComputation(std::move(new_computation));
|
||||
return *this;
|
||||
}
|
||||
|
||||
// WithOpCode changes the primitive type of the instruction being evaluated.
|
||||
HloProtoEvaluator& WithPrimitiveType(PrimitiveType new_primitive_type) {
|
||||
primitive_type = new_primitive_type;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// WithOpCode changes the opcode of the instruction being evaluated.
|
||||
HloProtoEvaluator& WithOpCode(HloOpcode new_opcode) {
|
||||
opcode = new_opcode;
|
||||
return *this;
|
||||
}
|
||||
|
||||
StatusOr<Literal> Evaluate() {
|
||||
// Evaluate the instruction by swapping it's operands with constant
|
||||
// instructions with given literals.
|
||||
HloComputation::Builder builder("EmptyComputation");
|
||||
absl::flat_hash_map<int64, HloInstruction*> operand_map;
|
||||
for (int64 i = 0; i < inst.operand_ids_size(); ++i) {
|
||||
int64 operand_handle = inst.operand_ids(i);
|
||||
TF_ASSIGN_OR_RETURN(auto literal, get_operand(i, inst.operand_ids(i)));
|
||||
std::unique_ptr<HloInstruction> operand =
|
||||
HloInstruction::CreateConstant(literal.Clone());
|
||||
operand_map[operand_handle] = operand.get();
|
||||
builder.AddInstruction(std::move(operand));
|
||||
}
|
||||
|
||||
if (primitive_type.has_value()) {
|
||||
*inst.mutable_shape() = ShapeUtil::ChangeElementType(
|
||||
Shape(inst.shape()), primitive_type.value())
|
||||
.ToProto();
|
||||
}
|
||||
if (opcode.has_value()) {
|
||||
*inst.mutable_opcode() = HloOpcodeString(opcode.value());
|
||||
}
|
||||
absl::flat_hash_map<int64, HloComputation*> computation_map;
|
||||
if (inst.called_computation_ids_size() != 0) {
|
||||
TF_RET_CHECK(inst.called_computation_ids_size() == 1 &&
|
||||
computation != nullptr)
|
||||
<< inst.DebugString();
|
||||
computation_map[inst.called_computation_ids(0)] = computation;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto new_instruction,
|
||||
HloInstruction::CreateFromProto(inst, operand_map, computation_map));
|
||||
new_instruction->ClearUniqueIdInternal();
|
||||
builder.AddInstruction(std::move(new_instruction));
|
||||
auto computation = builder.Build();
|
||||
module.AddEntryComputation(std::move(computation));
|
||||
HloEvaluator evaluator;
|
||||
return evaluator.Evaluate(module.entry_computation()->root_instruction());
|
||||
}
|
||||
|
||||
HloInstructionProto inst;
|
||||
GetOperand get_operand;
|
||||
HloModule module;
|
||||
HloComputation* computation = nullptr;
|
||||
absl::optional<PrimitiveType> primitive_type = absl::nullopt;
|
||||
absl::optional<HloOpcode> opcode = absl::nullopt;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
StatusOr<Literal> ValueInference::AnalyzeConstantLiteral(int64 handle) {
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
|
||||
builder_->LookUpInstructionByHandle(handle));
|
||||
if (constant_.contains(handle)) {
|
||||
return constant_[handle].Clone();
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
|
||||
switch (opcode) {
|
||||
case HloOpcode::kGetDimensionSize: {
|
||||
int64 dimension = root->dimensions(0);
|
||||
int64 operand_handle = root->operand_ids(0);
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
|
||||
builder_->LookUpInstructionByHandle(operand_handle));
|
||||
if (operand_proto->shape().is_dynamic_dimension(dimension)) {
|
||||
return InvalidArgument(
|
||||
"AnalyzeConstant is called on a GetDimensionSize on dynamic "
|
||||
"dimension.");
|
||||
} else {
|
||||
return LiteralUtil::CreateR0<int32>(
|
||||
operand_proto->shape().dimensions(dimension));
|
||||
}
|
||||
}
|
||||
// Non functional ops.
|
||||
case HloOpcode::kRng:
|
||||
case HloOpcode::kAllReduce:
|
||||
// TODO(b/33009255): Implement constant folding for cross replica sum.
|
||||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kCall:
|
||||
// TODO(b/32495713): We aren't checking the to_apply computation itself,
|
||||
// so we conservatively say that computations containing the Call op
|
||||
// cannot be constant. We cannot set is_functional=false in other similar
|
||||
// cases since we're already relying on IsConstant to return true.
|
||||
case HloOpcode::kCustomCall:
|
||||
case HloOpcode::kWhile:
|
||||
case HloOpcode::kConditional:
|
||||
// TODO(b/32495713): We aren't checking the condition and body
|
||||
// computations themselves.
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kParameter: {
|
||||
return InvalidArgument("Can't analyze constant values on instruction %s",
|
||||
root->DebugString());
|
||||
}
|
||||
case HloOpcode::kReduce:
|
||||
case HloOpcode::kScatter:
|
||||
case HloOpcode::kReduceWindow: {
|
||||
HloComputationProto computation_proto =
|
||||
builder_->embedded_[root->called_computation_ids(0)];
|
||||
TF_ASSIGN_OR_RETURN(auto computation, HloComputation::CreateFromProto(
|
||||
computation_proto, {}));
|
||||
return HloProtoEvaluator(*root,
|
||||
[&](int64 operand_index, int64 operand_handle) {
|
||||
return AnalyzeConstant(operand_handle);
|
||||
})
|
||||
.WithComputation(std::move(computation))
|
||||
.Evaluate();
|
||||
}
|
||||
default:
|
||||
return HloProtoEvaluator(*root,
|
||||
[&](int64 operand_index, int64 operand_handle) {
|
||||
return AnalyzeConstant(operand_handle);
|
||||
})
|
||||
.Evaluate();
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<Literal> ValueInference::AnalyzeIsDynamicLiteral(int64 handle) {
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
|
||||
builder_->LookUpInstructionByHandle(handle));
|
||||
if (is_dynamic_.contains(handle)) {
|
||||
return is_dynamic_[handle].Clone();
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
|
||||
switch (opcode) {
|
||||
case HloOpcode::kGetDimensionSize: {
|
||||
int64 dimension = root->dimensions(0);
|
||||
int64 operand_handle = root->operand_ids(0);
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
|
||||
builder_->LookUpInstructionByHandle(operand_handle));
|
||||
return LiteralUtil::CreateR0<bool>(
|
||||
operand_proto->shape().is_dynamic_dimension(dimension));
|
||||
}
|
||||
case HloOpcode::kAbs:
|
||||
case HloOpcode::kRoundNearestAfz:
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kCollectivePermuteDone:
|
||||
case HloOpcode::kCos:
|
||||
case HloOpcode::kClz:
|
||||
case HloOpcode::kExp:
|
||||
case HloOpcode::kExpm1:
|
||||
case HloOpcode::kFloor:
|
||||
case HloOpcode::kImag:
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kLog:
|
||||
case HloOpcode::kLog1p:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kPopulationCount:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kRsqrt:
|
||||
case HloOpcode::kLogistic:
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kConvert:
|
||||
case HloOpcode::kSqrt:
|
||||
case HloOpcode::kCbrt:
|
||||
case HloOpcode::kTanh: {
|
||||
// Forward operand as they don't change if a value is dynamic or static.
|
||||
int64 operand_handle = root->operand_ids(0);
|
||||
TF_ASSIGN_OR_RETURN(auto literal, AnalyzeIsDynamic(operand_handle));
|
||||
return literal.Clone();
|
||||
}
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kAtan2:
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kMaximum:
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kPower:
|
||||
case HloOpcode::kRemainder:
|
||||
case HloOpcode::kSubtract:
|
||||
case HloOpcode::kCompare:
|
||||
case HloOpcode::kAnd:
|
||||
case HloOpcode::kOr:
|
||||
case HloOpcode::kXor:
|
||||
case HloOpcode::kShiftLeft:
|
||||
case HloOpcode::kShiftRightArithmetic:
|
||||
case HloOpcode::kShiftRightLogical: {
|
||||
return HloProtoEvaluator(*root,
|
||||
[&](int64 operand_index, int64 operand_handle) {
|
||||
return AnalyzeIsDynamic(operand_handle);
|
||||
})
|
||||
.WithPrimitiveType(PRED)
|
||||
.WithOpCode(HloOpcode::kOr)
|
||||
.Evaluate();
|
||||
}
|
||||
case HloOpcode::kTuple:
|
||||
case HloOpcode::kTranspose:
|
||||
case HloOpcode::kSlice:
|
||||
case HloOpcode::kBroadcast:
|
||||
case HloOpcode::kConcatenate:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kPad: {
|
||||
return HloProtoEvaluator(*root,
|
||||
[&](int64 operand_index, int64 operand_handle) {
|
||||
return AnalyzeIsDynamic(operand_handle);
|
||||
})
|
||||
.WithPrimitiveType(PRED)
|
||||
.Evaluate();
|
||||
}
|
||||
case HloOpcode::kGetTupleElement: {
|
||||
int64 operand_handle = root->operand_ids(0);
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
|
||||
builder_->LookUpInstructionByHandle(operand_handle));
|
||||
TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
|
||||
StringToHloOpcode(operand_proto->opcode()));
|
||||
if (operand_opcode == HloOpcode::kParameter) {
|
||||
// Don't materialize the whole parameter if it's followed by a GTE.
|
||||
return CreatePredLiteral(true, Shape(root->shape()));
|
||||
}
|
||||
return HloProtoEvaluator(*root,
|
||||
[&](int64 operand_index, int64 operand_handle) {
|
||||
return AnalyzeIsDynamic(operand_handle);
|
||||
})
|
||||
.WithPrimitiveType(PRED)
|
||||
.Evaluate();
|
||||
}
|
||||
|
||||
case HloOpcode::kReduce: {
|
||||
std::vector<std::unique_ptr<HloInstruction>> operand_storage;
|
||||
absl::flat_hash_map<int64, HloInstruction*> operand_map;
|
||||
absl::flat_hash_map<int64, HloComputation*> computation_map;
|
||||
|
||||
Shape scalar_shape = ShapeUtil::MakeScalarShape(xla::PRED);
|
||||
HloComputation::Builder b("reduce_or");
|
||||
auto lhs = b.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
|
||||
auto rhs = b.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
|
||||
b.AddInstruction(
|
||||
HloInstruction::CreateBinary(scalar_shape, HloOpcode::kOr, lhs, rhs));
|
||||
auto reduce_computation = b.Build();
|
||||
return HloProtoEvaluator(*root,
|
||||
[&](int64 operand_index, int64 operand_handle) {
|
||||
return AnalyzeIsDynamic(operand_handle);
|
||||
})
|
||||
.WithPrimitiveType(PRED)
|
||||
.WithComputation(std::move(reduce_computation))
|
||||
.Evaluate();
|
||||
}
|
||||
case HloOpcode::kConstant:
|
||||
case HloOpcode::kIota: {
|
||||
return CreatePredLiteral(false, Shape(root->shape()));
|
||||
}
|
||||
case HloOpcode::kParameter: {
|
||||
return CreatePredLiteral(true, Shape(root->shape()));
|
||||
}
|
||||
case HloOpcode::kSelect: {
|
||||
if (!AnalyzeConstant(root->operand_ids(0)).ok()) {
|
||||
// If the predicate operand is not constant, conservatively assume the
|
||||
// entire result values are dynamic.
|
||||
return CreatePredLiteral(true, Shape(root->shape()));
|
||||
}
|
||||
return HloProtoEvaluator(*root,
|
||||
[&](int64 operand_index, int64 operand_handle) {
|
||||
if (operand_index == 0) {
|
||||
return AnalyzeConstant(operand_handle);
|
||||
} else {
|
||||
return AnalyzeIsDynamic(operand_handle);
|
||||
}
|
||||
})
|
||||
.WithPrimitiveType(PRED)
|
||||
.Evaluate();
|
||||
}
|
||||
case HloOpcode::kGather: {
|
||||
if (!AnalyzeConstant(root->operand_ids(1)).ok()) {
|
||||
// If the index operand is not constant, conservatively assume the
|
||||
// entire result values are dynamic.
|
||||
return CreatePredLiteral(true, Shape(root->shape()));
|
||||
}
|
||||
return HloProtoEvaluator(*root,
|
||||
[&](int64 operand_index, int64 operand_handle) {
|
||||
if (operand_index == 1) {
|
||||
return AnalyzeConstant(operand_handle);
|
||||
} else {
|
||||
return AnalyzeIsDynamic(operand_handle);
|
||||
}
|
||||
})
|
||||
.WithPrimitiveType(PRED)
|
||||
.Evaluate();
|
||||
}
|
||||
case HloOpcode::kCustomCall: {
|
||||
if (root->custom_call_target() == "SetBound") {
|
||||
return CreatePredLiteral(true, Shape(root->shape()));
|
||||
} else {
|
||||
return InvalidArgument(
|
||||
"Dynamic inferencing on custom call %s is not supported",
|
||||
root->DebugString());
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return Unimplemented("Can't infer upper bound through %s: %s",
|
||||
root->opcode(), root->DebugString());
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<LiteralSlice> ValueInference::AnalyzeIsDynamic(int64 handle) {
|
||||
TF_ASSIGN_OR_RETURN(Literal literal, AnalyzeIsDynamicLiteral(handle));
|
||||
is_dynamic_[handle] = std::move(literal);
|
||||
return LiteralSlice(is_dynamic_[handle]);
|
||||
}
|
||||
|
||||
StatusOr<LiteralSlice> ValueInference::AnalyzeConstant(int64 handle) {
|
||||
TF_ASSIGN_OR_RETURN(Literal literal, AnalyzeConstantLiteral(handle));
|
||||
constant_[handle] = std::move(literal);
|
||||
return LiteralSlice(constant_[handle]);
|
||||
}
|
||||
|
||||
} // namespace xla
|
70
tensorflow/compiler/xla/client/value_inference.h
Normal file
70
tensorflow/compiler/xla/client/value_inference.h
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_VALUE_INFERENCE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_CLIENT_VALUE_INFERENCE_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
class ValueInference {
|
||||
public:
|
||||
// ValueInference analyzes values in XlaOp answers following questions:
|
||||
// - What's the upper-bound of each value in a tensor.
|
||||
// - What's the lower-bound of each value in a tensor.
|
||||
// - What's the constant value of each tensor.
|
||||
// - Whether or not each value in a tensor is dynamic.
|
||||
explicit ValueInference(XlaBuilder* builder) : builder_(builder) {}
|
||||
StatusOr<LiteralSlice> AnalyzeUpperBound(XlaOp op) {
|
||||
return Unimplemented("Analyzing upper-bound is not implemented yet.");
|
||||
}
|
||||
StatusOr<LiteralSlice> AnalyzeLowerBound(XlaOp op) {
|
||||
return Unimplemented("Analyzing lower-bound is not implemented yet.");
|
||||
}
|
||||
StatusOr<LiteralSlice> AnalyzeIsDynamic(XlaOp op) {
|
||||
return AnalyzeIsDynamic(op.handle());
|
||||
}
|
||||
StatusOr<LiteralSlice> AnalyzeConstant(XlaOp op) {
|
||||
return AnalyzeConstant(op.handle());
|
||||
}
|
||||
|
||||
private:
|
||||
StatusOr<LiteralSlice> AnalyzeIsDynamic(int64 handle);
|
||||
StatusOr<LiteralSlice> AnalyzeConstant(int64 handle);
|
||||
|
||||
StatusOr<Literal> AnalyzeIsDynamicLiteral(int64 handle);
|
||||
StatusOr<Literal> AnalyzeConstantLiteral(int64 handle);
|
||||
|
||||
XlaBuilder* builder_;
|
||||
// Cache to avoid re-evaluating. Mapping of xla handle to evaluated
|
||||
// literals.
|
||||
absl::flat_hash_map<int64, Literal> upper_bound_;
|
||||
absl::flat_hash_map<int64, Literal> lower_bound_;
|
||||
absl::flat_hash_map<int64, Literal> is_dynamic_;
|
||||
absl::flat_hash_map<int64, Literal> constant_;
|
||||
HloEvaluator evaluator_;
|
||||
};
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_VALUE_INFERENCE_H_
|
@ -83,61 +83,6 @@ void SetProtoIdAndName(T* entry, const string& base_name, char separator,
|
||||
entry->set_name(GetFullName(base_name, separator, id));
|
||||
}
|
||||
|
||||
ShapeProto ConvertShapeProtoToPred(const ShapeProto& shape_proto) {
|
||||
return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto();
|
||||
}
|
||||
|
||||
void SetInstructionAsConstant(HloInstructionProto* instr, int64 id,
|
||||
const Shape& shape, bool pred) {
|
||||
Literal literal = LiteralUtil::CreateR0(pred);
|
||||
Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie();
|
||||
*instr->mutable_shape() = shape.ToProto();
|
||||
*instr->mutable_literal() = literal_broadcast.ToProto();
|
||||
*instr->mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
|
||||
}
|
||||
|
||||
// Copy `original_reducer` into a new computation proto with `reducer_id` as new
|
||||
// id. If `rewrite_into_pred` is true, the instructions in the reducer are
|
||||
// rewritten into predicate form.
|
||||
HloComputationProto CopyReducer(int64 reducer_id,
|
||||
HloComputationProto* original_reducer,
|
||||
bool rewrite_into_pred, int64* global_id) {
|
||||
HloComputationProto reducer;
|
||||
SetProtoIdAndName(&reducer, StrCat("reduce_or"), kNameSeparator, reducer_id);
|
||||
std::vector<int64> operands_id;
|
||||
for (auto& inst : original_reducer->instructions()) {
|
||||
// Copy params.
|
||||
if (StringToHloOpcode(inst.opcode()).ValueOrDie() ==
|
||||
HloOpcode::kParameter) {
|
||||
HloInstructionProto* new_param = reducer.add_instructions();
|
||||
*new_param = inst;
|
||||
new_param->set_id((*global_id)++);
|
||||
*new_param->mutable_name() =
|
||||
GetFullName(inst.name(), '.', new_param->id());
|
||||
if (rewrite_into_pred) {
|
||||
*new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
|
||||
}
|
||||
operands_id.push_back(new_param->id());
|
||||
}
|
||||
if (inst.id() == original_reducer->root_id()) {
|
||||
HloInstructionProto* new_root = reducer.add_instructions();
|
||||
*new_root = inst;
|
||||
new_root->set_id((*global_id)++);
|
||||
*new_root->mutable_name() = GetFullName(inst.name(), '.', new_root->id());
|
||||
if (rewrite_into_pred) {
|
||||
*new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
|
||||
*new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
|
||||
}
|
||||
new_root->clear_operand_ids();
|
||||
for (int64 operand_id : operands_id) {
|
||||
new_root->add_operand_ids(operand_id);
|
||||
}
|
||||
reducer.set_root_id(new_root->id());
|
||||
}
|
||||
}
|
||||
return reducer;
|
||||
}
|
||||
|
||||
bool InstrIsSetBound(const HloInstructionProto* instr_proto) {
|
||||
HloOpcode opcode = StringToHloOpcode(instr_proto->opcode()).ValueOrDie();
|
||||
if (opcode == HloOpcode::kCustomCall &&
|
||||
@ -3444,348 +3389,6 @@ StatusOr<bool> XlaBuilder::IsConstant(XlaOp operand) const {
|
||||
return is_constant;
|
||||
}
|
||||
|
||||
StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
|
||||
LookUpInstruction(root_op));
|
||||
|
||||
HloComputationProto entry;
|
||||
SetProtoIdAndName(&entry, StrCat(name_, "_dynamic_inference"), kNameSeparator,
|
||||
GetNextId());
|
||||
ProgramShapeProto* program_shape = entry.mutable_program_shape();
|
||||
*program_shape->mutable_result() =
|
||||
ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto();
|
||||
|
||||
std::vector<HloComputationProto> called_computations;
|
||||
auto operand_is_constant = [&](const HloInstructionProto* instr_proto,
|
||||
int64 operand_index) -> StatusOr<bool> {
|
||||
int64 operand_id = instr_proto->operand_ids(operand_index);
|
||||
bool is_constant = true;
|
||||
absl::flat_hash_set<int64> visited;
|
||||
IsConstantVisitor(operand_id, &visited, &is_constant);
|
||||
return is_constant;
|
||||
};
|
||||
// Process instruction and copy it into the new graph. The new node in the new
|
||||
// graph with have id set to `id`.
|
||||
auto process_instruction = [&](const HloInstructionProto* instr_proto,
|
||||
bool need_rewrite, int64 id,
|
||||
absl::Span<int64 const> operand_ids,
|
||||
int64* global_id) {
|
||||
// Rewrite the instruction with following rules:
|
||||
// - Unary ops: Convert into bitcast (identity) with type Pred.
|
||||
// - Binary ops: Convert into binary or.
|
||||
// - Select: Convert into binary or with its two data operands.
|
||||
// - Concat / Tuple/ GTE / Bitcast: Copy.
|
||||
// - Param: Convert to constant True.
|
||||
// - GetDimensionSize: Convert to constant True if dimension is dynamic,
|
||||
// contant False if dimension is static.
|
||||
// - Reduce: Convert to reduce or.
|
||||
// - Constant: Convert to constant False.
|
||||
// - Reshape, slice, transpose, pad:
|
||||
// Convert into predicate type with same opcode.
|
||||
// - Other ops: Not supported.
|
||||
// Create the instruction for the new handle.
|
||||
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
|
||||
StringToHloOpcode(instr_proto->opcode()));
|
||||
auto* new_instr = entry.add_instructions();
|
||||
*new_instr = *instr_proto;
|
||||
new_instr->set_id(id);
|
||||
new_instr->mutable_operand_ids()->Clear();
|
||||
for (auto operand_id : operand_ids) {
|
||||
new_instr->mutable_operand_ids()->Add(operand_id);
|
||||
}
|
||||
|
||||
if (!need_rewrite) {
|
||||
*new_instr->mutable_name() =
|
||||
GetFullName(instr_proto->opcode(), kNameSeparator, id);
|
||||
if (opcode == HloOpcode::kReduce) {
|
||||
// Copy the reducer to the new module, with a new id that's same as the
|
||||
// reduce op.
|
||||
HloComputationProto* reducer =
|
||||
&embedded_[new_instr->called_computation_ids(0)];
|
||||
int64 reducer_id = (*global_id)++;
|
||||
new_instr->clear_called_computation_ids();
|
||||
new_instr->add_called_computation_ids(reducer_id);
|
||||
called_computations.push_back(CopyReducer(
|
||||
reducer_id, reducer, /*rewrite_into_pred=*/false, global_id));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
*new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape());
|
||||
Shape new_shape(new_instr->shape());
|
||||
switch (opcode) {
|
||||
case HloOpcode::kAbs:
|
||||
case HloOpcode::kRoundNearestAfz:
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kCollectivePermuteDone:
|
||||
case HloOpcode::kCos:
|
||||
case HloOpcode::kClz:
|
||||
case HloOpcode::kExp:
|
||||
case HloOpcode::kExpm1:
|
||||
case HloOpcode::kFloor:
|
||||
case HloOpcode::kImag:
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kLog:
|
||||
case HloOpcode::kLog1p:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kPopulationCount:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kRsqrt:
|
||||
case HloOpcode::kLogistic:
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kConvert:
|
||||
case HloOpcode::kSqrt:
|
||||
case HloOpcode::kCbrt:
|
||||
case HloOpcode::kTanh:
|
||||
CHECK_EQ(instr_proto->operand_ids_size(), 1);
|
||||
*new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kBitcast);
|
||||
break;
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kAtan2:
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kMaximum:
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kPower:
|
||||
case HloOpcode::kRemainder:
|
||||
case HloOpcode::kSubtract:
|
||||
case HloOpcode::kCompare:
|
||||
case HloOpcode::kAnd:
|
||||
case HloOpcode::kOr:
|
||||
case HloOpcode::kXor:
|
||||
case HloOpcode::kShiftLeft:
|
||||
case HloOpcode::kShiftRightArithmetic:
|
||||
case HloOpcode::kShiftRightLogical:
|
||||
CHECK_EQ(instr_proto->operand_ids_size(), 2);
|
||||
*new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
|
||||
break;
|
||||
case HloOpcode::kSelect: {
|
||||
TF_ASSIGN_OR_RETURN(bool constant_predicate,
|
||||
operand_is_constant(instr_proto, 0));
|
||||
if (!constant_predicate) {
|
||||
// If the predicate operand is not constant, conservatively assume the
|
||||
// entire result values are dynamic.
|
||||
SetInstructionAsConstant(new_instr, id, new_shape, true);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kGather: {
|
||||
TF_ASSIGN_OR_RETURN(bool constant_indices,
|
||||
operand_is_constant(instr_proto, 1));
|
||||
if (!constant_indices) {
|
||||
// If the indices operand is not constant, conservatively assume the
|
||||
// entire result values are dynamic.
|
||||
SetInstructionAsConstant(new_instr, id, new_shape, true);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kReduce: {
|
||||
auto* reducer = &embedded_[new_instr->called_computation_ids(0)];
|
||||
int64 reducer_id = (*global_id)++;
|
||||
new_instr->clear_called_computation_ids();
|
||||
new_instr->add_called_computation_ids(reducer_id);
|
||||
called_computations.push_back(CopyReducer(
|
||||
reducer_id, reducer, /*rewrite_into_pred=*/true, global_id));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kTuple:
|
||||
case HloOpcode::kTranspose:
|
||||
case HloOpcode::kSlice:
|
||||
case HloOpcode::kBroadcast:
|
||||
case HloOpcode::kConcatenate:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kPad:
|
||||
break;
|
||||
case HloOpcode::kGetTupleElement: {
|
||||
// Rewrite parameter followed by gte into constants to avoid
|
||||
// rematerializing the tuple parameter (could be very large).
|
||||
int64 operand_handle = instr_proto->operand_ids(0);
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
|
||||
LookUpInstructionByHandle(operand_handle));
|
||||
TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
|
||||
StringToHloOpcode(operand_proto->opcode()));
|
||||
if (operand_opcode == HloOpcode::kParameter) {
|
||||
SetInstructionAsConstant(new_instr, id, new_shape, true);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kGetDimensionSize: {
|
||||
int64 dimension = instr_proto->dimensions(0);
|
||||
int64 operand_handle = instr_proto->operand_ids(0);
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
|
||||
LookUpInstructionByHandle(operand_handle));
|
||||
|
||||
SetInstructionAsConstant(
|
||||
new_instr, id, new_shape,
|
||||
operand_proto->shape().is_dynamic_dimension(dimension));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kConstant:
|
||||
case HloOpcode::kIota:
|
||||
SetInstructionAsConstant(new_instr, id, new_shape, false);
|
||||
break;
|
||||
case HloOpcode::kCustomCall:
|
||||
if (instr_proto->custom_call_target() == "SetBound") {
|
||||
SetInstructionAsConstant(new_instr, id, new_shape, true);
|
||||
break;
|
||||
} else {
|
||||
return InvalidArgument(
|
||||
"Dynamic inferencing on custom call %s is not supported",
|
||||
instr_proto->DebugString());
|
||||
}
|
||||
case HloOpcode::kParameter:
|
||||
SetInstructionAsConstant(new_instr, id, new_shape, true);
|
||||
break;
|
||||
default:
|
||||
return InvalidArgument("Dynamic inferencing %s is not supported",
|
||||
instr_proto->DebugString());
|
||||
}
|
||||
*new_instr->mutable_name() =
|
||||
GetFullName(instr_proto->opcode(), kNameSeparator, id);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
struct WorkItem {
|
||||
explicit WorkItem(int64 handle, bool need_rewrite)
|
||||
: handle(handle), need_rewrite(need_rewrite), visited(false) {}
|
||||
int64 handle;
|
||||
// If need_rewrite is true, the instruction will be copied and rewrite into
|
||||
// a pred instruction indicating if each value is dynamic. If need_rewrite
|
||||
// is false, simply copy the instruction to the output graph.
|
||||
// E.g.,
|
||||
// For select(P, A, B), we need to rewrite A and B into predicates, but
|
||||
// don't need to rewrite P.
|
||||
bool need_rewrite;
|
||||
// Used in dfs to remember the ids of processed operands of this item.
|
||||
std::vector<int64> processed_operands;
|
||||
// Whether this node been visited before or not.
|
||||
bool visited;
|
||||
};
|
||||
// Only copy each pair of {handle, need_rewrite} once. Value is the id in the
|
||||
// new graph.
|
||||
absl::flat_hash_map<std::pair<int64, bool>, int64> seen;
|
||||
// Monotonically increasing id to assign to new instructions.
|
||||
int64 global_id = 0;
|
||||
// The result id of the last rewritten item -- return value of last stack
|
||||
// item.
|
||||
int64 stacktop_id = -1;
|
||||
std::vector<WorkItem> worklist;
|
||||
worklist.push_back(WorkItem(root->id(), true));
|
||||
while (!worklist.empty()) {
|
||||
WorkItem& item = worklist.back();
|
||||
auto item_key = std::make_pair(item.handle, item.need_rewrite);
|
||||
auto iter = seen.find(item_key);
|
||||
// Already processed this item. Return previous results.
|
||||
if (iter != seen.end()) {
|
||||
stacktop_id = iter->second;
|
||||
worklist.pop_back();
|
||||
continue;
|
||||
}
|
||||
|
||||
int64 next_operand = item.processed_operands.size();
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
|
||||
LookUpInstructionByHandle(item.handle));
|
||||
VLOG(3) << "Processing" << instr_proto->name();
|
||||
if (!item.visited) {
|
||||
item.visited = true;
|
||||
} else {
|
||||
// Record previous processed operand.
|
||||
item.processed_operands.push_back(stacktop_id);
|
||||
next_operand++;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
|
||||
StringToHloOpcode(instr_proto->opcode()));
|
||||
bool should_visit_operand = true;
|
||||
if (opcode == HloOpcode::kGetDimensionSize) {
|
||||
// We rewrite gte instructions into constants based on its operand shape
|
||||
// so no need to visit their operands.
|
||||
should_visit_operand = false;
|
||||
}
|
||||
|
||||
if (opcode == HloOpcode::kGetTupleElement) {
|
||||
int64 operand_handle = instr_proto->operand_ids(0);
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
|
||||
LookUpInstructionByHandle(operand_handle));
|
||||
TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
|
||||
StringToHloOpcode(operand_proto->opcode()));
|
||||
if (operand_opcode == HloOpcode::kParameter) {
|
||||
// Don't rematerialize the whole parameter if it's followed by a GTE.
|
||||
should_visit_operand = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (opcode == HloOpcode::kSelect) {
|
||||
TF_ASSIGN_OR_RETURN(bool constant_predicate,
|
||||
operand_is_constant(instr_proto, 0));
|
||||
// If the predicate operand (first operand) is non-constant, we don't
|
||||
// visit operands and we say the all result values are dynamic.
|
||||
should_visit_operand = constant_predicate;
|
||||
}
|
||||
if (opcode == HloOpcode::kGather) {
|
||||
TF_ASSIGN_OR_RETURN(bool constant_indices,
|
||||
operand_is_constant(instr_proto, 1));
|
||||
// If the indices operand (second operand) is non-constant, we don't
|
||||
// visit operands and we say the all result values are dynamic.
|
||||
should_visit_operand = constant_indices;
|
||||
}
|
||||
if (opcode == HloOpcode::kGetDimensionSize && next_operand == 0) {
|
||||
// Always rewrite get dimension size into constant.
|
||||
item.need_rewrite = true;
|
||||
}
|
||||
if (next_operand >= instr_proto->operand_ids_size() ||
|
||||
!should_visit_operand || InstrIsSetBound(instr_proto)) {
|
||||
// No more operands to process, process self.
|
||||
int64 new_id = global_id++;
|
||||
VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name();
|
||||
TF_RETURN_IF_ERROR(process_instruction(instr_proto, item.need_rewrite,
|
||||
new_id, item.processed_operands,
|
||||
&global_id));
|
||||
stacktop_id = new_id;
|
||||
seen[item_key] = stacktop_id;
|
||||
worklist.pop_back();
|
||||
} else {
|
||||
// Visit and process operand. If an operand doesn't need rewrite
|
||||
// (predicate of kSelect, or indices of kGather), we also don't rewrite
|
||||
// its ancestors.
|
||||
WorkItem next_item(instr_proto->operand_ids(next_operand),
|
||||
item.need_rewrite);
|
||||
if (opcode == HloOpcode::kSelect && next_operand == 0) {
|
||||
next_item.need_rewrite = false;
|
||||
}
|
||||
if (opcode == HloOpcode::kGather && next_operand == 1) {
|
||||
next_item.need_rewrite = false;
|
||||
}
|
||||
// Push next operand into worklist.
|
||||
worklist.push_back(next_item);
|
||||
}
|
||||
}
|
||||
TF_RET_CHECK(stacktop_id != -1);
|
||||
entry.set_root_id(stacktop_id);
|
||||
absl::c_sort(*entry.mutable_instructions(),
|
||||
[](const HloInstructionProto& p1,
|
||||
const HloInstructionProto& p2) { return p1.id() < p2.id(); });
|
||||
XlaComputation computation(entry.id());
|
||||
HloModuleProto* module = computation.mutable_proto();
|
||||
module->set_name(entry.name());
|
||||
module->set_id(entry.id());
|
||||
module->set_entry_computation_name(entry.name());
|
||||
module->set_entry_computation_id(entry.id());
|
||||
*module->mutable_host_program_shape() = *program_shape;
|
||||
for (auto& called_comp : called_computations) {
|
||||
*module->add_computations() = called_comp;
|
||||
}
|
||||
*module->add_computations() = std::move(entry);
|
||||
// Make sure all ids appear in the computation with ascending order.
|
||||
absl::c_sort(*module->mutable_computations(),
|
||||
[](const HloComputationProto& c1,
|
||||
const HloComputationProto& c2) { return c1.id() < c2.id(); });
|
||||
XLA_VLOG_LINES(3, module->DebugString());
|
||||
return std::move(computation);
|
||||
}
|
||||
|
||||
StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
||||
XlaOp root_op, bool dynamic_dimension_is_minus_one) {
|
||||
TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
|
||||
|
@ -114,6 +114,7 @@ class XlaOp {
|
||||
int64 handle() const { return handle_; }
|
||||
|
||||
friend class XlaBuilder;
|
||||
friend class ValueInference;
|
||||
friend class MlirHloBuilder;
|
||||
friend struct internal::XlaBuilderFriend;
|
||||
|
||||
@ -296,31 +297,6 @@ class XlaBuilder {
|
||||
StatusOr<XlaComputation> BuildConstantSubGraph(
|
||||
XlaOp root_op, bool dynamic_dimension_is_uint_max = false);
|
||||
|
||||
// Similar to BuildConstantSubGraph, but with root element type changed to
|
||||
// boolean. A true value in the root indicates that the value is dynamic while
|
||||
// false value indicates that the value is a constant. This will copy the
|
||||
// needed ops/computations to the subgraph.
|
||||
//
|
||||
// E.g.,
|
||||
// Compuptation {
|
||||
// a = 3
|
||||
// b = param(0)
|
||||
// ROOT Tuple(a + b, a + 1, b + 1)
|
||||
// }
|
||||
// Calling BuildDynamicInferenceGraph on root will produce the following
|
||||
// graph:
|
||||
//
|
||||
// Compuptation {
|
||||
// a = False
|
||||
// b = True
|
||||
// ROOT Tuple(a | b, a, b)
|
||||
// }
|
||||
//
|
||||
// The result, which is (True, False, True) after evaluation, can be
|
||||
// interpreted as "First element is dynamic; Second element is static; Third
|
||||
// element is dynamic".
|
||||
StatusOr<XlaComputation> BuildDynamicInferenceGraph(XlaOp root_op);
|
||||
|
||||
// Returns the first error that was encountered while building the
|
||||
// computation. When an error is encountered, by default we return a vacuous
|
||||
// XlaOp and inform the user of the error that occurred while
|
||||
@ -1478,6 +1454,8 @@ class XlaBuilder {
|
||||
}
|
||||
|
||||
friend struct internal::XlaBuilderFriend;
|
||||
|
||||
friend class ValueInference;
|
||||
};
|
||||
|
||||
// RAII-style object: sets the current sharding assignment in builder on
|
||||
|
@ -497,7 +497,7 @@ class HloInstruction {
|
||||
static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
|
||||
const HloInstructionProto& proto,
|
||||
const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
|
||||
const absl::flat_hash_map<int64, HloComputation*>& computation_map,
|
||||
const absl::flat_hash_map<int64, HloComputation*>& computation_map = {},
|
||||
bool prohibit_empty_literal = true);
|
||||
|
||||
// Creates a parameter-retrieving instruction.
|
||||
@ -1043,6 +1043,7 @@ class HloInstruction {
|
||||
|
||||
// Returns the opcode for this instruction.
|
||||
HloOpcode opcode() const { return opcode_; }
|
||||
HloOpcode* mutable_opcode() { return &opcode_; }
|
||||
|
||||
// Returns true if this instruction has a side effect, irrespective of whether
|
||||
// any called computations may contain an instruction with side effects.
|
||||
|
@ -2097,6 +2097,7 @@ xla_test(
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:global_data",
|
||||
"//tensorflow/compiler/xla/client:value_inference",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/global_data.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/prng.h"
|
||||
#include "tensorflow/compiler/xla/client/value_inference.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
@ -44,7 +45,6 @@ namespace {
|
||||
// An enumerator for the client types that we want to iterate over in
|
||||
// the various tests.
|
||||
enum class ClientType { kLocal, kCompileOnly };
|
||||
ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly};
|
||||
|
||||
class DynamismInferenceTest : public ::testing::Test {
|
||||
public:
|
||||
@ -72,21 +72,18 @@ class DynamismInferenceTest : public ::testing::Test {
|
||||
LOG(FATAL) << "invalid client_type value";
|
||||
}
|
||||
|
||||
StatusOr<Literal> ComputeDynamismLiteral(Client* client, XlaOp operand,
|
||||
XlaBuilder* builder,
|
||||
StatusOr<Literal> ComputeDynamismLiteral(XlaOp operand, XlaBuilder* builder,
|
||||
Layout* output_layout = nullptr) {
|
||||
TF_ASSIGN_OR_RETURN(auto subgraph,
|
||||
builder->BuildDynamicInferenceGraph(operand));
|
||||
TF_ASSIGN_OR_RETURN(auto computed,
|
||||
client->ComputeConstant(subgraph, output_layout));
|
||||
return std::move(computed);
|
||||
ValueInference value_inference(builder);
|
||||
TF_ASSIGN_OR_RETURN(auto literal_slice,
|
||||
value_inference.AnalyzeIsDynamic(operand));
|
||||
return literal_slice.Clone();
|
||||
}
|
||||
|
||||
StatusOr<bool> ComputeDynamismScalar(Client* client, XlaOp operand,
|
||||
XlaBuilder* builder,
|
||||
StatusOr<bool> ComputeDynamismScalar(XlaOp operand, XlaBuilder* builder,
|
||||
ShapeIndex index = {}) {
|
||||
TF_ASSIGN_OR_RETURN(auto literal, ComputeDynamismLiteral(client, operand,
|
||||
builder, nullptr));
|
||||
TF_ASSIGN_OR_RETURN(auto literal,
|
||||
ComputeDynamismLiteral(operand, builder, nullptr));
|
||||
return literal.Get<bool>({}, index);
|
||||
}
|
||||
|
||||
@ -94,265 +91,202 @@ class DynamismInferenceTest : public ::testing::Test {
|
||||
};
|
||||
|
||||
TEST_F(DynamismInferenceTest, ScalarInt32Literal) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto computation = ConstantR0<int32>(&b, 42);
|
||||
XlaBuilder b(TestName());
|
||||
auto computation = ConstantR0<int32>(&b, 42);
|
||||
|
||||
auto value = ComputeDynamismScalar(client, computation, &b);
|
||||
ASSERT_TRUE(value.ok()) << value.status();
|
||||
// A constant is not dynamic.
|
||||
EXPECT_EQ(value.ValueOrDie(), false);
|
||||
}
|
||||
auto value = ComputeDynamismScalar(computation, &b);
|
||||
ASSERT_TRUE(value.ok()) << value.status();
|
||||
// A constant is not dynamic.
|
||||
EXPECT_EQ(value.ValueOrDie(), false);
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, Iota) {
|
||||
// The output of iota are consistened static.
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto computation = Iota(&b, S32, 2);
|
||||
// Iota is not dynamic.
|
||||
EXPECT_FALSE(ComputeDynamismLiteral(client, computation, &b)
|
||||
.ValueOrDie()
|
||||
.Get<bool>({0}));
|
||||
}
|
||||
XlaBuilder b(TestName());
|
||||
auto computation = Iota(&b, S32, 2);
|
||||
// Iota is not dynamic.
|
||||
EXPECT_FALSE(
|
||||
ComputeDynamismLiteral(computation, &b).ValueOrDie().Get<bool>({0}));
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, TupleSimple) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
|
||||
auto tuple = Tuple(&b, {c, p});
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {0}).ValueOrDie(),
|
||||
false);
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {1}).ValueOrDie(), true);
|
||||
}
|
||||
auto tuple = Tuple(&b, {c, p});
|
||||
EXPECT_EQ(ComputeDynamismScalar(tuple, &b, {0}).ValueOrDie(), false);
|
||||
EXPECT_EQ(ComputeDynamismScalar(tuple, &b, {1}).ValueOrDie(), true);
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
|
||||
auto tuple = Tuple(&b, {c, p});
|
||||
auto gte0 = GetTupleElement(tuple, 0);
|
||||
auto gte1 = GetTupleElement(tuple, 1);
|
||||
auto tuple_2 = Tuple(&b, {gte0, gte1});
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
|
||||
false);
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
|
||||
true);
|
||||
}
|
||||
auto tuple = Tuple(&b, {c, p});
|
||||
auto gte0 = GetTupleElement(tuple, 0);
|
||||
auto gte1 = GetTupleElement(tuple, 1);
|
||||
auto tuple_2 = Tuple(&b, {gte0, gte1});
|
||||
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
|
||||
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, PredValueUsedTwice) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
auto pred = Eq(c, p);
|
||||
auto result = Select(pred, p, c);
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, result, &b, {}).ValueOrDie(), true);
|
||||
}
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
auto pred = Eq(c, p);
|
||||
auto result = Select(pred, p, c);
|
||||
EXPECT_EQ(ComputeDynamismScalar(result, &b, {}).ValueOrDie(), true);
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, ReduceUsedTwice) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}), "p0");
|
||||
auto zero = ConstantR0<int32>(&b, 0);
|
||||
XlaComputation add_s32 = CreateScalarAddComputation(S32, &b);
|
||||
auto reduce = Reduce(p, zero, add_s32, {0});
|
||||
auto pred = Eq(c, reduce);
|
||||
auto result = Select(pred, reduce, c);
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, result, &b, {}).ValueOrDie(), true);
|
||||
}
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}), "p0");
|
||||
auto zero = ConstantR0<int32>(&b, 0);
|
||||
XlaComputation add_s32 = CreateScalarAddComputation(S32, &b);
|
||||
auto reduce = Reduce(p, zero, add_s32, {0});
|
||||
auto pred = Eq(c, reduce);
|
||||
auto result = Select(pred, reduce, c);
|
||||
EXPECT_EQ(ComputeDynamismScalar(result, &b, {}).ValueOrDie(), true);
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
|
||||
auto concat = ConcatScalars(&b, {c, p});
|
||||
auto slice0 = SliceInDim(concat, 0, 1, 1, 0);
|
||||
auto reshape0 = Reshape(slice0, {});
|
||||
auto slice1 = SliceInDim(concat, 1, 2, 1, 0);
|
||||
auto reshape1 = Reshape(slice1, {});
|
||||
auto tuple_2 = Tuple(&b, {reshape0, reshape1});
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
|
||||
false);
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
|
||||
true);
|
||||
}
|
||||
auto concat = ConcatScalars(&b, {c, p});
|
||||
auto slice0 = SliceInDim(concat, 0, 1, 1, 0);
|
||||
auto reshape0 = Reshape(slice0, {});
|
||||
auto slice1 = SliceInDim(concat, 1, 2, 1, 0);
|
||||
auto reshape1 = Reshape(slice1, {});
|
||||
auto tuple_2 = Tuple(&b, {reshape0, reshape1});
|
||||
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
|
||||
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, ParameterIsDynamic) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
XlaBuilder b(TestName());
|
||||
auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
|
||||
auto value = ComputeDynamismScalar(client, computation, &b);
|
||||
ASSERT_TRUE(value.ok()) << value.status();
|
||||
// A parameter is considered dynamic.
|
||||
EXPECT_EQ(value.ValueOrDie(), true);
|
||||
}
|
||||
auto value = ComputeDynamismScalar(computation, &b);
|
||||
ASSERT_TRUE(value.ok()) << value.status();
|
||||
// A parameter is considered dynamic.
|
||||
EXPECT_EQ(value.ValueOrDie(), true);
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
|
||||
auto neg0 = Neg(c);
|
||||
auto neg1 = Neg(p);
|
||||
auto tuple_2 = Tuple(&b, {neg0, neg1});
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
|
||||
false);
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
|
||||
true);
|
||||
}
|
||||
auto neg0 = Neg(c);
|
||||
auto neg1 = Neg(p);
|
||||
auto tuple_2 = Tuple(&b, {neg0, neg1});
|
||||
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
|
||||
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
|
||||
// Static value + static value = static
|
||||
auto add1 = Add(c, c);
|
||||
// Dynamic value + dynamic value = dynamic
|
||||
auto add2 = Add(p, c);
|
||||
auto tuple_2 = Tuple(&b, {add1, add2});
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
|
||||
false);
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
|
||||
true);
|
||||
}
|
||||
// Static value + static value = static
|
||||
auto add1 = Add(c, c);
|
||||
// Dynamic value + dynamic value = dynamic
|
||||
auto add2 = Add(p, c);
|
||||
auto tuple_2 = Tuple(&b, {add1, add2});
|
||||
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
|
||||
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, GetDimensionSize) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
// param = Param([<=2, 3])
|
||||
// get_dimension_size(param, 0) is dynamic
|
||||
// get_dimension_size(param, 1) is static
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}),
|
||||
"p0");
|
||||
XlaBuilder b(TestName());
|
||||
// param = Param([<=2, 3])
|
||||
// get_dimension_size(param, 0) is dynamic
|
||||
// get_dimension_size(param, 1) is static
|
||||
auto p =
|
||||
Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
|
||||
|
||||
auto gds0 = GetDimensionSize(p, 0);
|
||||
auto gds1 = GetDimensionSize(p, 1);
|
||||
auto tuple_2 = Tuple(&b, {gds0, gds1});
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
|
||||
true);
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
|
||||
false);
|
||||
}
|
||||
auto gds0 = GetDimensionSize(p, 0);
|
||||
auto gds1 = GetDimensionSize(p, 1);
|
||||
auto tuple_2 = Tuple(&b, {gds0, gds1});
|
||||
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), true);
|
||||
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), false);
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, GatherWithCommonParent) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
// Test the analysis on a gather where first operand and second operand have
|
||||
// common parents.
|
||||
Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
|
||||
XlaBuilder b(TestName());
|
||||
// Test the analysis on a gather where first operand and second operand have
|
||||
// common parents.
|
||||
Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
|
||||
|
||||
auto operand1 = Parameter(&b, 0, indices_shape, "p1");
|
||||
auto operand2 = Parameter(&b, 1, indices_shape, "p2");
|
||||
auto indices = Sub(operand1, operand2);
|
||||
GatherDimensionNumbers dim_numbers;
|
||||
dim_numbers.add_offset_dims(1);
|
||||
dim_numbers.add_start_index_map(0);
|
||||
dim_numbers.set_index_vector_dim(1);
|
||||
auto gather = Gather(operand1, indices, dim_numbers, {1});
|
||||
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
|
||||
EXPECT_TRUE(ComputeDynamismLiteral(client, gather, &b)
|
||||
.ValueOrDie()
|
||||
.Get<bool>({0, 0}));
|
||||
}
|
||||
auto operand1 = Parameter(&b, 0, indices_shape, "p1");
|
||||
auto operand2 = Parameter(&b, 1, indices_shape, "p2");
|
||||
auto indices = Sub(operand1, operand2);
|
||||
GatherDimensionNumbers dim_numbers;
|
||||
dim_numbers.add_offset_dims(1);
|
||||
dim_numbers.add_start_index_map(0);
|
||||
dim_numbers.set_index_vector_dim(1);
|
||||
auto gather = Gather(operand1, indices, dim_numbers, {1});
|
||||
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
|
||||
EXPECT_TRUE(
|
||||
ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0}));
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, GatherWithConstantParent) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
// Test the analysis on a gather.
|
||||
Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
|
||||
auto data_operand = ConstantR1<int32>(&b, {1, 2});
|
||||
auto indices = ConstantR1<int32>(&b, {1, 2});
|
||||
GatherDimensionNumbers dim_numbers;
|
||||
dim_numbers.add_offset_dims(1);
|
||||
dim_numbers.add_start_index_map(0);
|
||||
dim_numbers.set_index_vector_dim(1);
|
||||
auto gather = Gather(data_operand, indices, dim_numbers, {1});
|
||||
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
|
||||
// Everything is constant, result is also contant.
|
||||
EXPECT_FALSE(ComputeDynamismLiteral(client, gather, &b)
|
||||
.ValueOrDie()
|
||||
.Get<bool>({0, 0}));
|
||||
}
|
||||
XlaBuilder b(TestName());
|
||||
// Test the analysis on a gather.
|
||||
Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
|
||||
auto data_operand = ConstantR1<int32>(&b, {1, 2});
|
||||
auto indices = ConstantR1<int32>(&b, {1, 2});
|
||||
GatherDimensionNumbers dim_numbers;
|
||||
dim_numbers.add_offset_dims(1);
|
||||
dim_numbers.add_start_index_map(0);
|
||||
dim_numbers.set_index_vector_dim(1);
|
||||
auto gather = Gather(data_operand, indices, dim_numbers, {1});
|
||||
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
|
||||
// Everything is constant, result is also contant.
|
||||
EXPECT_FALSE(
|
||||
ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0}));
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, GatherWithSharedConstantParent) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
// Test the analysis on a gather.
|
||||
Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
|
||||
auto operand1 = ConstantR1<int32>(&b, {1, 2});
|
||||
auto operand2 = ConstantR1<int32>(&b, {1, 2});
|
||||
auto indices = Sub(operand1, operand2);
|
||||
GatherDimensionNumbers dim_numbers;
|
||||
dim_numbers.add_offset_dims(1);
|
||||
dim_numbers.add_start_index_map(0);
|
||||
dim_numbers.set_index_vector_dim(1);
|
||||
auto gather = Gather(operand1, indices, dim_numbers, {1});
|
||||
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
|
||||
// Everything is constant, result is also contant.
|
||||
EXPECT_FALSE(ComputeDynamismLiteral(client, gather, &b)
|
||||
.ValueOrDie()
|
||||
.Get<bool>({0, 0}));
|
||||
}
|
||||
XlaBuilder b(TestName());
|
||||
// Test the analysis on a gather.
|
||||
Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
|
||||
auto operand1 = ConstantR1<int32>(&b, {1, 2});
|
||||
auto operand2 = ConstantR1<int32>(&b, {1, 2});
|
||||
auto indices = Sub(operand1, operand2);
|
||||
GatherDimensionNumbers dim_numbers;
|
||||
dim_numbers.add_offset_dims(1);
|
||||
dim_numbers.add_start_index_map(0);
|
||||
dim_numbers.set_index_vector_dim(1);
|
||||
auto gather = Gather(operand1, indices, dim_numbers, {1});
|
||||
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
|
||||
// Everything is constant, result is also contant.
|
||||
EXPECT_FALSE(
|
||||
ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0}));
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, InferThroughPad) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
// Test the analysis on a gather.
|
||||
auto operand1 = ConstantR1<int32>(&b, {1, 2});
|
||||
auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0");
|
||||
PaddingConfig padding_config;
|
||||
padding_config.add_dimensions()->set_edge_padding_high(1);
|
||||
// After pad the value is [constant, constant, parameter].
|
||||
auto pad = Pad(operand1, parameter, padding_config);
|
||||
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
|
||||
// Everything is constant, result is also contant.
|
||||
EXPECT_FALSE(
|
||||
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({0}));
|
||||
EXPECT_FALSE(
|
||||
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({1}));
|
||||
EXPECT_TRUE(
|
||||
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({2}));
|
||||
}
|
||||
XlaBuilder b(TestName());
|
||||
// Test the analysis on a gather.
|
||||
auto operand1 = ConstantR1<int32>(&b, {1, 2});
|
||||
auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0");
|
||||
PaddingConfig padding_config;
|
||||
padding_config.add_dimensions()->set_edge_padding_high(1);
|
||||
// After pad the value is [constant, constant, parameter].
|
||||
auto pad = Pad(operand1, parameter, padding_config);
|
||||
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
|
||||
// Everything is constant, result is also contant.
|
||||
EXPECT_FALSE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({0}));
|
||||
EXPECT_FALSE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({1}));
|
||||
EXPECT_TRUE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({2}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user