DynamismInference: Support instruction with/without rewrite in the same graph.

- Dynamism inference is used to decide if a value is dynamic or not.
- In dynamism inference, we rewrite some instructions into boolean form.
E.g.,
 We rewrite
   A = constant(0)
   B = parameter(0)
   ROOT A + B
 into
   A' = constant(false)
   B' = constant(true)
   ROOT A' | B'

- We also don't rewrite some instructions:
E.g.,
   A = constant(0)
   B = parameter(0)
   C = constant(0)
   D = parameter(0)
   P = C == D
  ROOT select(P,A,B)
Into
   A' = constant(false)
   B' = constant(true)
   C = constant(0)
   D = parameter(0)
   P = C == D
  ROOT select(P,A',B')

We don't rewrite P, and instructions reachable from P.
- This cl fixes an issue where this two forms are mixed together:
E.g.,
   A = constant(0)
   B = parameter(0)
   P = A == B
  ROOT select(P,A,B)

Previously the pass would fail.

PiperOrigin-RevId: 327889288
Change-Id: I3dd419ca5d729bb857d3fcac8fd76d47788aa5c2
This commit is contained in:
Yunxing Dai 2020-08-21 15:58:29 -07:00 committed by TensorFlower Gardener
parent 90d58ce333
commit 9f6e57df30
2 changed files with 127 additions and 59 deletions

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/comparison_util.h"
@ -78,16 +79,13 @@ ShapeProto ConvertShapeProtoToPred(const ShapeProto& shape_proto) {
return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto();
}
HloInstructionProto CreateConstantInstruction(int64 id, const Shape& shape,
bool pred) {
HloInstructionProto const_instr;
void SetInstructionAsConstant(HloInstructionProto* instr, int64 id,
const Shape& shape, bool pred) {
Literal literal = LiteralUtil::CreateR0(pred);
Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie();
*const_instr.mutable_shape() = shape.ToProto();
*const_instr.mutable_literal() = literal_broadcast.ToProto();
*const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
const_instr.set_id(id);
return const_instr;
*instr->mutable_shape() = shape.ToProto();
*instr->mutable_literal() = literal_broadcast.ToProto();
*instr->mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
}
// Converts a HloComputation into ReducerOr with predicate types.
@ -2971,27 +2969,12 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
*program_shape->mutable_result() =
ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto();
std::set<int64> seen;
struct WorkItem {
explicit WorkItem(int64 handle, bool need_rewrite)
: handle(handle), need_rewrite(need_rewrite) {}
int64 handle;
// If need_rewrite is true, the instruction will be copied and rewrite into
// a pred instruction indicating if each value is dynamic. If need_rewrite
// is false, simply copy the instruction to the output graph.
// E.g.,
// For select(P, A, B), we need to rewrite A and B into predicates, but
// don't need to rewrite P.
bool need_rewrite;
};
std::queue<WorkItem> worklist;
worklist.push(WorkItem(root->id(), true));
entry.set_root_id(root->id());
std::vector<HloComputationProto> called_computatons;
// Rewritre instruction with id "from" into the new graph.
// Returns more work items that need to finish.
auto rewrite_instruction =
[&](int64 from, bool need_rewrite) -> StatusOr<std::vector<WorkItem>> {
// Process instruction and copy it into the new graph. The new node in the new
// graph with have id set to `id`.
auto process_instruction = [&](const HloInstructionProto* instr_proto,
bool need_rewrite, int64 id,
absl::Span<int64 const> operand_ids) {
// Rewrite the instruction with following rules:
// - Unary ops: Convert into bitcast (identity) with type Pred.
// - Binary ops: Convert into binary or.
@ -3004,22 +2987,20 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
// - Constant: Convert to constant False.
// - Other ops: Not supported.
// Create the instruction for the new handle.
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
LookUpInstructionByHandle(from));
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
StringToHloOpcode(instr_proto->opcode()));
std::vector<WorkItem> operands_todo;
auto* new_instr = entry.add_instructions();
*new_instr = *instr_proto;
for (auto operand_id : new_instr->operand_ids()) {
operands_todo.emplace_back(operand_id, need_rewrite);
new_instr->set_id(id);
new_instr->mutable_operand_ids()->Clear();
for (auto operand_id : operand_ids) {
new_instr->mutable_operand_ids()->Add(operand_id);
}
if (!need_rewrite) {
*new_instr->mutable_name() =
GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id());
return operands_todo;
GetFullName(instr_proto->opcode(), kNameSeparator, id);
return Status::OK();
}
*new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape());
Shape new_shape(new_instr->shape());
@ -3074,10 +3055,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
*new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
break;
case HloOpcode::kSelect:
operands_todo[0].need_rewrite = false;
break;
case HloOpcode::kGather:
operands_todo[1].need_rewrite = false;
break;
case HloOpcode::kReduce: {
int64 reducer_id = new_instr->called_computation_ids(0);
@ -3099,39 +3078,101 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
LookUpInstructionByHandle(operand_handle));
*new_instr = CreateConstantInstruction(
from, new_shape,
SetInstructionAsConstant(
new_instr, id, new_shape,
operand_proto->shape().is_dynamic_dimension(dimension));
operands_todo.clear();
break;
}
case HloOpcode::kConstant:
*new_instr = CreateConstantInstruction(from, new_shape, false);
SetInstructionAsConstant(new_instr, id, new_shape, false);
break;
case HloOpcode::kParameter:
*new_instr = CreateConstantInstruction(from, new_shape, true);
SetInstructionAsConstant(new_instr, id, new_shape, true);
break;
default:
return InvalidArgument("Dynamic inferencing %s is not supported",
instr_proto->DebugString());
}
*new_instr->mutable_name() =
GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id());
return operands_todo;
GetFullName(instr_proto->opcode(), kNameSeparator, id);
return Status::OK();
};
struct WorkItem {
explicit WorkItem(int64 handle, bool need_rewrite)
: handle(handle), need_rewrite(need_rewrite), visited(false) {}
int64 handle;
// If need_rewrite is true, the instruction will be copied and rewrite into
// a pred instruction indicating if each value is dynamic. If need_rewrite
// is false, simply copy the instruction to the output graph.
// E.g.,
// For select(P, A, B), we need to rewrite A and B into predicates, but
// don't need to rewrite P.
bool need_rewrite;
// Used in dfs to remember the ids of processed operands of this item.
std::vector<int64> processed_operands;
// Whether this node been visited before or not.
bool visited;
};
// Only copy each pair of {handle, need_rewrite} once. Value is the id in the
// new graph.
absl::flat_hash_map<std::pair<int64, bool>, int64> seen;
// Monotonically increasing id to assign to new instructions.
int64 global_id = 0;
// The result id of the last rewritten item -- return value of last stack
// item.
int64 stacktop_id = -1;
std::vector<WorkItem> worklist;
worklist.push_back(WorkItem(root->id(), true));
while (!worklist.empty()) {
WorkItem item = worklist.front();
worklist.pop();
if (!seen.insert(item.handle).second) {
WorkItem& item = worklist.back();
auto item_key = std::make_pair(item.handle, item.need_rewrite);
auto iter = seen.find(item_key);
// Already processed this item. Return previous results.
if (iter != seen.end()) {
stacktop_id = iter->second;
worklist.pop_back();
continue;
}
TF_ASSIGN_OR_RETURN(auto todos,
rewrite_instruction(item.handle, item.need_rewrite));
for (WorkItem& todo : todos) {
worklist.push(todo);
int64 next_operand = item.processed_operands.size();
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
LookUpInstructionByHandle(item.handle));
VLOG(3) << "Processing" << instr_proto->name();
if (!item.visited) {
item.visited = true;
} else {
// Record previous processed operand.
item.processed_operands.push_back(stacktop_id);
next_operand++;
}
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
StringToHloOpcode(instr_proto->opcode()));
if (next_operand >= instr_proto->operand_ids_size() ||
opcode == HloOpcode::kGetDimensionSize) {
// No more operands to process, process self.
int64 new_id = ++global_id;
VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name();
TF_RETURN_IF_ERROR(process_instruction(instr_proto, item.need_rewrite,
new_id, item.processed_operands));
stacktop_id = new_id;
seen[item_key] = stacktop_id;
worklist.pop_back();
continue;
}
WorkItem next_item(instr_proto->operand_ids(next_operand), true);
if (opcode == HloOpcode::kSelect && next_operand == 0) {
next_item.need_rewrite = false;
}
if (opcode == HloOpcode::kGather && next_operand == 1) {
next_item.need_rewrite = false;
}
// Push next operand into worklist.
worklist.push_back(next_item);
}
TF_RET_CHECK(stacktop_id != -1);
entry.set_root_id(stacktop_id);
absl::c_sort(*entry.mutable_instructions(),
[](const HloInstructionProto& p1,
const HloInstructionProto& p2) { return p1.id() < p2.id(); });

