[XLA] Use dynamism inference to infer dynamic dimensions for reshape.

- Introduce dynamism inference function in xla builder, which tells if a value is dynamic or static.
- Use dynamism inference to infer whether an input to reshape's dimensions is dynamic.
- This removes the "-1" hack I made before in the bridge, makes the code cleaner. Plus it can support more complex cases dynamic reshape when the dimension comes from a series of transformations.

PiperOrigin-RevId: 325532056
Change-Id: Icc5bad39a857be77537e4736dd6863b833e2fe9d
This commit is contained in:
Yunxing Dai 2020-08-07 16:36:02 -07:00 committed by TensorFlower Gardener
parent ff457d4d01
commit 2e3e2bb335
11 changed files with 630 additions and 20 deletions

View File

@ -109,27 +109,33 @@ class ReshapeOp : public XlaOpKernel {
VLOG(2) << "Reshape from " << input_shape.DebugString() << " to "
<< shape.DebugString() << ", unknown_index=" << unknown_index;
shape_input.clear();
// Run get input again, this time with dynamic dimension represented as
// "-1"
ctx->set_dynamic_dimension_is_minus_one(true);
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input));
int dynamic_dimension = -1;
for (int d = 0; d < num_dims; ++d) {
const int32 size = shape_input[d];
if (size == -1) {
if (dynamic_dimension == -1) {
if (ctx->InputXlaShape(0)->is_dynamic()) {
std::vector<bool> dynamic_dims;
OP_REQUIRES_OK(ctx,
ctx->ResolveInputDynamismIntoPredVector(1, &dynamic_dims));
for (int d = 0; d < num_dims; ++d) {
const bool dim_is_dynamic = dynamic_dims[d];
if (dim_is_dynamic) {
dynamic_dimension = d;
} else {
if (unknown_index != d) {
dynamic_dimension = d;
}
}
}
}
// When reshaping from dynamic dimension, unkwown index is considered
// dynamic. E.g.,
// [<=10]
// |
// Reshape
// |
// [2, -1]
// The second dimension is dynamic.
if (dynamic_dimension == -1) {
dynamic_dimension = unknown_index;
}
VLOG(2) << "Reshape from " << ctx->InputXlaShape(0)->ToString() << " to "
<< xla::VectorString(shape.dim_sizes())
<< ", dynamic_dim=" << dynamic_dimension;
}
// Pass unknown_index to Xla::Reshape as a hint for dynamic shape inference
// in XLA to know which output dimension is dynamic.
ctx->SetOutput(0, xla::ReshapeWithInferredDimension(

View File

@ -101,6 +101,48 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const {
});
}
xla::StatusOr<Tensor> XlaExpression::ResolveDynamism(
xla::Client* client) const {
switch (kind()) {
case Kind::kConstant: {
// Constant values are considered static.
Tensor constant_false(DT_BOOL, constant_value().shape());
auto flat = constant_false.flat<bool>();
for (int64 i = 0; i < flat.size(); ++i) flat(i) = false;
return constant_false;
}
case Kind::kXlaOp:
break;
case Kind::kTensorList:
TF_FALLTHROUGH_INTENDED;
case Kind::kResource:
TF_FALLTHROUGH_INTENDED;
case Kind::kInvalid:
return errors::InvalidArgument(
"ResolveDynamism called on unsupported XlaExpression: ",
HumanString());
}
if (!client)
return errors::InvalidArgument("client is required to resolve constant");
TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
handle().builder()->BuildDynamicInferenceGraph(handle()));
TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
// The XLA layout is specified minor to major, and TensorFlow uses a major to
// minor order.
std::vector<int64> layout_indices(shape.dims());
std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
TF_ASSIGN_OR_RETURN(xla::Literal literal,
client->ComputeConstant(constant_graph, &layout));
Tensor tensor(DT_BOOL);
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, DT_BOOL, &tensor));
return tensor;
}
xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
xla::Client* client, bool dynamic_dimension_is_minus_one) const {
switch (kind()) {

View File

@ -99,6 +99,10 @@ class XlaExpression {
xla::StatusOr<absl::optional<Tensor>> ResolveConstant(
xla::Client* client, bool dynamic_dimension_is_minus_one = false) const;
// ResolveDynamism computes where a value inside this op is dynamic or can be
// inferred at compile time.
xla::StatusOr<Tensor> ResolveDynamism(xla::Client* client) const;
// Returns the shape of the tensor.
// The shape of a resource is the shape of a resource handle (i.e., a scalar),
// not the shape of the resource's value.

View File

@ -243,6 +243,48 @@ Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) {
return LiteralToFloat64Scalar(literal, out);
}
static Status LiteralToPredVector(const xla::LiteralSlice& literal,
std::vector<bool>* out) {
if (literal.shape().rank() != 1) {
return errors::InvalidArgument("value is not 1D, rank: ",
literal.shape().rank());
}
int64 size = xla::ShapeUtil::ElementsIn(literal.shape());
if (literal.shape().element_type() != xla::PRED) {
return errors::InvalidArgument("value is not PRED");
}
for (int64 i = 0; i < size; ++i) {
out->push_back(literal.Get<bool>({i}));
}
return Status::OK();
}
Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
int index, std::vector<bool>* out) {
xla::Literal literal;
XlaExpression e = InputExpression(index);
auto* client = compiler() ? compiler()->client() : nullptr;
xla::StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
if (!dynamism_or_status.ok()) {
Status status = dynamism_or_status.status();
errors::AppendToMessage(&status, "while evaluating input dynamism", index,
" of ", context_->op_kernel().type_string());
return status;
}
Tensor dynamism = dynamism_or_status.ValueOrDie();
Tensor temp(dynamism.dtype());
TensorShape tensor_shape({InputShape(index).num_elements()});
if (!temp.CopyFrom(dynamism, tensor_shape)) {
return errors::InvalidArgument(
context_->op_kernel().name(), " input ", index, " has shape ",
dynamism.shape().DebugString(), " which is not a R1 ", tensor_shape);
}
TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp));
return LiteralToPredVector(literal, out);
}
// Converts an int32 or int64 1D literal to an int64 vector.
static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
std::vector<int64>* out) {

View File

@ -116,6 +116,9 @@ class XlaOpKernelContext {
// returns a one-element list.
Status InputList(absl::string_view name, std::vector<xla::XlaOp>* handles,
std::vector<TensorShape>* shapes);
// Evaluates input and returns their dynamism vector in a vector of
// predicates.
Status ResolveInputDynamismIntoPredVector(int index, std::vector<bool>* out);
// Helper methods for constant inputs.

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@ -39,6 +40,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/errors.h"
namespace xla {
@ -71,6 +73,52 @@ void SetProtoIdAndName(T* entry, const string& base_name, char separator,
entry->set_id(id);
entry->set_name(GetFullName(base_name, separator, id));
}
ShapeProto ConvertShapeProtoToPred(const ShapeProto& shape_proto) {
return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto();
}
HloInstructionProto CreateConstantInstruction(int64 id, const Shape& shape,
bool pred) {
HloInstructionProto const_instr;
Literal literal = LiteralUtil::CreateR0(pred);
Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie();
*const_instr.mutable_shape() = shape.ToProto();
*const_instr.mutable_literal() = literal_broadcast.ToProto();
*const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
const_instr.set_id(id);
return const_instr;
}
// Converts a HloComputation into ReducerOr with predicate types.
HloComputationProto CreateReduceOr(int64 reducer_id,
HloComputationProto* original_reducer) {
HloComputationProto reducer;
SetProtoIdAndName(&reducer, StrCat("reduce_or"), kNameSeparator, reducer_id);
std::vector<int64> operands_id;
for (auto& inst : original_reducer->instructions()) {
// Copy params.
if (StringToHloOpcode(inst.opcode()).ValueOrDie() ==
HloOpcode::kParameter) {
HloInstructionProto* new_param = reducer.add_instructions();
*new_param = inst;
*new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
operands_id.push_back(inst.id());
}
if (inst.id() == original_reducer->root_id()) {
HloInstructionProto* new_root = reducer.add_instructions();
*new_root = inst;
*new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
*new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
new_root->clear_operand_ids();
for (int64 operand_id : operands_id) {
new_root->add_operand_ids(operand_id);
}
reducer.set_root_id(inst.id());
}
}
return reducer;
}
} // namespace
namespace internal {
@ -2842,6 +2890,196 @@ StatusOr<bool> XlaBuilder::IsConstant(XlaOp operand) const {
return is_constant;
}
StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
LookUpInstruction(root_op));
HloComputationProto entry;
SetProtoIdAndName(&entry, StrCat(name_, "_dynamic_inference"), kNameSeparator,
GetNextId());
ProgramShapeProto* program_shape = entry.mutable_program_shape();
*program_shape->mutable_result() =
ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto();
std::set<int64> seen;
struct WorkItem {
explicit WorkItem(int64 handle, bool need_rewrite)
: handle(handle), need_rewrite(need_rewrite) {}
int64 handle;
// If need_rewrite is true, the instruction will be copied and rewrite into
// a pred instruction indicating if each value is dynamic. If need_rewrite
// is false, simply copy the instruction to the output graph.
// E.g.,
// For select(P, A, B), we need to rewrite A and B into predicates, but
// don't need to rewrite P.
bool need_rewrite;
};
std::queue<WorkItem> worklist;
worklist.push(WorkItem(root->id(), true));
entry.set_root_id(root->id());
std::vector<HloComputationProto> called_computatons;
// Rewritre instruction with id "from" into the new graph.
// Returns more work items that need to finish.
auto rewrite_instruction =
[&](int64 from, bool need_rewrite) -> StatusOr<std::vector<WorkItem>> {
// Rewrite the instruction with following rules:
// - Unary ops: Convert into bitcast (identity) with type Pred.
// - Binary ops: Convert into binary or.
// - Select: Convert into binary or with its two data operands.
// - Concat / Tuple/ GTE / Bitcast: Copy.
// - Param: Convert to constant True.
// - GetDimensionSize: Convert to constant True if dimension is dynamic,
// contant False if dimension is static.
// - Reduce: Convert to reduce or.
// - Constant: Convert to constant False.
// - Other ops: Not supported.
// Create the instruction for the new handle.
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
LookUpInstructionByHandle(from));
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
StringToHloOpcode(instr_proto->opcode()));
std::vector<WorkItem> operands_todo;
auto* new_instr = entry.add_instructions();
*new_instr = *instr_proto;
for (auto operand_id : new_instr->operand_ids()) {
operands_todo.emplace_back(operand_id, need_rewrite);
}
if (!need_rewrite) {
*new_instr->mutable_name() =
GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id());
return operands_todo;
}
*new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape());
Shape new_shape(new_instr->shape());
switch (opcode) {
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kBitcast:
case HloOpcode::kCeil:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kCos:
case HloOpcode::kClz:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal:
case HloOpcode::kRsqrt:
case HloOpcode::kLogistic:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kConvert:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kTanh:
CHECK_EQ(instr_proto->operand_ids_size(), 1);
*new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kBitcast);
break;
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
case HloOpcode::kDivide:
case HloOpcode::kComplex:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kSubtract:
case HloOpcode::kCompare:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
CHECK_EQ(instr_proto->operand_ids_size(), 2);
*new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
break;
case HloOpcode::kSelect:
operands_todo[0].need_rewrite = false;
break;
case HloOpcode::kGather:
operands_todo[1].need_rewrite = false;
break;
case HloOpcode::kReduce: {
int64 reducer_id = new_instr->called_computation_ids(0);
called_computatons.push_back(
CreateReduceOr(reducer_id, &embedded_[reducer_id]));
break;
}
case HloOpcode::kTuple:
case HloOpcode::kTranspose:
case HloOpcode::kGetTupleElement:
case HloOpcode::kSlice:
case HloOpcode::kBroadcast:
case HloOpcode::kConcatenate:
case HloOpcode::kReshape:
break;
case HloOpcode::kGetDimensionSize: {
int64 dimension = instr_proto->dimensions(0);
int64 operand_handle = instr_proto->operand_ids(0);
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
LookUpInstructionByHandle(operand_handle));
*new_instr = CreateConstantInstruction(
from, new_shape,
operand_proto->shape().is_dynamic_dimension(dimension));
operands_todo.clear();
break;
}
case HloOpcode::kConstant:
*new_instr = CreateConstantInstruction(from, new_shape, false);
break;
case HloOpcode::kParameter:
*new_instr = CreateConstantInstruction(from, new_shape, true);
break;
default:
return InvalidArgument("Dynamic inferencing %s is not supported",
instr_proto->DebugString());
}
*new_instr->mutable_name() =
GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id());
return operands_todo;
};
while (!worklist.empty()) {
WorkItem item = worklist.front();
worklist.pop();
if (!seen.insert(item.handle).second) {
continue;
}
TF_ASSIGN_OR_RETURN(auto todos,
rewrite_instruction(item.handle, item.need_rewrite));
for (WorkItem& todo : todos) {
worklist.push(todo);
}
}
absl::c_sort(*entry.mutable_instructions(),
[](const HloInstructionProto& p1,
const HloInstructionProto& p2) { return p1.id() < p2.id(); });
XlaComputation computation(entry.id());
HloModuleProto* module = computation.mutable_proto();
module->set_name(entry.name());
module->set_id(entry.id());
module->set_entry_computation_name(entry.name());
module->set_entry_computation_id(entry.id());
*module->mutable_host_program_shape() = *program_shape;
for (auto& called_comp : called_computatons) {
*module->add_computations() = called_comp;
}
*module->add_computations() = std::move(entry);
XLA_VLOG_LINES(3, module->DebugString());
return std::move(computation);
}
StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
XlaOp root_op, bool dynamic_dimension_is_minus_one) {
TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));

