DynamismInference: Support instruction with/without rewrite in the same graph.

- Dynamism inference is used to decide if a value is dynamic or not.
- In dynamism inference, we rewrite some instructions into boolean form.
E.g.,
 We rewrite
   A = constant(0)
   B = parameter(0)
   ROOT A + B
 into
   A' = constant(false)
   B' = constant(true)
   ROOT A' | B'

- We also don't rewrite some instructions:
E.g.,
   A = constant(0)
   B = parameter(0)
   C = constant(0)
   D = parameter(0)
   P = C == D
  ROOT select(P,A,B)
Into
   A' = constant(false)
   B' = constant(true)
   C = constant(0)
   D = parameter(0)
   P = C == D
  ROOT select(P,A',B')

We don't rewrite P, and instructions reachable from P.
- This cl fixes an issue where this two forms are mixed together:
E.g.,
   A = constant(0)
   B = parameter(0)
   P = A == B
  ROOT select(P,A,B)

Previously the pass would fail.

PiperOrigin-RevId: 327889288
Change-Id: I3dd419ca5d729bb857d3fcac8fd76d47788aa5c2
This commit is contained in:
Yunxing Dai 2020-08-21 15:58:29 -07:00 committed by TensorFlower Gardener
parent 90d58ce333
commit 9f6e57df30
2 changed files with 127 additions and 59 deletions

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/comparison_util.h"
@ -78,16 +79,13 @@ ShapeProto ConvertShapeProtoToPred(const ShapeProto& shape_proto) {
return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto(); return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto();
} }
HloInstructionProto CreateConstantInstruction(int64 id, const Shape& shape, void SetInstructionAsConstant(HloInstructionProto* instr, int64 id,
bool pred) { const Shape& shape, bool pred) {
HloInstructionProto const_instr;
Literal literal = LiteralUtil::CreateR0(pred); Literal literal = LiteralUtil::CreateR0(pred);
Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie(); Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie();
*const_instr.mutable_shape() = shape.ToProto(); *instr->mutable_shape() = shape.ToProto();
*const_instr.mutable_literal() = literal_broadcast.ToProto(); *instr->mutable_literal() = literal_broadcast.ToProto();
*const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant); *instr->mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
const_instr.set_id(id);
return const_instr;
} }
// Converts a HloComputation into ReducerOr with predicate types. // Converts a HloComputation into ReducerOr with predicate types.
@ -2971,27 +2969,12 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
*program_shape->mutable_result() = *program_shape->mutable_result() =
ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto(); ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto();
std::set<int64> seen;
struct WorkItem {
explicit WorkItem(int64 handle, bool need_rewrite)
: handle(handle), need_rewrite(need_rewrite) {}
int64 handle;
// If need_rewrite is true, the instruction will be copied and rewrite into
// a pred instruction indicating if each value is dynamic. If need_rewrite
// is false, simply copy the instruction to the output graph.
// E.g.,
// For select(P, A, B), we need to rewrite A and B into predicates, but
// don't need to rewrite P.
bool need_rewrite;
};
std::queue<WorkItem> worklist;
worklist.push(WorkItem(root->id(), true));
entry.set_root_id(root->id());
std::vector<HloComputationProto> called_computatons; std::vector<HloComputationProto> called_computatons;
// Rewritre instruction with id "from" into the new graph. // Process instruction and copy it into the new graph. The new node in the new
// Returns more work items that need to finish. // graph with have id set to `id`.
auto rewrite_instruction = auto process_instruction = [&](const HloInstructionProto* instr_proto,
[&](int64 from, bool need_rewrite) -> StatusOr<std::vector<WorkItem>> { bool need_rewrite, int64 id,
absl::Span<int64 const> operand_ids) {
// Rewrite the instruction with following rules: // Rewrite the instruction with following rules:
// - Unary ops: Convert into bitcast (identity) with type Pred. // - Unary ops: Convert into bitcast (identity) with type Pred.
// - Binary ops: Convert into binary or. // - Binary ops: Convert into binary or.
@ -3004,22 +2987,20 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
// - Constant: Convert to constant False. // - Constant: Convert to constant False.
// - Other ops: Not supported. // - Other ops: Not supported.
// Create the instruction for the new handle. // Create the instruction for the new handle.
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
LookUpInstructionByHandle(from));
TF_ASSIGN_OR_RETURN(HloOpcode opcode, TF_ASSIGN_OR_RETURN(HloOpcode opcode,
StringToHloOpcode(instr_proto->opcode())); StringToHloOpcode(instr_proto->opcode()));
std::vector<WorkItem> operands_todo;
auto* new_instr = entry.add_instructions(); auto* new_instr = entry.add_instructions();
*new_instr = *instr_proto; *new_instr = *instr_proto;
for (auto operand_id : new_instr->operand_ids()) { new_instr->set_id(id);
operands_todo.emplace_back(operand_id, need_rewrite); new_instr->mutable_operand_ids()->Clear();
for (auto operand_id : operand_ids) {
new_instr->mutable_operand_ids()->Add(operand_id);
} }
if (!need_rewrite) { if (!need_rewrite) {
*new_instr->mutable_name() = *new_instr->mutable_name() =
GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id()); GetFullName(instr_proto->opcode(), kNameSeparator, id);
return operands_todo; return Status::OK();
} }
*new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape()); *new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape());
Shape new_shape(new_instr->shape()); Shape new_shape(new_instr->shape());
@ -3074,10 +3055,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
*new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr); *new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
break; break;
case HloOpcode::kSelect: case HloOpcode::kSelect:
operands_todo[0].need_rewrite = false;
break; break;
case HloOpcode::kGather: case HloOpcode::kGather:
operands_todo[1].need_rewrite = false;
break; break;
case HloOpcode::kReduce: { case HloOpcode::kReduce: {
int64 reducer_id = new_instr->called_computation_ids(0); int64 reducer_id = new_instr->called_computation_ids(0);
@ -3099,39 +3078,101 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
LookUpInstructionByHandle(operand_handle)); LookUpInstructionByHandle(operand_handle));
*new_instr = CreateConstantInstruction( SetInstructionAsConstant(
from, new_shape, new_instr, id, new_shape,
operand_proto->shape().is_dynamic_dimension(dimension)); operand_proto->shape().is_dynamic_dimension(dimension));
operands_todo.clear();
break; break;
} }
case HloOpcode::kConstant: case HloOpcode::kConstant:
*new_instr = CreateConstantInstruction(from, new_shape, false); SetInstructionAsConstant(new_instr, id, new_shape, false);
break; break;
case HloOpcode::kParameter: case HloOpcode::kParameter:
*new_instr = CreateConstantInstruction(from, new_shape, true); SetInstructionAsConstant(new_instr, id, new_shape, true);
break; break;
default: default:
return InvalidArgument("Dynamic inferencing %s is not supported", return InvalidArgument("Dynamic inferencing %s is not supported",
instr_proto->DebugString()); instr_proto->DebugString());
} }
*new_instr->mutable_name() = *new_instr->mutable_name() =
GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id()); GetFullName(instr_proto->opcode(), kNameSeparator, id);
return operands_todo; return Status::OK();
}; };
struct WorkItem {
explicit WorkItem(int64 handle, bool need_rewrite)
: handle(handle), need_rewrite(need_rewrite), visited(false) {}
int64 handle;
// If need_rewrite is true, the instruction will be copied and rewrite into
// a pred instruction indicating if each value is dynamic. If need_rewrite
// is false, simply copy the instruction to the output graph.
// E.g.,
// For select(P, A, B), we need to rewrite A and B into predicates, but
// don't need to rewrite P.
bool need_rewrite;
// Used in dfs to remember the ids of processed operands of this item.
std::vector<int64> processed_operands;
// Whether this node been visited before or not.
bool visited;
};
// Only copy each pair of {handle, need_rewrite} once. Value is the id in the
// new graph.
absl::flat_hash_map<std::pair<int64, bool>, int64> seen;
// Monotonically increasing id to assign to new instructions.
int64 global_id = 0;
// The result id of the last rewritten item -- return value of last stack
// item.
int64 stacktop_id = -1;
std::vector<WorkItem> worklist;
worklist.push_back(WorkItem(root->id(), true));
while (!worklist.empty()) { while (!worklist.empty()) {
WorkItem item = worklist.front(); WorkItem& item = worklist.back();
worklist.pop(); auto item_key = std::make_pair(item.handle, item.need_rewrite);
if (!seen.insert(item.handle).second) { auto iter = seen.find(item_key);
// Already processed this item. Return previous results.
if (iter != seen.end()) {
stacktop_id = iter->second;
worklist.pop_back();
continue; continue;
} }
TF_ASSIGN_OR_RETURN(auto todos,
rewrite_instruction(item.handle, item.need_rewrite)); int64 next_operand = item.processed_operands.size();
for (WorkItem& todo : todos) { TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
worklist.push(todo); LookUpInstructionByHandle(item.handle));
VLOG(3) << "Processing" << instr_proto->name();
if (!item.visited) {
item.visited = true;
} else {
// Record previous processed operand.
item.processed_operands.push_back(stacktop_id);
next_operand++;
} }
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
StringToHloOpcode(instr_proto->opcode()));
if (next_operand >= instr_proto->operand_ids_size() ||
opcode == HloOpcode::kGetDimensionSize) {
// No more operands to process, process self.
int64 new_id = ++global_id;
VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name();
TF_RETURN_IF_ERROR(process_instruction(instr_proto, item.need_rewrite,
new_id, item.processed_operands));
stacktop_id = new_id;
seen[item_key] = stacktop_id;
worklist.pop_back();
continue;
}
WorkItem next_item(instr_proto->operand_ids(next_operand), true);
if (opcode == HloOpcode::kSelect && next_operand == 0) {
next_item.need_rewrite = false;
}
if (opcode == HloOpcode::kGather && next_operand == 1) {
next_item.need_rewrite = false;
}
// Push next operand into worklist.
worklist.push_back(next_item);
} }
TF_RET_CHECK(stacktop_id != -1);
entry.set_root_id(stacktop_id);
absl::c_sort(*entry.mutable_instructions(), absl::c_sort(*entry.mutable_instructions(),
[](const HloInstructionProto& p1, [](const HloInstructionProto& p1,
const HloInstructionProto& p2) { return p1.id() < p2.id(); }); const HloInstructionProto& p2) { return p1.id() < p2.id(); });

