[XLA] Use dynamism inference to infer dynamic dimensions for reshape.
- Introduce dynamism inference function in xla builder, which tells if a value is dynamic or static. - Use dynamism inference to infer whether an input to reshape's dimensions is dynamic. - This removes the "-1" hack I made before in the bridge, makes the code cleaner. Plus it can support more complex cases dynamic reshape when the dimension comes from a series of transformations. PiperOrigin-RevId: 325532056 Change-Id: Icc5bad39a857be77537e4736dd6863b833e2fe9d
This commit is contained in:
parent
ff457d4d01
commit
2e3e2bb335
@ -109,27 +109,33 @@ class ReshapeOp : public XlaOpKernel {
|
||||
VLOG(2) << "Reshape from " << input_shape.DebugString() << " to "
|
||||
<< shape.DebugString() << ", unknown_index=" << unknown_index;
|
||||
|
||||
shape_input.clear();
|
||||
// Run get input again, this time with dynamic dimension represented as
|
||||
// "-1"
|
||||
ctx->set_dynamic_dimension_is_minus_one(true);
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input));
|
||||
|
||||
int dynamic_dimension = -1;
|
||||
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
const int32 size = shape_input[d];
|
||||
if (size == -1) {
|
||||
if (dynamic_dimension == -1) {
|
||||
if (ctx->InputXlaShape(0)->is_dynamic()) {
|
||||
std::vector<bool> dynamic_dims;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->ResolveInputDynamismIntoPredVector(1, &dynamic_dims));
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
const bool dim_is_dynamic = dynamic_dims[d];
|
||||
if (dim_is_dynamic) {
|
||||
dynamic_dimension = d;
|
||||
} else {
|
||||
if (unknown_index != d) {
|
||||
dynamic_dimension = d;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// When reshaping from dynamic dimension, unkwown index is considered
|
||||
// dynamic. E.g.,
|
||||
// [<=10]
|
||||
// |
|
||||
// Reshape
|
||||
// |
|
||||
// [2, -1]
|
||||
// The second dimension is dynamic.
|
||||
if (dynamic_dimension == -1) {
|
||||
dynamic_dimension = unknown_index;
|
||||
}
|
||||
VLOG(2) << "Reshape from " << ctx->InputXlaShape(0)->ToString() << " to "
|
||||
<< xla::VectorString(shape.dim_sizes())
|
||||
<< ", dynamic_dim=" << dynamic_dimension;
|
||||
}
|
||||
// Pass unknown_index to Xla::Reshape as a hint for dynamic shape inference
|
||||
// in XLA to know which output dimension is dynamic.
|
||||
ctx->SetOutput(0, xla::ReshapeWithInferredDimension(
|
||||
|
@ -101,6 +101,48 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const {
|
||||
});
|
||||
}
|
||||
|
||||
xla::StatusOr<Tensor> XlaExpression::ResolveDynamism(
|
||||
xla::Client* client) const {
|
||||
switch (kind()) {
|
||||
case Kind::kConstant: {
|
||||
// Constant values are considered static.
|
||||
Tensor constant_false(DT_BOOL, constant_value().shape());
|
||||
auto flat = constant_false.flat<bool>();
|
||||
for (int64 i = 0; i < flat.size(); ++i) flat(i) = false;
|
||||
return constant_false;
|
||||
}
|
||||
case Kind::kXlaOp:
|
||||
break;
|
||||
case Kind::kTensorList:
|
||||
TF_FALLTHROUGH_INTENDED;
|
||||
case Kind::kResource:
|
||||
TF_FALLTHROUGH_INTENDED;
|
||||
case Kind::kInvalid:
|
||||
return errors::InvalidArgument(
|
||||
"ResolveDynamism called on unsupported XlaExpression: ",
|
||||
HumanString());
|
||||
}
|
||||
|
||||
if (!client)
|
||||
return errors::InvalidArgument("client is required to resolve constant");
|
||||
|
||||
TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
|
||||
handle().builder()->BuildDynamicInferenceGraph(handle()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
|
||||
|
||||
// The XLA layout is specified minor to major, and TensorFlow uses a major to
|
||||
// minor order.
|
||||
std::vector<int64> layout_indices(shape.dims());
|
||||
std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
|
||||
xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
|
||||
TF_ASSIGN_OR_RETURN(xla::Literal literal,
|
||||
client->ComputeConstant(constant_graph, &layout));
|
||||
Tensor tensor(DT_BOOL);
|
||||
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, DT_BOOL, &tensor));
|
||||
return tensor;
|
||||
}
|
||||
|
||||
xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
|
||||
xla::Client* client, bool dynamic_dimension_is_minus_one) const {
|
||||
switch (kind()) {
|
||||
|
@ -99,6 +99,10 @@ class XlaExpression {
|
||||
xla::StatusOr<absl::optional<Tensor>> ResolveConstant(
|
||||
xla::Client* client, bool dynamic_dimension_is_minus_one = false) const;
|
||||
|
||||
// ResolveDynamism computes where a value inside this op is dynamic or can be
|
||||
// inferred at compile time.
|
||||
xla::StatusOr<Tensor> ResolveDynamism(xla::Client* client) const;
|
||||
|
||||
// Returns the shape of the tensor.
|
||||
// The shape of a resource is the shape of a resource handle (i.e., a scalar),
|
||||
// not the shape of the resource's value.
|
||||
|
@ -243,6 +243,48 @@ Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) {
|
||||
return LiteralToFloat64Scalar(literal, out);
|
||||
}
|
||||
|
||||
static Status LiteralToPredVector(const xla::LiteralSlice& literal,
|
||||
std::vector<bool>* out) {
|
||||
if (literal.shape().rank() != 1) {
|
||||
return errors::InvalidArgument("value is not 1D, rank: ",
|
||||
literal.shape().rank());
|
||||
}
|
||||
int64 size = xla::ShapeUtil::ElementsIn(literal.shape());
|
||||
if (literal.shape().element_type() != xla::PRED) {
|
||||
return errors::InvalidArgument("value is not PRED");
|
||||
}
|
||||
for (int64 i = 0; i < size; ++i) {
|
||||
out->push_back(literal.Get<bool>({i}));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
|
||||
int index, std::vector<bool>* out) {
|
||||
xla::Literal literal;
|
||||
XlaExpression e = InputExpression(index);
|
||||
auto* client = compiler() ? compiler()->client() : nullptr;
|
||||
xla::StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
|
||||
if (!dynamism_or_status.ok()) {
|
||||
Status status = dynamism_or_status.status();
|
||||
errors::AppendToMessage(&status, "while evaluating input dynamism", index,
|
||||
" of ", context_->op_kernel().type_string());
|
||||
return status;
|
||||
}
|
||||
Tensor dynamism = dynamism_or_status.ValueOrDie();
|
||||
|
||||
Tensor temp(dynamism.dtype());
|
||||
TensorShape tensor_shape({InputShape(index).num_elements()});
|
||||
if (!temp.CopyFrom(dynamism, tensor_shape)) {
|
||||
return errors::InvalidArgument(
|
||||
context_->op_kernel().name(), " input ", index, " has shape ",
|
||||
dynamism.shape().DebugString(), " which is not a R1 ", tensor_shape);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp));
|
||||
return LiteralToPredVector(literal, out);
|
||||
}
|
||||
|
||||
// Converts an int32 or int64 1D literal to an int64 vector.
|
||||
static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
|
||||
std::vector<int64>* out) {
|
||||
|
@ -116,6 +116,9 @@ class XlaOpKernelContext {
|
||||
// returns a one-element list.
|
||||
Status InputList(absl::string_view name, std::vector<xla::XlaOp>* handles,
|
||||
std::vector<TensorShape>* shapes);
|
||||
// Evaluates input and returns their dynamism vector in a vector of
|
||||
// predicates.
|
||||
Status ResolveInputDynamismIntoPredVector(int index, std::vector<bool>* out);
|
||||
|
||||
// Helper methods for constant inputs.
|
||||
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
#include "tensorflow/compiler/xla/execution_options_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
@ -39,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -71,6 +73,52 @@ void SetProtoIdAndName(T* entry, const string& base_name, char separator,
|
||||
entry->set_id(id);
|
||||
entry->set_name(GetFullName(base_name, separator, id));
|
||||
}
|
||||
|
||||
ShapeProto ConvertShapeProtoToPred(const ShapeProto& shape_proto) {
|
||||
return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto();
|
||||
}
|
||||
|
||||
HloInstructionProto CreateConstantInstruction(int64 id, const Shape& shape,
|
||||
bool pred) {
|
||||
HloInstructionProto const_instr;
|
||||
Literal literal = LiteralUtil::CreateR0(pred);
|
||||
Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie();
|
||||
*const_instr.mutable_shape() = shape.ToProto();
|
||||
*const_instr.mutable_literal() = literal_broadcast.ToProto();
|
||||
*const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
|
||||
const_instr.set_id(id);
|
||||
return const_instr;
|
||||
}
|
||||
|
||||
// Converts a HloComputation into ReducerOr with predicate types.
|
||||
HloComputationProto CreateReduceOr(int64 reducer_id,
|
||||
HloComputationProto* original_reducer) {
|
||||
HloComputationProto reducer;
|
||||
SetProtoIdAndName(&reducer, StrCat("reduce_or"), kNameSeparator, reducer_id);
|
||||
std::vector<int64> operands_id;
|
||||
for (auto& inst : original_reducer->instructions()) {
|
||||
// Copy params.
|
||||
if (StringToHloOpcode(inst.opcode()).ValueOrDie() ==
|
||||
HloOpcode::kParameter) {
|
||||
HloInstructionProto* new_param = reducer.add_instructions();
|
||||
*new_param = inst;
|
||||
*new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
|
||||
operands_id.push_back(inst.id());
|
||||
}
|
||||
if (inst.id() == original_reducer->root_id()) {
|
||||
HloInstructionProto* new_root = reducer.add_instructions();
|
||||
*new_root = inst;
|
||||
*new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
|
||||
*new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
|
||||
new_root->clear_operand_ids();
|
||||
for (int64 operand_id : operands_id) {
|
||||
new_root->add_operand_ids(operand_id);
|
||||
}
|
||||
reducer.set_root_id(inst.id());
|
||||
}
|
||||
}
|
||||
return reducer;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace internal {
|
||||
@ -2842,6 +2890,196 @@ StatusOr<bool> XlaBuilder::IsConstant(XlaOp operand) const {
|
||||
return is_constant;
|
||||
}
|
||||
|
||||
StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
|
||||
LookUpInstruction(root_op));
|
||||
|
||||
HloComputationProto entry;
|
||||
SetProtoIdAndName(&entry, StrCat(name_, "_dynamic_inference"), kNameSeparator,
|
||||
GetNextId());
|
||||
ProgramShapeProto* program_shape = entry.mutable_program_shape();
|
||||
*program_shape->mutable_result() =
|
||||
ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto();
|
||||
|
||||
std::set<int64> seen;
|
||||
struct WorkItem {
|
||||
explicit WorkItem(int64 handle, bool need_rewrite)
|
||||
: handle(handle), need_rewrite(need_rewrite) {}
|
||||
int64 handle;
|
||||
// If need_rewrite is true, the instruction will be copied and rewrite into
|
||||
// a pred instruction indicating if each value is dynamic. If need_rewrite
|
||||
// is false, simply copy the instruction to the output graph.
|
||||
// E.g.,
|
||||
// For select(P, A, B), we need to rewrite A and B into predicates, but
|
||||
// don't need to rewrite P.
|
||||
bool need_rewrite;
|
||||
};
|
||||
std::queue<WorkItem> worklist;
|
||||
worklist.push(WorkItem(root->id(), true));
|
||||
entry.set_root_id(root->id());
|
||||
std::vector<HloComputationProto> called_computatons;
|
||||
// Rewritre instruction with id "from" into the new graph.
|
||||
// Returns more work items that need to finish.
|
||||
auto rewrite_instruction =
|
||||
[&](int64 from, bool need_rewrite) -> StatusOr<std::vector<WorkItem>> {
|
||||
// Rewrite the instruction with following rules:
|
||||
// - Unary ops: Convert into bitcast (identity) with type Pred.
|
||||
// - Binary ops: Convert into binary or.
|
||||
// - Select: Convert into binary or with its two data operands.
|
||||
// - Concat / Tuple/ GTE / Bitcast: Copy.
|
||||
// - Param: Convert to constant True.
|
||||
// - GetDimensionSize: Convert to constant True if dimension is dynamic,
|
||||
// contant False if dimension is static.
|
||||
// - Reduce: Convert to reduce or.
|
||||
// - Constant: Convert to constant False.
|
||||
// - Other ops: Not supported.
|
||||
// Create the instruction for the new handle.
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
|
||||
LookUpInstructionByHandle(from));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
|
||||
StringToHloOpcode(instr_proto->opcode()));
|
||||
std::vector<WorkItem> operands_todo;
|
||||
auto* new_instr = entry.add_instructions();
|
||||
*new_instr = *instr_proto;
|
||||
for (auto operand_id : new_instr->operand_ids()) {
|
||||
operands_todo.emplace_back(operand_id, need_rewrite);
|
||||
}
|
||||
|
||||
if (!need_rewrite) {
|
||||
*new_instr->mutable_name() =
|
||||
GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id());
|
||||
return operands_todo;
|
||||
}
|
||||
*new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape());
|
||||
Shape new_shape(new_instr->shape());
|
||||
switch (opcode) {
|
||||
case HloOpcode::kAbs:
|
||||
case HloOpcode::kRoundNearestAfz:
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kCollectivePermuteDone:
|
||||
case HloOpcode::kCos:
|
||||
case HloOpcode::kClz:
|
||||
case HloOpcode::kExp:
|
||||
case HloOpcode::kExpm1:
|
||||
case HloOpcode::kFloor:
|
||||
case HloOpcode::kImag:
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kLog:
|
||||
case HloOpcode::kLog1p:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kPopulationCount:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kRsqrt:
|
||||
case HloOpcode::kLogistic:
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kConvert:
|
||||
case HloOpcode::kSqrt:
|
||||
case HloOpcode::kCbrt:
|
||||
case HloOpcode::kTanh:
|
||||
CHECK_EQ(instr_proto->operand_ids_size(), 1);
|
||||
*new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kBitcast);
|
||||
break;
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kAtan2:
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kMaximum:
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kPower:
|
||||
case HloOpcode::kRemainder:
|
||||
case HloOpcode::kSubtract:
|
||||
case HloOpcode::kCompare:
|
||||
case HloOpcode::kAnd:
|
||||
case HloOpcode::kOr:
|
||||
case HloOpcode::kXor:
|
||||
case HloOpcode::kShiftLeft:
|
||||
case HloOpcode::kShiftRightArithmetic:
|
||||
case HloOpcode::kShiftRightLogical:
|
||||
CHECK_EQ(instr_proto->operand_ids_size(), 2);
|
||||
*new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
|
||||
break;
|
||||
case HloOpcode::kSelect:
|
||||
operands_todo[0].need_rewrite = false;
|
||||
break;
|
||||
case HloOpcode::kGather:
|
||||
operands_todo[1].need_rewrite = false;
|
||||
break;
|
||||
case HloOpcode::kReduce: {
|
||||
int64 reducer_id = new_instr->called_computation_ids(0);
|
||||
called_computatons.push_back(
|
||||
CreateReduceOr(reducer_id, &embedded_[reducer_id]));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kTuple:
|
||||
case HloOpcode::kTranspose:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
case HloOpcode::kSlice:
|
||||
case HloOpcode::kBroadcast:
|
||||
case HloOpcode::kConcatenate:
|
||||
case HloOpcode::kReshape:
|
||||
break;
|
||||
case HloOpcode::kGetDimensionSize: {
|
||||
int64 dimension = instr_proto->dimensions(0);
|
||||
int64 operand_handle = instr_proto->operand_ids(0);
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
|
||||
LookUpInstructionByHandle(operand_handle));
|
||||
|
||||
*new_instr = CreateConstantInstruction(
|
||||
from, new_shape,
|
||||
operand_proto->shape().is_dynamic_dimension(dimension));
|
||||
operands_todo.clear();
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kConstant:
|
||||
*new_instr = CreateConstantInstruction(from, new_shape, false);
|
||||
break;
|
||||
case HloOpcode::kParameter:
|
||||
*new_instr = CreateConstantInstruction(from, new_shape, true);
|
||||
break;
|
||||
default:
|
||||
return InvalidArgument("Dynamic inferencing %s is not supported",
|
||||
instr_proto->DebugString());
|
||||
}
|
||||
*new_instr->mutable_name() =
|
||||
GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id());
|
||||
return operands_todo;
|
||||
};
|
||||
|
||||
while (!worklist.empty()) {
|
||||
WorkItem item = worklist.front();
|
||||
worklist.pop();
|
||||
if (!seen.insert(item.handle).second) {
|
||||
continue;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto todos,
|
||||
rewrite_instruction(item.handle, item.need_rewrite));
|
||||
for (WorkItem& todo : todos) {
|
||||
worklist.push(todo);
|
||||
}
|
||||
}
|
||||
absl::c_sort(*entry.mutable_instructions(),
|
||||
[](const HloInstructionProto& p1,
|
||||
const HloInstructionProto& p2) { return p1.id() < p2.id(); });
|
||||
XlaComputation computation(entry.id());
|
||||
HloModuleProto* module = computation.mutable_proto();
|
||||
module->set_name(entry.name());
|
||||
module->set_id(entry.id());
|
||||
module->set_entry_computation_name(entry.name());
|
||||
module->set_entry_computation_id(entry.id());
|
||||
*module->mutable_host_program_shape() = *program_shape;
|
||||
for (auto& called_comp : called_computatons) {
|
||||
*module->add_computations() = called_comp;
|
||||
}
|
||||
*module->add_computations() = std::move(entry);
|
||||
XLA_VLOG_LINES(3, module->DebugString());
|
||||
return std::move(computation);
|
||||
}
|
||||
|
||||
StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
||||
XlaOp root_op, bool dynamic_dimension_is_minus_one) {
|
||||
TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
|
||||
|
@ -278,6 +278,31 @@ class XlaBuilder {
|
||||
StatusOr<XlaComputation> BuildConstantSubGraph(
|
||||
XlaOp root_op, bool dynamic_dimension_is_uint_max = false);
|
||||
|
||||
// Similar to BuildConstantSubGraph, but with root element type changed to
|
||||
// boolean. A true value in the root indicates that the value is dynamic while
|
||||
// false value indicates that the value is a constant. This will copy the
|
||||
// needed ops/computations to the subgraph.
|
||||
//
|
||||
// E.g.,
|
||||
// Compuptation {
|
||||
// a = 3
|
||||
// b = param(0)
|
||||
// ROOT Tuple(a + b, a + 1, b + 1)
|
||||
// }
|
||||
// Calling BuildDynamicInferenceGraph on root will produce the following
|
||||
// graph:
|
||||
//
|
||||
// Compuptation {
|
||||
// a = False
|
||||
// b = True
|
||||
// ROOT Tuple(a | b, a, b)
|
||||
// }
|
||||
//
|
||||
// The result, which is (True, False, True) after evaluation, can be
|
||||
// interpreted as "First element is dynamic; Second element is static; Third
|
||||
// element is dynamic".
|
||||
StatusOr<XlaComputation> BuildDynamicInferenceGraph(XlaOp root_op);
|
||||
|
||||
// Returns the first error that was encountered while building the
|
||||
// computation. When an error is encountered, by default we return a vacuous
|
||||
// XlaOp and inform the user of the error that occurred while
|
||||
|
@ -805,7 +805,8 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
|
||||
}
|
||||
|
||||
if (input_dim_size > output_dim_size) {
|
||||
TF_RET_CHECK(input_dim_size % output_dim_size == 0);
|
||||
TF_RET_CHECK(input_dim_size % output_dim_size == 0)
|
||||
<< reshape->ToString();
|
||||
const int64 divisor = input_dim_size / output_dim_size;
|
||||
HloInstruction* divisor_hlo =
|
||||
hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
|
||||
|
@ -783,9 +783,18 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
|
||||
/* static */ Shape ShapeUtil::ChangeElementType(const Shape& original,
|
||||
PrimitiveType type) {
|
||||
Shape new_shape = original;
|
||||
new_shape.set_element_type(type);
|
||||
return new_shape;
|
||||
if (original.IsTuple()) {
|
||||
std::vector<Shape> new_operands;
|
||||
new_operands.reserve(original.tuple_shapes_size());
|
||||
for (const Shape& operand : original.tuple_shapes()) {
|
||||
new_operands.push_back(ChangeElementType(operand, type));
|
||||
}
|
||||
return MakeTupleShape(new_operands);
|
||||
} else {
|
||||
Shape new_shape = original;
|
||||
new_shape.set_element_type(type);
|
||||
return new_shape;
|
||||
}
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::IndexIsValid(const Shape& shape,
|
||||
|
@ -2088,6 +2088,31 @@ xla_test(
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "dynamism_inference_test",
|
||||
srcs = ["dynamism_inference_test.cc"],
|
||||
deps = [
|
||||
":test_macros_header",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:global_data",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/client/lib:prng",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "compute_constant_test",
|
||||
srcs = ["compute_constant_test.cc"],
|
||||
|
215
tensorflow/compiler/xla/tests/dynamism_inference_test.cc
Normal file
215
tensorflow/compiler/xla/tests/dynamism_inference_test.cc
Normal file
@ -0,0 +1,215 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/global_data.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/prng.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
// An enumerator for the client types that we want to iterate over in
|
||||
// the various tests.
|
||||
enum class ClientType { kLocal, kCompileOnly };
|
||||
ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly};
|
||||
|
||||
class DynamismInferenceTest : public ::testing::Test {
|
||||
public:
|
||||
explicit DynamismInferenceTest(se::Platform* platform = nullptr)
|
||||
: platform_(platform) {}
|
||||
|
||||
string TestName() const {
|
||||
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
|
||||
}
|
||||
|
||||
Client* ClientOrDie(se::Platform* platform, ClientType client_type) {
|
||||
if (client_type == ClientType::kLocal) {
|
||||
StatusOr<Client*> result =
|
||||
ClientLibrary::GetOrCreateLocalClient(platform);
|
||||
TF_CHECK_OK(result.status())
|
||||
<< "could not create LocalClient for testing";
|
||||
return result.ValueOrDie();
|
||||
} else if (client_type == ClientType::kCompileOnly) {
|
||||
StatusOr<Client*> result =
|
||||
ClientLibrary::GetOrCreateCompileOnlyClient(platform);
|
||||
TF_CHECK_OK(result.status())
|
||||
<< "could not create CompileOnlyClient for testing";
|
||||
return result.ValueOrDie();
|
||||
}
|
||||
LOG(FATAL) << "invalid client_type value";
|
||||
}
|
||||
|
||||
StatusOr<Literal> ComputeDynamismLiteral(Client* client, XlaOp operand,
|
||||
XlaBuilder* builder,
|
||||
Layout* output_layout = nullptr) {
|
||||
TF_ASSIGN_OR_RETURN(auto subgraph,
|
||||
builder->BuildDynamicInferenceGraph(operand));
|
||||
TF_ASSIGN_OR_RETURN(auto computed,
|
||||
client->ComputeConstant(subgraph, output_layout));
|
||||
return std::move(computed);
|
||||
}
|
||||
|
||||
StatusOr<bool> ComputeDynamismScalar(Client* client, XlaOp operand,
|
||||
XlaBuilder* builder,
|
||||
ShapeIndex index = {}) {
|
||||
TF_ASSIGN_OR_RETURN(auto literal, ComputeDynamismLiteral(client, operand,
|
||||
builder, nullptr));
|
||||
return literal.Get<bool>({}, index);
|
||||
}
|
||||
|
||||
se::Platform* platform_;
|
||||
};
|
||||
|
||||
TEST_F(DynamismInferenceTest, ScalarInt32Literal) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto computation = ConstantR0<int32>(&b, 42);
|
||||
|
||||
auto value = ComputeDynamismScalar(client, computation, &b);
|
||||
ASSERT_TRUE(value.ok()) << value.status();
|
||||
// A constant is not dynamic.
|
||||
EXPECT_EQ(value.ValueOrDie(), false);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
|
||||
|
||||
auto tuple = Tuple(&b, {c, p});
|
||||
auto gte0 = GetTupleElement(tuple, 0);
|
||||
auto gte1 = GetTupleElement(tuple, 1);
|
||||
auto tuple_2 = Tuple(&b, {gte0, gte1});
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
|
||||
false);
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
|
||||
true);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
|
||||
|
||||
auto concat = ConcatScalars(&b, {c, p});
|
||||
auto slice0 = SliceInDim(concat, 0, 1, 1, 0);
|
||||
auto reshape0 = Reshape(slice0, {});
|
||||
auto slice1 = SliceInDim(concat, 1, 2, 1, 0);
|
||||
auto reshape1 = Reshape(slice1, {});
|
||||
auto tuple_2 = Tuple(&b, {reshape0, reshape1});
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
|
||||
false);
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
|
||||
true);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, ParameterIsDynamic) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
|
||||
|
||||
auto value = ComputeDynamismScalar(client, computation, &b);
|
||||
ASSERT_TRUE(value.ok()) << value.status();
|
||||
// A parameter is considered dynamic.
|
||||
EXPECT_EQ(value.ValueOrDie(), true);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
|
||||
|
||||
auto neg0 = Neg(c);
|
||||
auto neg1 = Neg(p);
|
||||
auto tuple_2 = Tuple(&b, {neg0, neg1});
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
|
||||
false);
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
|
||||
true);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
auto c = ConstantR0<int32>(&b, 42);
|
||||
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
|
||||
|
||||
// Static value + static value = static
|
||||
auto add1 = Add(c, c);
|
||||
// Dynamic value + dynamic value = dynamic
|
||||
auto add2 = Add(p, c);
|
||||
auto tuple_2 = Tuple(&b, {add1, add2});
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
|
||||
false);
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
|
||||
true);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DynamismInferenceTest, GetDimensionSize) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
XlaBuilder b(TestName());
|
||||
// param = Param([<=2, 3])
|
||||
// get_dimension_size(param, 0) is dynamic
|
||||
// get_dimension_size(param, 1) is static
|
||||
auto p =
|
||||
Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "0");
|
||||
|
||||
auto gds0 = GetDimensionSize(p, 0);
|
||||
auto gds1 = GetDimensionSize(p, 1);
|
||||
auto tuple_2 = Tuple(&b, {gds0, gds1});
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
|
||||
true);
|
||||
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
|
||||
false);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
Loading…
Reference in New Issue
Block a user