View File

@ -104,12 +104,26 @@ TEST_F(DynamismInferenceTest, ScalarInt32Literal) {
}
}
TEST_F(DynamismInferenceTest, TupleSimple) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
auto tuple = Tuple(&b, {c, p});
EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {0}).ValueOrDie(),
false);
EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {1}).ValueOrDie(), true);
}
}
TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
auto tuple = Tuple(&b, {c, p});
auto gte0 = GetTupleElement(tuple, 0);
@ -122,12 +136,25 @@ TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
}
}
TEST_F(DynamismInferenceTest, PredValueUsedTwice) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
auto pred = Eq(c, p);
auto result = Select(pred, p, c);
EXPECT_EQ(ComputeDynamismScalar(client, result, &b, {}).ValueOrDie(),
false);
}
}
TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
auto concat = ConcatScalars(&b, {c, p});
auto slice0 = SliceInDim(concat, 0, 1, 1, 0);
@ -146,7 +173,7 @@ TEST_F(DynamismInferenceTest, ParameterIsDynamic) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
auto value = ComputeDynamismScalar(client, computation, &b);
ASSERT_TRUE(value.ok()) << value.status();
@ -160,7 +187,7 @@ TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
auto neg0 = Neg(c);
auto neg1 = Neg(p);
@ -177,7 +204,7 @@ TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
// Static value + static value = static
auto add1 = Add(c, c);
@ -198,8 +225,8 @@ TEST_F(DynamismInferenceTest, GetDimensionSize) {
// param = Param([<=2, 3])
// get_dimension_size(param, 0) is dynamic
// get_dimension_size(param, 1) is static
auto p =
Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "0");
auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}),
"p0");
auto gds0 = GetDimensionSize(p, 0);
auto gds1 = GetDimensionSize(p, 1);