DynamismInference: Support instruction with/without rewrite in the same graph.
- Dynamism inference is used to decide if a value is dynamic or not. - In dynamism inference, we rewrite some instructions into boolean form. E.g., We rewrite A = constant(0) B = parameter(0) ROOT A + B into A' = constant(false) B' = constant(true) ROOT A' | B' - We also don't rewrite some instructions: E.g., A = constant(0) B = parameter(0) C = constant(0) D = parameter(0) P = C == D ROOT select(P,A,B) Into A' = constant(false) B' = constant(true) C = constant(0) D = parameter(0) P = C == D ROOT select(P,A',B') We don't rewrite P, and instructions reachable from P. - This cl fixes an issue where this two forms are mixed together: E.g., A = constant(0) B = parameter(0) P = A == B ROOT select(P,A,B) Previously the pass would fail. PiperOrigin-RevId: 327889288 Change-Id: I3dd419ca5d729bb857d3fcac8fd76d47788aa5c2
This commit is contained in:
parent
90d58ce333
commit
9f6e57df30
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/client/sharding_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
@ -78,16 +79,13 @@ ShapeProto ConvertShapeProtoToPred(const ShapeProto& shape_proto) {
|
||||
return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto();
|
||||
}
|
||||
|
||||
HloInstructionProto CreateConstantInstruction(int64 id, const Shape& shape,
|
||||
bool pred) {
|
||||
HloInstructionProto const_instr;
|
||||
void SetInstructionAsConstant(HloInstructionProto* instr, int64 id,
|
||||
const Shape& shape, bool pred) {
|
||||
Literal literal = LiteralUtil::CreateR0(pred);
|
||||
Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie();
|
||||
*const_instr.mutable_shape() = shape.ToProto();
|
||||
*const_instr.mutable_literal() = literal_broadcast.ToProto();
|
||||
*const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
|
||||
const_instr.set_id(id);
|
||||
return const_instr;
|
||||
*instr->mutable_shape() = shape.ToProto();
|
||||
*instr->mutable_literal() = literal_broadcast.ToProto();
|
||||
*instr->mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
|
||||
}
|
||||
|
||||
// Converts a HloComputation into ReducerOr with predicate types.
|
||||
@ -2971,27 +2969,12 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
*program_shape->mutable_result() =
|
||||
ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto();
|
||||
|
||||
std::set<int64> seen;
|
||||
struct WorkItem {
|
||||
explicit WorkItem(int64 handle, bool need_rewrite)
|
||||
: handle(handle), need_rewrite(need_rewrite) {}
|
||||
int64 handle;
|
||||
// If need_rewrite is true, the instruction will be copied and rewrite into
|
||||
// a pred instruction indicating if each value is dynamic. If need_rewrite
|
||||
// is false, simply copy the instruction to the output graph.
|
||||
// E.g.,
|
||||
// For select(P, A, B), we need to rewrite A and B into predicates, but
|
||||
// don't need to rewrite P.
|
||||
bool need_rewrite;
|
||||
};
|
||||
std::queue<WorkItem> worklist;
|
||||
worklist.push(WorkItem(root->id(), true));
|
||||
entry.set_root_id(root->id());
|
||||
std::vector<HloComputationProto> called_computatons;
|
||||
// Rewritre instruction with id "from" into the new graph.
|
||||
// Returns more work items that need to finish.
|
||||
auto rewrite_instruction =
|
||||
[&](int64 from, bool need_rewrite) -> StatusOr<std::vector<WorkItem>> {
|
||||
// Process instruction and copy it into the new graph. The new node in the new
|
||||
// graph with have id set to `id`.
|
||||
auto process_instruction = [&](const HloInstructionProto* instr_proto,
|
||||
bool need_rewrite, int64 id,
|
||||
absl::Span<int64 const> operand_ids) {
|
||||
// Rewrite the instruction with following rules:
|
||||
// - Unary ops: Convert into bitcast (identity) with type Pred.
|
||||
// - Binary ops: Convert into binary or.
|
||||
@ -3004,22 +2987,20 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
// - Constant: Convert to constant False.
|
||||
// - Other ops: Not supported.
|
||||
// Create the instruction for the new handle.
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
|
||||
LookUpInstructionByHandle(from));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
|
||||
StringToHloOpcode(instr_proto->opcode()));
|
||||
std::vector<WorkItem> operands_todo;
|
||||
auto* new_instr = entry.add_instructions();
|
||||
*new_instr = *instr_proto;
|
||||
for (auto operand_id : new_instr->operand_ids()) {
|
||||
operands_todo.emplace_back(operand_id, need_rewrite);
|
||||
new_instr->set_id(id);
|
||||
new_instr->mutable_operand_ids()->Clear();
|
||||
for (auto operand_id : operand_ids) {
|
||||
new_instr->mutable_operand_ids()->Add(operand_id);
|
||||
}
|
||||
|
||||
if (!need_rewrite) {
|
||||
*new_instr->mutable_name() =
|
||||
GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id());
|
||||
return operands_todo;
|
||||
GetFullName(instr_proto->opcode(), kNameSeparator, id);
|
||||
return Status::OK();
|
||||
}
|
||||
*new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape());
|
||||
Shape new_shape(new_instr->shape());
|
||||
@ -3074,10 +3055,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
*new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
|
||||
break;
|
||||
case HloOpcode::kSelect:
|
||||
operands_todo[0].need_rewrite = false;
|
||||
break;
|
||||
case HloOpcode::kGather:
|
||||
operands_todo[1].need_rewrite = false;
|
||||
break;
|
||||
case HloOpcode::kReduce: {
|
||||
int64 reducer_id = new_instr->called_computation_ids(0);
|
||||
@ -3099,39 +3078,101 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
|
||||
LookUpInstructionByHandle(operand_handle));
|
||||
|
||||
*new_instr = CreateConstantInstruction(
|
||||
from, new_shape,
|
||||
SetInstructionAsConstant(
|
||||
new_instr, id, new_shape,
|
||||
operand_proto->shape().is_dynamic_dimension(dimension));
|
||||
operands_todo.clear();
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kConstant:
|
||||
*new_instr = CreateConstantInstruction(from, new_shape, false);
|
||||
SetInstructionAsConstant(new_instr, id, new_shape, false);
|
||||
break;
|
||||
case HloOpcode::kParameter:
|
||||
*new_instr = CreateConstantInstruction(from, new_shape, true);
|
||||
SetInstructionAsConstant(new_instr, id, new_shape, true);
|
||||
break;
|
||||
default:
|
||||
return InvalidArgument("Dynamic inferencing %s is not supported",
|
||||
instr_proto->DebugString());
|
||||
}
|
||||
*new_instr->mutable_name() =
|
||||
GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id());
|
||||
return operands_todo;
|
||||
GetFullName(instr_proto->opcode(), kNameSeparator, id);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
struct WorkItem {
|
||||
explicit WorkItem(int64 handle, bool need_rewrite)
|
||||
: handle(handle), need_rewrite(need_rewrite), visited(false) {}
|
||||
int64 handle;
|
||||
// If need_rewrite is true, the instruction will be copied and rewrite into
|
||||
// a pred instruction indicating if each value is dynamic. If need_rewrite
|
||||
// is false, simply copy the instruction to the output graph.
|
||||
// E.g.,
|
||||
// For select(P, A, B), we need to rewrite A and B into predicates, but
|
||||
// don't need to rewrite P.
|
||||
bool need_rewrite;
|
||||
// Used in dfs to remember the ids of processed operands of this item.
|
||||
std::vector<int64> processed_operands;
|
||||
// Whether this node been visited before or not.
|
||||
bool visited;
|
||||
};
|
||||
// Only copy each pair of {handle, need_rewrite} once. Value is the id in the
|
||||
// new graph.
|
||||
absl::flat_hash_map<std::pair<int64, bool>, int64> seen;
|
||||
// Monotonically increasing id to assign to new instructions.
|
||||
int64 global_id = 0;
|
||||
// The result id of the last rewritten item -- return value of last stack
|
||||
// item.
|
||||
int64 stacktop_id = -1;
|
||||
std::vector<WorkItem> worklist;
|
||||
worklist.push_back(WorkItem(root->id(), true));
|
||||
while (!worklist.empty()) {
|
||||
WorkItem item = worklist.front();
|
||||
worklist.pop();
|
||||
if (!seen.insert(item.handle).second) {
|
||||
WorkItem& item = worklist.back();
|
||||
auto item_key = std::make_pair(item.handle, item.need_rewrite);
|
||||
auto iter = seen.find(item_key);
|
||||
// Already processed this item. Return previous results.
|
||||
if (iter != seen.end()) {
|
||||
stacktop_id = iter->second;
|
||||
worklist.pop_back();
|
||||
continue;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto todos,
|
||||
rewrite_instruction(item.handle, item.need_rewrite));
|
||||
for (WorkItem& todo : todos) {
|
||||
worklist.push(todo);
|
||||
|
||||
int64 next_operand = item.processed_operands.size();
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
|
||||
LookUpInstructionByHandle(item.handle));
|
||||
VLOG(3) << "Processing" << instr_proto->name();
|
||||
if (!item.visited) {
|
||||
item.visited = true;
|
||||
} else {
|
||||
// Record previous processed operand.
|
||||
item.processed_operands.push_back(stacktop_id);
|
||||
next_operand++;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
|
||||
StringToHloOpcode(instr_proto->opcode()));
|
||||
if (next_operand >= instr_proto->operand_ids_size() ||
|
||||
opcode == HloOpcode::kGetDimensionSize) {
|
||||
// No more operands to process, process self.
|
||||
int64 new_id = ++global_id;
|
||||
VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name();
|
||||
TF_RETURN_IF_ERROR(process_instruction(instr_proto, item.need_rewrite,
|
||||
new_id, item.processed_operands));
|
||||
stacktop_id = new_id;
|
||||
seen[item_key] = stacktop_id;
|
||||
worklist.pop_back();
|
||||
continue;
|
||||
}
|
||||
|
||||
WorkItem next_item(instr_proto->operand_ids(next_operand), true);
|
||||
if (opcode == HloOpcode::kSelect && next_operand == 0) {
|
||||
next_item.need_rewrite = false;
|
||||
}
|
||||
if (opcode == HloOpcode::kGather && next_operand == 1) {
|
||||
next_item.need_rewrite = false;
|
||||
}
|
||||
// Push next operand into worklist.
|
||||
worklist.push_back(next_item);
|
||||
}
|
||||
TF_RET_CHECK(stacktop_id != -1);
|
||||
entry.set_root_id(stacktop_id);
|
||||
absl::c_sort(*entry.mutable_instructions(),
|
||||
[](const HloInstructionProto& p1,
|
||||
const HloInstructionProto& p2) { return p1.id() < p2.id(); });
|
||||
|
@ -104,12 +104,26 @@ TEST_F(DynamismInferenceTest, ScalarInt32Literal) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, TupleSimple) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
|
||||
auto tuple = Tuple(&b, {c, p});
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {0}).ValueOrDie(),
|
||||
false);
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {1}).ValueOrDie(), true);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
|
||||
auto tuple = Tuple(&b, {c, p});
|
||||
auto gte0 = GetTupleElement(tuple, 0);
|
||||
@ -122,12 +136,25 @@ TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, PredValueUsedTwice) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
auto pred = Eq(c, p);
|
||||
auto result = Select(pred, p, c);
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, result, &b, {}).ValueOrDie(),
|
||||
false);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
|
||||
auto concat = ConcatScalars(&b, {c, p});
|
||||
auto slice0 = SliceInDim(concat, 0, 1, 1, 0);
|
||||
@ -146,7 +173,7 @@ TEST_F(DynamismInferenceTest, ParameterIsDynamic) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
|
||||
auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
|
||||
auto value = ComputeDynamismScalar(client, computation, &b);
|
||||
ASSERT_TRUE(value.ok()) << value.status();
|
||||
@ -160,7 +187,7 @@ TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
|
||||
auto neg0 = Neg(c);
|
||||
auto neg1 = Neg(p);
|
||||
@ -177,7 +204,7 @@ TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||
|
||||
// Static value + static value = static
|
||||
auto add1 = Add(c, c);
|
||||
@ -198,8 +225,8 @@ TEST_F(DynamismInferenceTest, GetDimensionSize) {
|
||||
// param = Param([<=2, 3])
|
||||
// get_dimension_size(param, 0) is dynamic
|
||||
// get_dimension_size(param, 1) is static
|
||||
auto p =
|
||||
Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "0");
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}),
|
||||
"p0");
|
||||
|
||||
auto gds0 = GetDimensionSize(p, 0);
|
||||
auto gds1 = GetDimensionSize(p, 1);
|
||||
|
Loading…
Reference in New Issue
Block a user