View File

@ -278,6 +278,31 @@ class XlaBuilder {
StatusOr<XlaComputation> BuildConstantSubGraph(
XlaOp root_op, bool dynamic_dimension_is_uint_max = false);
// Similar to BuildConstantSubGraph, but with root element type changed to
// boolean. A true value in the root indicates that the value is dynamic while
// false value indicates that the value is a constant. This will copy the
// needed ops/computations to the subgraph.
//
// E.g.,
// Compuptation {
// a = 3
// b = param(0)
// ROOT Tuple(a + b, a + 1, b + 1)
// }
// Calling BuildDynamicInferenceGraph on root will produce the following
// graph:
//
// Compuptation {
// a = False
// b = True
// ROOT Tuple(a | b, a, b)
// }
//
// The result, which is (True, False, True) after evaluation, can be
// interpreted as "First element is dynamic; Second element is static; Third
// element is dynamic".
StatusOr<XlaComputation> BuildDynamicInferenceGraph(XlaOp root_op);
// Returns the first error that was encountered while building the
// computation. When an error is encountered, by default we return a vacuous
// XlaOp and inform the user of the error that occurred while

View File

@ -805,7 +805,8 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
}
if (input_dim_size > output_dim_size) {
TF_RET_CHECK(input_dim_size % output_dim_size == 0);
TF_RET_CHECK(input_dim_size % output_dim_size == 0)
<< reshape->ToString();
const int64 divisor = input_dim_size / output_dim_size;
HloInstruction* divisor_hlo =
hlo->parent()->AddInstruction(HloInstruction::CreateConstant(

View File

@ -783,9 +783,18 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
/* static */ Shape ShapeUtil::ChangeElementType(const Shape& original,
PrimitiveType type) {
Shape new_shape = original;
new_shape.set_element_type(type);
return new_shape;
if (original.IsTuple()) {
std::vector<Shape> new_operands;
new_operands.reserve(original.tuple_shapes_size());
for (const Shape& operand : original.tuple_shapes()) {
new_operands.push_back(ChangeElementType(operand, type));
}
return MakeTupleShape(new_operands);
} else {
Shape new_shape = original;
new_shape.set_element_type(type);
return new_shape;
}
}
/* static */ bool ShapeUtil::IndexIsValid(const Shape& shape,

View File

@ -2088,6 +2088,31 @@ xla_test(
],
)
xla_test(
name = "dynamism_inference_test",
srcs = ["dynamism_inference_test.cc"],
deps = [
":test_macros_header",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:prng",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/strings",
],
)
xla_test(
name = "compute_constant_test",
srcs = ["compute_constant_test.cc"],

View File

@ -0,0 +1,215 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <utility>
#include <vector>
#include "absl/strings/match.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/lib/prng.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace {
// An enumerator for the client types that we want to iterate over in
// the various tests.
enum class ClientType { kLocal, kCompileOnly };
ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly};
class DynamismInferenceTest : public ::testing::Test {
public:
explicit DynamismInferenceTest(se::Platform* platform = nullptr)
: platform_(platform) {}
string TestName() const {
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
}
Client* ClientOrDie(se::Platform* platform, ClientType client_type) {
if (client_type == ClientType::kLocal) {
StatusOr<Client*> result =
ClientLibrary::GetOrCreateLocalClient(platform);
TF_CHECK_OK(result.status())
<< "could not create LocalClient for testing";
return result.ValueOrDie();
} else if (client_type == ClientType::kCompileOnly) {
StatusOr<Client*> result =
ClientLibrary::GetOrCreateCompileOnlyClient(platform);
TF_CHECK_OK(result.status())
<< "could not create CompileOnlyClient for testing";
return result.ValueOrDie();
}
LOG(FATAL) << "invalid client_type value";
}
StatusOr<Literal> ComputeDynamismLiteral(Client* client, XlaOp operand,
XlaBuilder* builder,
Layout* output_layout = nullptr) {
TF_ASSIGN_OR_RETURN(auto subgraph,
builder->BuildDynamicInferenceGraph(operand));
TF_ASSIGN_OR_RETURN(auto computed,
client->ComputeConstant(subgraph, output_layout));
return std::move(computed);
}
StatusOr<bool> ComputeDynamismScalar(Client* client, XlaOp operand,
XlaBuilder* builder,
ShapeIndex index = {}) {
TF_ASSIGN_OR_RETURN(auto literal, ComputeDynamismLiteral(client, operand,
builder, nullptr));
return literal.Get<bool>({}, index);
}
se::Platform* platform_;
};
TEST_F(DynamismInferenceTest, ScalarInt32Literal) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto computation = ConstantR0<int32>(&b, 42);
auto value = ComputeDynamismScalar(client, computation, &b);
ASSERT_TRUE(value.ok()) << value.status();
// A constant is not dynamic.
EXPECT_EQ(value.ValueOrDie(), false);
}
}
TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
auto tuple = Tuple(&b, {c, p});
auto gte0 = GetTupleElement(tuple, 0);
auto gte1 = GetTupleElement(tuple, 1);
auto tuple_2 = Tuple(&b, {gte0, gte1});
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
false);
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
true);
}
}
TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
auto concat = ConcatScalars(&b, {c, p});
auto slice0 = SliceInDim(concat, 0, 1, 1, 0);
auto reshape0 = Reshape(slice0, {});
auto slice1 = SliceInDim(concat, 1, 2, 1, 0);
auto reshape1 = Reshape(slice1, {});
auto tuple_2 = Tuple(&b, {reshape0, reshape1});
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
false);
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
true);
}
}
TEST_F(DynamismInferenceTest, ParameterIsDynamic) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
auto value = ComputeDynamismScalar(client, computation, &b);
ASSERT_TRUE(value.ok()) << value.status();
// A parameter is considered dynamic.
EXPECT_EQ(value.ValueOrDie(), true);
}
}
TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
auto neg0 = Neg(c);
auto neg1 = Neg(p);
auto tuple_2 = Tuple(&b, {neg0, neg1});
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
false);
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
true);
}
}
TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto c = ConstantR0<int32>(&b, 42);
auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
// Static value + static value = static
auto add1 = Add(c, c);
// Dynamic value + dynamic value = dynamic
auto add2 = Add(p, c);
auto tuple_2 = Tuple(&b, {add1, add2});
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
false);
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
true);
}
}
TEST_F(DynamismInferenceTest, GetDimensionSize) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
// param = Param([<=2, 3])
// get_dimension_size(param, 0) is dynamic
// get_dimension_size(param, 1) is static
auto p =
Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "0");
auto gds0 = GetDimensionSize(p, 0);
auto gds1 = GetDimensionSize(p, 1);
auto tuple_2 = Tuple(&b, {gds0, gds1});
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
true);
EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
false);
}
}
} // namespace
} // namespace xla