DynamismInference: Support instruction with/without rewrite in the same graph.
- Dynamism inference is used to decide if a value is dynamic or not. - In dynamism inference, we rewrite some instructions into boolean form. E.g., We rewrite A = constant(0) B = parameter(0) ROOT A + B into A' = constant(false) B' = constant(true) ROOT A' | B' - We also don't rewrite some instructions: E.g., A = constant(0) B = parameter(0) C = constant(0) D = parameter(0) P = C == D ROOT select(P,A,B) Into A' = constant(false) B' = constant(true) C = constant(0) D = parameter(0) P = C == D ROOT select(P,A',B') We don't rewrite P, and instructions reachable from P. - This cl fixes an issue where this two forms are mixed together: E.g., A = constant(0) B = parameter(0) P = A == B ROOT select(P,A,B) Previously the pass would fail. PiperOrigin-RevId: 327889288 Change-Id: I3dd419ca5d729bb857d3fcac8fd76d47788aa5c2
This commit is contained in:
parent
90d58ce333
commit
9f6e57df30
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/client/sharding_builder.h"
|
#include "tensorflow/compiler/xla/client/sharding_builder.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||||
@ -78,16 +79,13 @@ ShapeProto ConvertShapeProtoToPred(const ShapeProto& shape_proto) {
|
|||||||
return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto();
|
return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto();
|
||||||
}
|
}
|
||||||
|
|
||||||
HloInstructionProto CreateConstantInstruction(int64 id, const Shape& shape,
|
void SetInstructionAsConstant(HloInstructionProto* instr, int64 id,
|
||||||
bool pred) {
|
const Shape& shape, bool pred) {
|
||||||
HloInstructionProto const_instr;
|
|
||||||
Literal literal = LiteralUtil::CreateR0(pred);
|
Literal literal = LiteralUtil::CreateR0(pred);
|
||||||
Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie();
|
Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie();
|
||||||
*const_instr.mutable_shape() = shape.ToProto();
|
*instr->mutable_shape() = shape.ToProto();
|
||||||
*const_instr.mutable_literal() = literal_broadcast.ToProto();
|
*instr->mutable_literal() = literal_broadcast.ToProto();
|
||||||
*const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
|
*instr->mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
|
||||||
const_instr.set_id(id);
|
|
||||||
return const_instr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts a HloComputation into ReducerOr with predicate types.
|
// Converts a HloComputation into ReducerOr with predicate types.
|
||||||
@ -2971,27 +2969,12 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
|||||||
*program_shape->mutable_result() =
|
*program_shape->mutable_result() =
|
||||||
ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto();
|
ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto();
|
||||||
|
|
||||||
std::set<int64> seen;
|
|
||||||
struct WorkItem {
|
|
||||||
explicit WorkItem(int64 handle, bool need_rewrite)
|
|
||||||
: handle(handle), need_rewrite(need_rewrite) {}
|
|
||||||
int64 handle;
|
|
||||||
// If need_rewrite is true, the instruction will be copied and rewrite into
|
|
||||||
// a pred instruction indicating if each value is dynamic. If need_rewrite
|
|
||||||
// is false, simply copy the instruction to the output graph.
|
|
||||||
// E.g.,
|
|
||||||
// For select(P, A, B), we need to rewrite A and B into predicates, but
|
|
||||||
// don't need to rewrite P.
|
|
||||||
bool need_rewrite;
|
|
||||||
};
|
|
||||||
std::queue<WorkItem> worklist;
|
|
||||||
worklist.push(WorkItem(root->id(), true));
|
|
||||||
entry.set_root_id(root->id());
|
|
||||||
std::vector<HloComputationProto> called_computatons;
|
std::vector<HloComputationProto> called_computatons;
|
||||||
// Rewritre instruction with id "from" into the new graph.
|
// Process instruction and copy it into the new graph. The new node in the new
|
||||||
// Returns more work items that need to finish.
|
// graph with have id set to `id`.
|
||||||
auto rewrite_instruction =
|
auto process_instruction = [&](const HloInstructionProto* instr_proto,
|
||||||
[&](int64 from, bool need_rewrite) -> StatusOr<std::vector<WorkItem>> {
|
bool need_rewrite, int64 id,
|
||||||
|
absl::Span<int64 const> operand_ids) {
|
||||||
// Rewrite the instruction with following rules:
|
// Rewrite the instruction with following rules:
|
||||||
// - Unary ops: Convert into bitcast (identity) with type Pred.
|
// - Unary ops: Convert into bitcast (identity) with type Pred.
|
||||||
// - Binary ops: Convert into binary or.
|
// - Binary ops: Convert into binary or.
|
||||||
@ -3004,22 +2987,20 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
|||||||
// - Constant: Convert to constant False.
|
// - Constant: Convert to constant False.
|
||||||
// - Other ops: Not supported.
|
// - Other ops: Not supported.
|
||||||
// Create the instruction for the new handle.
|
// Create the instruction for the new handle.
|
||||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
|
|
||||||
LookUpInstructionByHandle(from));
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
|
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
|
||||||
StringToHloOpcode(instr_proto->opcode()));
|
StringToHloOpcode(instr_proto->opcode()));
|
||||||
std::vector<WorkItem> operands_todo;
|
|
||||||
auto* new_instr = entry.add_instructions();
|
auto* new_instr = entry.add_instructions();
|
||||||
*new_instr = *instr_proto;
|
*new_instr = *instr_proto;
|
||||||
for (auto operand_id : new_instr->operand_ids()) {
|
new_instr->set_id(id);
|
||||||
operands_todo.emplace_back(operand_id, need_rewrite);
|
new_instr->mutable_operand_ids()->Clear();
|
||||||
|
for (auto operand_id : operand_ids) {
|
||||||
|
new_instr->mutable_operand_ids()->Add(operand_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!need_rewrite) {
|
if (!need_rewrite) {
|
||||||
*new_instr->mutable_name() =
|
*new_instr->mutable_name() =
|
||||||
GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id());
|
GetFullName(instr_proto->opcode(), kNameSeparator, id);
|
||||||
return operands_todo;
|
return Status::OK();
|
||||||
}
|
}
|
||||||
*new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape());
|
*new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape());
|
||||||
Shape new_shape(new_instr->shape());
|
Shape new_shape(new_instr->shape());
|
||||||
@ -3074,10 +3055,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
|||||||
*new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
|
*new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
|
||||||
break;
|
break;
|
||||||
case HloOpcode::kSelect:
|
case HloOpcode::kSelect:
|
||||||
operands_todo[0].need_rewrite = false;
|
|
||||||
break;
|
break;
|
||||||
case HloOpcode::kGather:
|
case HloOpcode::kGather:
|
||||||
operands_todo[1].need_rewrite = false;
|
|
||||||
break;
|
break;
|
||||||
case HloOpcode::kReduce: {
|
case HloOpcode::kReduce: {
|
||||||
int64 reducer_id = new_instr->called_computation_ids(0);
|
int64 reducer_id = new_instr->called_computation_ids(0);
|
||||||
@ -3099,39 +3078,101 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
|||||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
|
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
|
||||||
LookUpInstructionByHandle(operand_handle));
|
LookUpInstructionByHandle(operand_handle));
|
||||||
|
|
||||||
*new_instr = CreateConstantInstruction(
|
SetInstructionAsConstant(
|
||||||
from, new_shape,
|
new_instr, id, new_shape,
|
||||||
operand_proto->shape().is_dynamic_dimension(dimension));
|
operand_proto->shape().is_dynamic_dimension(dimension));
|
||||||
operands_todo.clear();
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case HloOpcode::kConstant:
|
case HloOpcode::kConstant:
|
||||||
*new_instr = CreateConstantInstruction(from, new_shape, false);
|
SetInstructionAsConstant(new_instr, id, new_shape, false);
|
||||||
break;
|
break;
|
||||||
case HloOpcode::kParameter:
|
case HloOpcode::kParameter:
|
||||||
*new_instr = CreateConstantInstruction(from, new_shape, true);
|
SetInstructionAsConstant(new_instr, id, new_shape, true);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return InvalidArgument("Dynamic inferencing %s is not supported",
|
return InvalidArgument("Dynamic inferencing %s is not supported",
|
||||||
instr_proto->DebugString());
|
instr_proto->DebugString());
|
||||||
}
|
}
|
||||||
*new_instr->mutable_name() =
|
*new_instr->mutable_name() =
|
||||||
GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id());
|
GetFullName(instr_proto->opcode(), kNameSeparator, id);
|
||||||
return operands_todo;
|
return Status::OK();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct WorkItem {
|
||||||
|
explicit WorkItem(int64 handle, bool need_rewrite)
|
||||||
|
: handle(handle), need_rewrite(need_rewrite), visited(false) {}
|
||||||
|
int64 handle;
|
||||||
|
// If need_rewrite is true, the instruction will be copied and rewrite into
|
||||||
|
// a pred instruction indicating if each value is dynamic. If need_rewrite
|
||||||
|
// is false, simply copy the instruction to the output graph.
|
||||||
|
// E.g.,
|
||||||
|
// For select(P, A, B), we need to rewrite A and B into predicates, but
|
||||||
|
// don't need to rewrite P.
|
||||||
|
bool need_rewrite;
|
||||||
|
// Used in dfs to remember the ids of processed operands of this item.
|
||||||
|
std::vector<int64> processed_operands;
|
||||||
|
// Whether this node been visited before or not.
|
||||||
|
bool visited;
|
||||||
|
};
|
||||||
|
// Only copy each pair of {handle, need_rewrite} once. Value is the id in the
|
||||||
|
// new graph.
|
||||||
|
absl::flat_hash_map<std::pair<int64, bool>, int64> seen;
|
||||||
|
// Monotonically increasing id to assign to new instructions.
|
||||||
|
int64 global_id = 0;
|
||||||
|
// The result id of the last rewritten item -- return value of last stack
|
||||||
|
// item.
|
||||||
|
int64 stacktop_id = -1;
|
||||||
|
std::vector<WorkItem> worklist;
|
||||||
|
worklist.push_back(WorkItem(root->id(), true));
|
||||||
while (!worklist.empty()) {
|
while (!worklist.empty()) {
|
||||||
WorkItem item = worklist.front();
|
WorkItem& item = worklist.back();
|
||||||
worklist.pop();
|
auto item_key = std::make_pair(item.handle, item.need_rewrite);
|
||||||
if (!seen.insert(item.handle).second) {
|
auto iter = seen.find(item_key);
|
||||||
|
// Already processed this item. Return previous results.
|
||||||
|
if (iter != seen.end()) {
|
||||||
|
stacktop_id = iter->second;
|
||||||
|
worklist.pop_back();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(auto todos,
|
|
||||||
rewrite_instruction(item.handle, item.need_rewrite));
|
int64 next_operand = item.processed_operands.size();
|
||||||
for (WorkItem& todo : todos) {
|
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
|
||||||
worklist.push(todo);
|
LookUpInstructionByHandle(item.handle));
|
||||||
|
VLOG(3) << "Processing" << instr_proto->name();
|
||||||
|
if (!item.visited) {
|
||||||
|
item.visited = true;
|
||||||
|
} else {
|
||||||
|
// Record previous processed operand.
|
||||||
|
item.processed_operands.push_back(stacktop_id);
|
||||||
|
next_operand++;
|
||||||
}
|
}
|
||||||
|
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
|
||||||
|
StringToHloOpcode(instr_proto->opcode()));
|
||||||
|
if (next_operand >= instr_proto->operand_ids_size() ||
|
||||||
|
opcode == HloOpcode::kGetDimensionSize) {
|
||||||
|
// No more operands to process, process self.
|
||||||
|
int64 new_id = ++global_id;
|
||||||
|
VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name();
|
||||||
|
TF_RETURN_IF_ERROR(process_instruction(instr_proto, item.need_rewrite,
|
||||||
|
new_id, item.processed_operands));
|
||||||
|
stacktop_id = new_id;
|
||||||
|
seen[item_key] = stacktop_id;
|
||||||
|
worklist.pop_back();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
WorkItem next_item(instr_proto->operand_ids(next_operand), true);
|
||||||
|
if (opcode == HloOpcode::kSelect && next_operand == 0) {
|
||||||
|
next_item.need_rewrite = false;
|
||||||
|
}
|
||||||
|
if (opcode == HloOpcode::kGather && next_operand == 1) {
|
||||||
|
next_item.need_rewrite = false;
|
||||||
|
}
|
||||||
|
// Push next operand into worklist.
|
||||||
|
worklist.push_back(next_item);
|
||||||
}
|
}
|
||||||
|
TF_RET_CHECK(stacktop_id != -1);
|
||||||
|
entry.set_root_id(stacktop_id);
|
||||||
absl::c_sort(*entry.mutable_instructions(),
|
absl::c_sort(*entry.mutable_instructions(),
|
||||||
[](const HloInstructionProto& p1,
|
[](const HloInstructionProto& p1,
|
||||||
const HloInstructionProto& p2) { return p1.id() < p2.id(); });
|
const HloInstructionProto& p2) { return p1.id() < p2.id(); });
|
||||||
|
@ -104,12 +104,26 @@ TEST_F(DynamismInferenceTest, ScalarInt32Literal) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DynamismInferenceTest, TupleSimple) {
|
||||||
|
for (ClientType client_type : client_types) {
|
||||||
|
Client* client = ClientOrDie(platform_, client_type);
|
||||||
|
XlaBuilder b(TestName());
|
||||||
|
auto c = ConstantR0<int32>(&b, 42);
|
||||||
|
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||||
|
|
||||||
|
auto tuple = Tuple(&b, {c, p});
|
||||||
|
EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {0}).ValueOrDie(),
|
||||||
|
false);
|
||||||
|
EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {1}).ValueOrDie(), true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
|
TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
|
||||||
for (ClientType client_type : client_types) {
|
for (ClientType client_type : client_types) {
|
||||||
Client* client = ClientOrDie(platform_, client_type);
|
Client* client = ClientOrDie(platform_, client_type);
|
||||||
XlaBuilder b(TestName());
|
XlaBuilder b(TestName());
|
||||||
auto c = ConstantR0<int32>(&b, 42);
|
auto c = ConstantR0<int32>(&b, 42);
|
||||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
|
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||||
|
|
||||||
auto tuple = Tuple(&b, {c, p});
|
auto tuple = Tuple(&b, {c, p});
|
||||||
auto gte0 = GetTupleElement(tuple, 0);
|
auto gte0 = GetTupleElement(tuple, 0);
|
||||||
@ -122,12 +136,25 @@ TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DynamismInferenceTest, PredValueUsedTwice) {
|
||||||
|
for (ClientType client_type : client_types) {
|
||||||
|
Client* client = ClientOrDie(platform_, client_type);
|
||||||
|
XlaBuilder b(TestName());
|
||||||
|
auto c = ConstantR0<int32>(&b, 42);
|
||||||
|
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||||
|
auto pred = Eq(c, p);
|
||||||
|
auto result = Select(pred, p, c);
|
||||||
|
EXPECT_EQ(ComputeDynamismScalar(client, result, &b, {}).ValueOrDie(),
|
||||||
|
false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
|
TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
|
||||||
for (ClientType client_type : client_types) {
|
for (ClientType client_type : client_types) {
|
||||||
Client* client = ClientOrDie(platform_, client_type);
|
Client* client = ClientOrDie(platform_, client_type);
|
||||||
XlaBuilder b(TestName());
|
XlaBuilder b(TestName());
|
||||||
auto c = ConstantR0<int32>(&b, 42);
|
auto c = ConstantR0<int32>(&b, 42);
|
||||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
|
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||||
|
|
||||||
auto concat = ConcatScalars(&b, {c, p});
|
auto concat = ConcatScalars(&b, {c, p});
|
||||||
auto slice0 = SliceInDim(concat, 0, 1, 1, 0);
|
auto slice0 = SliceInDim(concat, 0, 1, 1, 0);
|
||||||
@ -146,7 +173,7 @@ TEST_F(DynamismInferenceTest, ParameterIsDynamic) {
|
|||||||
for (ClientType client_type : client_types) {
|
for (ClientType client_type : client_types) {
|
||||||
Client* client = ClientOrDie(platform_, client_type);
|
Client* client = ClientOrDie(platform_, client_type);
|
||||||
XlaBuilder b(TestName());
|
XlaBuilder b(TestName());
|
||||||
auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
|
auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||||
|
|
||||||
auto value = ComputeDynamismScalar(client, computation, &b);
|
auto value = ComputeDynamismScalar(client, computation, &b);
|
||||||
ASSERT_TRUE(value.ok()) << value.status();
|
ASSERT_TRUE(value.ok()) << value.status();
|
||||||
@ -160,7 +187,7 @@ TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) {
|
|||||||
Client* client = ClientOrDie(platform_, client_type);
|
Client* client = ClientOrDie(platform_, client_type);
|
||||||
XlaBuilder b(TestName());
|
XlaBuilder b(TestName());
|
||||||
auto c = ConstantR0<int32>(&b, 42);
|
auto c = ConstantR0<int32>(&b, 42);
|
||||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
|
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||||
|
|
||||||
auto neg0 = Neg(c);
|
auto neg0 = Neg(c);
|
||||||
auto neg1 = Neg(p);
|
auto neg1 = Neg(p);
|
||||||
@ -177,7 +204,7 @@ TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) {
|
|||||||
Client* client = ClientOrDie(platform_, client_type);
|
Client* client = ClientOrDie(platform_, client_type);
|
||||||
XlaBuilder b(TestName());
|
XlaBuilder b(TestName());
|
||||||
auto c = ConstantR0<int32>(&b, 42);
|
auto c = ConstantR0<int32>(&b, 42);
|
||||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
|
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
|
||||||
|
|
||||||
// Static value + static value = static
|
// Static value + static value = static
|
||||||
auto add1 = Add(c, c);
|
auto add1 = Add(c, c);
|
||||||
@ -198,8 +225,8 @@ TEST_F(DynamismInferenceTest, GetDimensionSize) {
|
|||||||
// param = Param([<=2, 3])
|
// param = Param([<=2, 3])
|
||||||
// get_dimension_size(param, 0) is dynamic
|
// get_dimension_size(param, 0) is dynamic
|
||||||
// get_dimension_size(param, 1) is static
|
// get_dimension_size(param, 1) is static
|
||||||
auto p =
|
auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}),
|
||||||
Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "0");
|
"p0");
|
||||||
|
|
||||||
auto gds0 = GetDimensionSize(p, 0);
|
auto gds0 = GetDimensionSize(p, 0);
|
||||||
auto gds1 = GetDimensionSize(p, 1);
|
auto gds1 = GetDimensionSize(p, 1);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user