From 2e3e2bb33559bef5b76cb1e5bc745a75488efaff Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Fri, 7 Aug 2020 16:36:02 -0700 Subject: [PATCH] [XLA] Use dynamism inference to infer dynamic dimensions for reshape. - Introduce dynamism inference function in xla builder, which tells if a value is dynamic or static. - Use dynamism inference to infer whether an input to reshape's dimensions is dynamic. - This removes the "-1" hack I made before in the bridge, makes the code cleaner. Plus it can support more complex cases dynamic reshape when the dimension comes from a series of transformations. PiperOrigin-RevId: 325532056 Change-Id: Icc5bad39a857be77537e4736dd6863b833e2fe9d --- .../compiler/tf2xla/kernels/reshape_op.cc | 38 +-- tensorflow/compiler/tf2xla/xla_expression.cc | 42 ++++ tensorflow/compiler/tf2xla/xla_expression.h | 4 + tensorflow/compiler/tf2xla/xla_op_kernel.cc | 42 ++++ tensorflow/compiler/tf2xla/xla_op_kernel.h | 3 + tensorflow/compiler/xla/client/xla_builder.cc | 238 ++++++++++++++++++ tensorflow/compiler/xla/client/xla_builder.h | 25 ++ .../service/dynamic_dimension_inference.cc | 3 +- tensorflow/compiler/xla/shape_util.cc | 15 +- tensorflow/compiler/xla/tests/BUILD | 25 ++ .../xla/tests/dynamism_inference_test.cc | 215 ++++++++++++++++ 11 files changed, 630 insertions(+), 20 deletions(-) create mode 100644 tensorflow/compiler/xla/tests/dynamism_inference_test.cc diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index bf9a9150ea6..a85ba547179 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -109,27 +109,33 @@ class ReshapeOp : public XlaOpKernel { VLOG(2) << "Reshape from " << input_shape.DebugString() << " to " << shape.DebugString() << ", unknown_index=" << unknown_index; - shape_input.clear(); - // Run get input again, this time with dynamic dimension represented as - // "-1" - ctx->set_dynamic_dimension_is_minus_one(true); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input)); - int dynamic_dimension = -1; - - for (int d = 0; d < num_dims; ++d) { - const int32 size = shape_input[d]; - if (size == -1) { - if (dynamic_dimension == -1) { + if (ctx->InputXlaShape(0)->is_dynamic()) { + std::vector dynamic_dims; + OP_REQUIRES_OK(ctx, + ctx->ResolveInputDynamismIntoPredVector(1, &dynamic_dims)); + for (int d = 0; d < num_dims; ++d) { + const bool dim_is_dynamic = dynamic_dims[d]; + if (dim_is_dynamic) { dynamic_dimension = d; - } else { - if (unknown_index != d) { - dynamic_dimension = d; - } } } - } + // When reshaping from dynamic dimension, unkwown index is considered + // dynamic. E.g., + // [<=10] + // | + // Reshape + // | + // [2, -1] + // The second dimension is dynamic. + if (dynamic_dimension == -1) { + dynamic_dimension = unknown_index; + } + VLOG(2) << "Reshape from " << ctx->InputXlaShape(0)->ToString() << " to " + << xla::VectorString(shape.dim_sizes()) + << ", dynamic_dim=" << dynamic_dimension; + } // Pass unknown_index to Xla::Reshape as a hint for dynamic shape inference // in XLA to know which output dimension is dynamic. ctx->SetOutput(0, xla::ReshapeWithInferredDimension( diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc index 34e108bb6bf..f0cc8d26709 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.cc +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -101,6 +101,48 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const { }); } +xla::StatusOr XlaExpression::ResolveDynamism( + xla::Client* client) const { + switch (kind()) { + case Kind::kConstant: { + // Constant values are considered static. + Tensor constant_false(DT_BOOL, constant_value().shape()); + auto flat = constant_false.flat(); + for (int64 i = 0; i < flat.size(); ++i) flat(i) = false; + return constant_false; + } + case Kind::kXlaOp: + break; + case Kind::kTensorList: + TF_FALLTHROUGH_INTENDED; + case Kind::kResource: + TF_FALLTHROUGH_INTENDED; + case Kind::kInvalid: + return errors::InvalidArgument( + "ResolveDynamism called on unsupported XlaExpression: ", + HumanString()); + } + + if (!client) + return errors::InvalidArgument("client is required to resolve constant"); + + TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph, + handle().builder()->BuildDynamicInferenceGraph(handle())); + + TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape()); + + // The XLA layout is specified minor to major, and TensorFlow uses a major to + // minor order. + std::vector layout_indices(shape.dims()); + std::iota(layout_indices.rbegin(), layout_indices.rend(), 0); + xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices); + TF_ASSIGN_OR_RETURN(xla::Literal literal, + client->ComputeConstant(constant_graph, &layout)); + Tensor tensor(DT_BOOL); + TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, DT_BOOL, &tensor)); + return tensor; +} + xla::StatusOr> XlaExpression::ResolveConstant( xla::Client* client, bool dynamic_dimension_is_minus_one) const { switch (kind()) { diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h index 3010964c5b7..3546368ff7b 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.h +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -99,6 +99,10 @@ class XlaExpression { xla::StatusOr> ResolveConstant( xla::Client* client, bool dynamic_dimension_is_minus_one = false) const; + // ResolveDynamism computes where a value inside this op is dynamic or can be + // inferred at compile time. + xla::StatusOr ResolveDynamism(xla::Client* client) const; + // Returns the shape of the tensor. // The shape of a resource is the shape of a resource handle (i.e., a scalar), // not the shape of the resource's value. diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 735a6c7291e..07537546d52 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -243,6 +243,48 @@ Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) { return LiteralToFloat64Scalar(literal, out); } +static Status LiteralToPredVector(const xla::LiteralSlice& literal, + std::vector* out) { + if (literal.shape().rank() != 1) { + return errors::InvalidArgument("value is not 1D, rank: ", + literal.shape().rank()); + } + int64 size = xla::ShapeUtil::ElementsIn(literal.shape()); + if (literal.shape().element_type() != xla::PRED) { + return errors::InvalidArgument("value is not PRED"); + } + for (int64 i = 0; i < size; ++i) { + out->push_back(literal.Get({i})); + } + return Status::OK(); +} + +Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector( + int index, std::vector* out) { + xla::Literal literal; + XlaExpression e = InputExpression(index); + auto* client = compiler() ? compiler()->client() : nullptr; + xla::StatusOr dynamism_or_status = e.ResolveDynamism(client); + if (!dynamism_or_status.ok()) { + Status status = dynamism_or_status.status(); + errors::AppendToMessage(&status, "while evaluating input dynamism", index, + " of ", context_->op_kernel().type_string()); + return status; + } + Tensor dynamism = dynamism_or_status.ValueOrDie(); + + Tensor temp(dynamism.dtype()); + TensorShape tensor_shape({InputShape(index).num_elements()}); + if (!temp.CopyFrom(dynamism, tensor_shape)) { + return errors::InvalidArgument( + context_->op_kernel().name(), " input ", index, " has shape ", + dynamism.shape().DebugString(), " which is not a R1 ", tensor_shape); + } + + TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp)); + return LiteralToPredVector(literal, out); +} + // Converts an int32 or int64 1D literal to an int64 vector. static Status LiteralToInt64Vector(const xla::LiteralSlice& literal, std::vector* out) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 3cf51e6ec6f..75c3e60171a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -116,6 +116,9 @@ class XlaOpKernelContext { // returns a one-element list. Status InputList(absl::string_view name, std::vector* handles, std::vector* shapes); + // Evaluates input and returns their dynamism vector in a vector of + // predicates. + Status ResolveInputDynamismIntoPredVector(int index, std::vector* out); // Helper methods for constant inputs. diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 484fb0aabe7..8de8216c005 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -39,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/errors.h" namespace xla { @@ -71,6 +73,52 @@ void SetProtoIdAndName(T* entry, const string& base_name, char separator, entry->set_id(id); entry->set_name(GetFullName(base_name, separator, id)); } + +ShapeProto ConvertShapeProtoToPred(const ShapeProto& shape_proto) { + return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto(); +} + +HloInstructionProto CreateConstantInstruction(int64 id, const Shape& shape, + bool pred) { + HloInstructionProto const_instr; + Literal literal = LiteralUtil::CreateR0(pred); + Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie(); + *const_instr.mutable_shape() = shape.ToProto(); + *const_instr.mutable_literal() = literal_broadcast.ToProto(); + *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant); + const_instr.set_id(id); + return const_instr; +} + +// Converts a HloComputation into ReducerOr with predicate types. +HloComputationProto CreateReduceOr(int64 reducer_id, + HloComputationProto* original_reducer) { + HloComputationProto reducer; + SetProtoIdAndName(&reducer, StrCat("reduce_or"), kNameSeparator, reducer_id); + std::vector operands_id; + for (auto& inst : original_reducer->instructions()) { + // Copy params. + if (StringToHloOpcode(inst.opcode()).ValueOrDie() == + HloOpcode::kParameter) { + HloInstructionProto* new_param = reducer.add_instructions(); + *new_param = inst; + *new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape()); + operands_id.push_back(inst.id()); + } + if (inst.id() == original_reducer->root_id()) { + HloInstructionProto* new_root = reducer.add_instructions(); + *new_root = inst; + *new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape()); + *new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr); + new_root->clear_operand_ids(); + for (int64 operand_id : operands_id) { + new_root->add_operand_ids(operand_id); + } + reducer.set_root_id(inst.id()); + } + } + return reducer; +} } // namespace namespace internal { @@ -2842,6 +2890,196 @@ StatusOr XlaBuilder::IsConstant(XlaOp operand) const { return is_constant; } +StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { + TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, + LookUpInstruction(root_op)); + + HloComputationProto entry; + SetProtoIdAndName(&entry, StrCat(name_, "_dynamic_inference"), kNameSeparator, + GetNextId()); + ProgramShapeProto* program_shape = entry.mutable_program_shape(); + *program_shape->mutable_result() = + ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto(); + + std::set seen; + struct WorkItem { + explicit WorkItem(int64 handle, bool need_rewrite) + : handle(handle), need_rewrite(need_rewrite) {} + int64 handle; + // If need_rewrite is true, the instruction will be copied and rewrite into + // a pred instruction indicating if each value is dynamic. If need_rewrite + // is false, simply copy the instruction to the output graph. + // E.g., + // For select(P, A, B), we need to rewrite A and B into predicates, but + // don't need to rewrite P. + bool need_rewrite; + }; + std::queue worklist; + worklist.push(WorkItem(root->id(), true)); + entry.set_root_id(root->id()); + std::vector called_computatons; + // Rewritre instruction with id "from" into the new graph. + // Returns more work items that need to finish. + auto rewrite_instruction = + [&](int64 from, bool need_rewrite) -> StatusOr> { + // Rewrite the instruction with following rules: + // - Unary ops: Convert into bitcast (identity) with type Pred. + // - Binary ops: Convert into binary or. + // - Select: Convert into binary or with its two data operands. + // - Concat / Tuple/ GTE / Bitcast: Copy. + // - Param: Convert to constant True. + // - GetDimensionSize: Convert to constant True if dimension is dynamic, + // contant False if dimension is static. + // - Reduce: Convert to reduce or. + // - Constant: Convert to constant False. + // - Other ops: Not supported. + // Create the instruction for the new handle. + TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto, + LookUpInstructionByHandle(from)); + + TF_ASSIGN_OR_RETURN(HloOpcode opcode, + StringToHloOpcode(instr_proto->opcode())); + std::vector operands_todo; + auto* new_instr = entry.add_instructions(); + *new_instr = *instr_proto; + for (auto operand_id : new_instr->operand_ids()) { + operands_todo.emplace_back(operand_id, need_rewrite); + } + + if (!need_rewrite) { + *new_instr->mutable_name() = + GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id()); + return operands_todo; + } + *new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape()); + Shape new_shape(new_instr->shape()); + switch (opcode) { + case HloOpcode::kAbs: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kBitcast: + case HloOpcode::kCeil: + case HloOpcode::kCollectivePermuteDone: + case HloOpcode::kCos: + case HloOpcode::kClz: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kFloor: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kNot: + case HloOpcode::kNegate: + case HloOpcode::kPopulationCount: + case HloOpcode::kReal: + case HloOpcode::kRsqrt: + case HloOpcode::kLogistic: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kConvert: + case HloOpcode::kSqrt: + case HloOpcode::kCbrt: + case HloOpcode::kTanh: + CHECK_EQ(instr_proto->operand_ids_size(), 1); + *new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kBitcast); + break; + case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kDivide: + case HloOpcode::kComplex: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kSubtract: + case HloOpcode::kCompare: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kXor: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + CHECK_EQ(instr_proto->operand_ids_size(), 2); + *new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr); + break; + case HloOpcode::kSelect: + operands_todo[0].need_rewrite = false; + break; + case HloOpcode::kGather: + operands_todo[1].need_rewrite = false; + break; + case HloOpcode::kReduce: { + int64 reducer_id = new_instr->called_computation_ids(0); + called_computatons.push_back( + CreateReduceOr(reducer_id, &embedded_[reducer_id])); + break; + } + case HloOpcode::kTuple: + case HloOpcode::kTranspose: + case HloOpcode::kGetTupleElement: + case HloOpcode::kSlice: + case HloOpcode::kBroadcast: + case HloOpcode::kConcatenate: + case HloOpcode::kReshape: + break; + case HloOpcode::kGetDimensionSize: { + int64 dimension = instr_proto->dimensions(0); + int64 operand_handle = instr_proto->operand_ids(0); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, + LookUpInstructionByHandle(operand_handle)); + + *new_instr = CreateConstantInstruction( + from, new_shape, + operand_proto->shape().is_dynamic_dimension(dimension)); + operands_todo.clear(); + break; + } + case HloOpcode::kConstant: + *new_instr = CreateConstantInstruction(from, new_shape, false); + break; + case HloOpcode::kParameter: + *new_instr = CreateConstantInstruction(from, new_shape, true); + break; + default: + return InvalidArgument("Dynamic inferencing %s is not supported", + instr_proto->DebugString()); + } + *new_instr->mutable_name() = + GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id()); + return operands_todo; + }; + + while (!worklist.empty()) { + WorkItem item = worklist.front(); + worklist.pop(); + if (!seen.insert(item.handle).second) { + continue; + } + TF_ASSIGN_OR_RETURN(auto todos, + rewrite_instruction(item.handle, item.need_rewrite)); + for (WorkItem& todo : todos) { + worklist.push(todo); + } + } + absl::c_sort(*entry.mutable_instructions(), + [](const HloInstructionProto& p1, + const HloInstructionProto& p2) { return p1.id() < p2.id(); }); + XlaComputation computation(entry.id()); + HloModuleProto* module = computation.mutable_proto(); + module->set_name(entry.name()); + module->set_id(entry.id()); + module->set_entry_computation_name(entry.name()); + module->set_entry_computation_id(entry.id()); + *module->mutable_host_program_shape() = *program_shape; + for (auto& called_comp : called_computatons) { + *module->add_computations() = called_comp; + } + *module->add_computations() = std::move(entry); + XLA_VLOG_LINES(3, module->DebugString()); + return std::move(computation); +} + StatusOr XlaBuilder::BuildConstantSubGraph( XlaOp root_op, bool dynamic_dimension_is_minus_one) { TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op)); diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index aa5074d28d9..6753b6dd919 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -278,6 +278,31 @@ class XlaBuilder { StatusOr BuildConstantSubGraph( XlaOp root_op, bool dynamic_dimension_is_uint_max = false); + // Similar to BuildConstantSubGraph, but with root element type changed to + // boolean. A true value in the root indicates that the value is dynamic while + // false value indicates that the value is a constant. This will copy the + // needed ops/computations to the subgraph. + // + // E.g., + // Compuptation { + // a = 3 + // b = param(0) + // ROOT Tuple(a + b, a + 1, b + 1) + // } + // Calling BuildDynamicInferenceGraph on root will produce the following + // graph: + // + // Compuptation { + // a = False + // b = True + // ROOT Tuple(a | b, a, b) + // } + // + // The result, which is (True, False, True) after evaluation, can be + // interpreted as "First element is dynamic; Second element is static; Third + // element is dynamic". + StatusOr BuildDynamicInferenceGraph(XlaOp root_op); + // Returns the first error that was encountered while building the // computation. When an error is encountered, by default we return a vacuous // XlaOp and inform the user of the error that occurred while diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 2f2456863e9..36429d3d755 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -805,7 +805,8 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { } if (input_dim_size > output_dim_size) { - TF_RET_CHECK(input_dim_size % output_dim_size == 0); + TF_RET_CHECK(input_dim_size % output_dim_size == 0) + << reshape->ToString(); const int64 divisor = input_dim_size / output_dim_size; HloInstruction* divisor_hlo = hlo->parent()->AddInstruction(HloInstruction::CreateConstant( diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 02fcaafd19d..0833919b124 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -783,9 +783,18 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ Shape ShapeUtil::ChangeElementType(const Shape& original, PrimitiveType type) { - Shape new_shape = original; - new_shape.set_element_type(type); - return new_shape; + if (original.IsTuple()) { + std::vector new_operands; + new_operands.reserve(original.tuple_shapes_size()); + for (const Shape& operand : original.tuple_shapes()) { + new_operands.push_back(ChangeElementType(operand, type)); + } + return MakeTupleShape(new_operands); + } else { + Shape new_shape = original; + new_shape.set_element_type(type); + return new_shape; + } } /* static */ bool ShapeUtil::IndexIsValid(const Shape& shape, diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 927f9d14883..17444c042e7 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -2088,6 +2088,31 @@ xla_test( ], ) +xla_test( + name = "dynamism_inference_test", + srcs = ["dynamism_inference_test.cc"], + deps = [ + ":test_macros_header", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:prng", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/strings", + ], +) + xla_test( name = "compute_constant_test", srcs = ["compute_constant_test.cc"], diff --git a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc new file mode 100644 index 00000000000..ba4092def16 --- /dev/null +++ b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc @@ -0,0 +1,215 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/strings/match.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/lib/prng.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +// An enumerator for the client types that we want to iterate over in +// the various tests. +enum class ClientType { kLocal, kCompileOnly }; +ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly}; + +class DynamismInferenceTest : public ::testing::Test { + public: + explicit DynamismInferenceTest(se::Platform* platform = nullptr) + : platform_(platform) {} + + string TestName() const { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); + } + + Client* ClientOrDie(se::Platform* platform, ClientType client_type) { + if (client_type == ClientType::kLocal) { + StatusOr result = + ClientLibrary::GetOrCreateLocalClient(platform); + TF_CHECK_OK(result.status()) + << "could not create LocalClient for testing"; + return result.ValueOrDie(); + } else if (client_type == ClientType::kCompileOnly) { + StatusOr result = + ClientLibrary::GetOrCreateCompileOnlyClient(platform); + TF_CHECK_OK(result.status()) + << "could not create CompileOnlyClient for testing"; + return result.ValueOrDie(); + } + LOG(FATAL) << "invalid client_type value"; + } + + StatusOr ComputeDynamismLiteral(Client* client, XlaOp operand, + XlaBuilder* builder, + Layout* output_layout = nullptr) { + TF_ASSIGN_OR_RETURN(auto subgraph, + builder->BuildDynamicInferenceGraph(operand)); + TF_ASSIGN_OR_RETURN(auto computed, + client->ComputeConstant(subgraph, output_layout)); + return std::move(computed); + } + + StatusOr ComputeDynamismScalar(Client* client, XlaOp operand, + XlaBuilder* builder, + ShapeIndex index = {}) { + TF_ASSIGN_OR_RETURN(auto literal, ComputeDynamismLiteral(client, operand, + builder, nullptr)); + return literal.Get({}, index); + } + + se::Platform* platform_; +}; + +TEST_F(DynamismInferenceTest, ScalarInt32Literal) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto computation = ConstantR0(&b, 42); + + auto value = ComputeDynamismScalar(client, computation, &b); + ASSERT_TRUE(value.ok()) << value.status(); + // A constant is not dynamic. + EXPECT_EQ(value.ValueOrDie(), false); + } +} + +TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); + + auto tuple = Tuple(&b, {c, p}); + auto gte0 = GetTupleElement(tuple, 0); + auto gte1 = GetTupleElement(tuple, 1); + auto tuple_2 = Tuple(&b, {gte0, gte1}); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), + false); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), + true); + } +} + +TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); + + auto concat = ConcatScalars(&b, {c, p}); + auto slice0 = SliceInDim(concat, 0, 1, 1, 0); + auto reshape0 = Reshape(slice0, {}); + auto slice1 = SliceInDim(concat, 1, 2, 1, 0); + auto reshape1 = Reshape(slice1, {}); + auto tuple_2 = Tuple(&b, {reshape0, reshape1}); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), + false); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), + true); + } +} + +TEST_F(DynamismInferenceTest, ParameterIsDynamic) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); + + auto value = ComputeDynamismScalar(client, computation, &b); + ASSERT_TRUE(value.ok()) << value.status(); + // A parameter is considered dynamic. + EXPECT_EQ(value.ValueOrDie(), true); + } +} + +TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); + + auto neg0 = Neg(c); + auto neg1 = Neg(p); + auto tuple_2 = Tuple(&b, {neg0, neg1}); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), + false); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), + true); + } +} + +TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); + + // Static value + static value = static + auto add1 = Add(c, c); + // Dynamic value + dynamic value = dynamic + auto add2 = Add(p, c); + auto tuple_2 = Tuple(&b, {add1, add2}); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), + false); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), + true); + } +} + +TEST_F(DynamismInferenceTest, GetDimensionSize) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + // param = Param([<=2, 3]) + // get_dimension_size(param, 0) is dynamic + // get_dimension_size(param, 1) is static + auto p = + Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "0"); + + auto gds0 = GetDimensionSize(p, 0); + auto gds1 = GetDimensionSize(p, 1); + auto tuple_2 = Tuple(&b, {gds0, gds1}); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), + true); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), + false); + } +} + +} // namespace +} // namespace xla