View File

@ -104,12 +104,26 @@ TEST_F(DynamismInferenceTest, ScalarInt32Literal) {
} }
} }
TEST_F(DynamismInferenceTest, TupleSimple) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
auto tuple = Tuple(&b, {c, p});
EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {0}).ValueOrDie(),
false);
EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {1}).ValueOrDie(), true);
}
}
TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) { TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
for (ClientType client_type : client_types) { for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type); Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName()); XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42); auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
auto tuple = Tuple(&b, {c, p}); auto tuple = Tuple(&b, {c, p});
auto gte0 = GetTupleElement(tuple, 0); auto gte0 = GetTupleElement(tuple, 0);
@ -122,12 +136,25 @@ TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
} }
} }
TEST_F(DynamismInferenceTest, PredValueUsedTwice) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
auto pred = Eq(c, p);
auto result = Select(pred, p, c);
EXPECT_EQ(ComputeDynamismScalar(client, result, &b, {}).ValueOrDie(),
false);
}
}
TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) { TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
for (ClientType client_type : client_types) { for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type); Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName()); XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42); auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
auto concat = ConcatScalars(&b, {c, p}); auto concat = ConcatScalars(&b, {c, p});
auto slice0 = SliceInDim(concat, 0, 1, 1, 0); auto slice0 = SliceInDim(concat, 0, 1, 1, 0);
@ -146,7 +173,7 @@ TEST_F(DynamismInferenceTest, ParameterIsDynamic) {
for (ClientType client_type : client_types) { for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type); Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName()); XlaBuilder b(TestName());
auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
auto value = ComputeDynamismScalar(client, computation, &b); auto value = ComputeDynamismScalar(client, computation, &b);
ASSERT_TRUE(value.ok()) << value.status(); ASSERT_TRUE(value.ok()) << value.status();
@ -160,7 +187,7 @@ TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) {
Client* client = ClientOrDie(platform_, client_type); Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName()); XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42); auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
auto neg0 = Neg(c); auto neg0 = Neg(c);
auto neg1 = Neg(p); auto neg1 = Neg(p);
@ -177,7 +204,7 @@ TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) {
Client* client = ClientOrDie(platform_, client_type); Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName()); XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42); auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
// Static value + static value = static // Static value + static value = static
auto add1 = Add(c, c); auto add1 = Add(c, c);
@ -198,8 +225,8 @@ TEST_F(DynamismInferenceTest, GetDimensionSize) {
// param = Param([<=2, 3]) // param = Param([<=2, 3])
// get_dimension_size(param, 0) is dynamic // get_dimension_size(param, 0) is dynamic
// get_dimension_size(param, 1) is static // get_dimension_size(param, 1) is static
auto p = auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}),
Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "0"); "p0");
auto gds0 = GetDimensionSize(p, 0); auto gds0 = GetDimensionSize(p, 0);
auto gds1 = GetDimensionSize(p, 1); auto gds1 = GetDimensionSize(p, 1);