[XLA] Enable interpreter backend test by default for xla/tests without having to
using additional backend tags "enable_for_xla_interpreter". PiperOrigin-RevId: 227598976
This commit is contained in:
parent
0b6177c2fa
commit
dd3adf935a
@ -93,7 +93,6 @@ cc_library(
|
||||
xla_test(
|
||||
name = "constants_test",
|
||||
srcs = ["constants_test.cc"],
|
||||
tags = ["enable_for_xla_interpreter"],
|
||||
deps = [
|
||||
":constants",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -147,7 +146,6 @@ cc_library(
|
||||
xla_test(
|
||||
name = "math_test",
|
||||
srcs = ["math_test.cc"],
|
||||
tags = ["enable_for_xla_interpreter"],
|
||||
deps = [
|
||||
":math",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
@ -181,7 +179,6 @@ cc_library(
|
||||
xla_test(
|
||||
name = "matrix_test",
|
||||
srcs = ["matrix_test.cc"],
|
||||
tags = ["enable_for_xla_interpreter"],
|
||||
deps = [
|
||||
":matrix",
|
||||
":slicing",
|
||||
@ -295,7 +292,6 @@ cc_library(
|
||||
xla_test(
|
||||
name = "slicing_test",
|
||||
srcs = ["slicing_test.cc"],
|
||||
tags = ["enable_for_xla_interpreter"],
|
||||
deps = [
|
||||
":slicing",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
@ -324,7 +320,6 @@ cc_library(
|
||||
xla_test(
|
||||
name = "sorting_test",
|
||||
srcs = ["sorting_test.cc"],
|
||||
tags = ["enable_for_xla_interpreter"],
|
||||
deps = [
|
||||
":sorting",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -354,7 +349,6 @@ xla_test(
|
||||
srcs = ["quantize_test.cc"],
|
||||
# TODO(b/122119490): re-enable TAP after fixing.
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
|
@ -138,6 +138,11 @@ StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
|
||||
|
||||
} // namespace
|
||||
|
||||
// Note that unsupported types by the typed visitor does not necessarily imply
|
||||
// the non-typed HloEvaluator (parent evaluator) would not support them either
|
||||
// in the type-agnostic handler. For e.g., HandleGetTupleElement in the parent
|
||||
// type-agnostic evaluator will be able to accept Tuple primitive type, whereas
|
||||
// HloEvaluatorTypedVisitor cannot.
|
||||
HloEvaluator::HloEvaluator(int64 max_loop_iterations)
|
||||
: max_loop_iterations_(max_loop_iterations) {
|
||||
typed_visitors_[PRED] =
|
||||
|
@ -211,6 +211,29 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
|
||||
|
||||
Status HandleReduce(HloInstruction* reduce) override;
|
||||
|
||||
// Unsupported HLOs, note some of them (such as BatchNorm*) are typically
|
||||
// expanded in a semantic-preserving way into other HLOs by adding exanpsion
|
||||
// HLO pass to the HLO optimization pass during compilation, which can then be
|
||||
// handled by the evaluator.
|
||||
Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override {
|
||||
return Unimplemented("BatchNormGrad HLO is unsupported by the evaluator.");
|
||||
};
|
||||
Status HandleBatchNormInference(
|
||||
HloInstruction* batch_norm_inference) override {
|
||||
return Unimplemented(
|
||||
"BatchNormInference HLO is unsupported by the evaluator.");
|
||||
};
|
||||
Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override {
|
||||
return Unimplemented(
|
||||
"BatchNormTraining HLO is unsupported by the evaluator.");
|
||||
};
|
||||
Status HandleInfeed(HloInstruction* infeed) override {
|
||||
return Unimplemented("Infeed HLO is unsupported by the evaluator.");
|
||||
};
|
||||
Status HandleOutfeed(HloInstruction* outfeed) override {
|
||||
return Unimplemented("Outfeed HLO is unsupported by the evaluator.");
|
||||
};
|
||||
|
||||
// Returns the already-evaluated literal result for the instruction.
|
||||
// A Constant instruction is considered evaluated and its literal will be
|
||||
// returned directly without looking up the cache.
|
||||
|
@ -917,7 +917,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
Status HandleClamp(HloInstruction* clamp) {
|
||||
std::function<ElementwiseT(ElementwiseT, ElementwiseT, ElementwiseT)>
|
||||
clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) {
|
||||
return std::fmin(high, std::fmax(value, low));
|
||||
if (std::isnan(low) || std::isnan(high)) {
|
||||
return static_cast<ElementwiseT>(NAN);
|
||||
}
|
||||
return static_cast<ElementwiseT>(
|
||||
std::fmin(high, std::fmax(value, low)));
|
||||
};
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[clamp],
|
||||
@ -2664,12 +2668,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return HandleReducePrecision<ElementwiseT>(reduce_precision);
|
||||
}
|
||||
|
||||
template <typename NativeT,
|
||||
typename std::enable_if<
|
||||
std::is_same<NativeT, bfloat16>::value ||
|
||||
std::is_same<NativeT, Eigen::half>::value ||
|
||||
std::is_integral<NativeT>::value ||
|
||||
std::is_floating_point<NativeT>::value>::type* = nullptr>
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<
|
||||
std::is_same<NativeT, bfloat16>::value ||
|
||||
std::is_same<NativeT, Eigen::half>::value ||
|
||||
std::is_integral<NativeT>::value || is_complex_t<NativeT>::value ||
|
||||
std::is_floating_point<NativeT>::value>::type* = nullptr>
|
||||
Status HandleIota(HloInstruction* instruction) {
|
||||
auto* iota = Cast<HloIotaInstruction>(instruction);
|
||||
const int64 iota_size = iota->shape().dimensions(iota->iota_dimension());
|
||||
@ -2700,12 +2705,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
template <typename NativeT,
|
||||
typename std::enable_if<
|
||||
!(std::is_same<NativeT, bfloat16>::value ||
|
||||
std::is_same<NativeT, Eigen::half>::value ||
|
||||
std::is_integral<NativeT>::value ||
|
||||
std::is_floating_point<NativeT>::value)>::type* = nullptr>
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<
|
||||
!(std::is_same<NativeT, bfloat16>::value ||
|
||||
std::is_same<NativeT, Eigen::half>::value ||
|
||||
std::is_integral<NativeT>::value || is_complex_t<NativeT>::value ||
|
||||
std::is_floating_point<NativeT>::value)>::type* = nullptr>
|
||||
Status HandleIota(HloInstruction* iota) {
|
||||
return UnsupportedTypeError(iota);
|
||||
}
|
||||
|
@ -32,6 +32,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service:algebraic_simplifier",
|
||||
"//tensorflow/compiler/xla/service:batchnorm_expander",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
@ -41,12 +42,14 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
|
||||
"//tensorflow/compiler/xla/service:hlo_cse",
|
||||
"//tensorflow/compiler/xla/service:hlo_dce",
|
||||
"//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
||||
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
|
||||
"//tensorflow/compiler/xla/service:layout_assignment",
|
||||
"//tensorflow/compiler/xla/service:map_inliner",
|
||||
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
|
||||
"//tensorflow/compiler/xla/service:reshape_mover",
|
||||
"//tensorflow/compiler/xla/service:while_loop_simplifier",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/interpreter/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/layout_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/map_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -46,6 +47,11 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
|
||||
pipeline.AddPass<LayoutAssignment>(
|
||||
hlo_module->mutable_entry_computation_layout(),
|
||||
LayoutAssignment::InstructionCanChangeLayout);
|
||||
|
||||
ReducePrecisionInsertion::AddPasses(
|
||||
&pipeline, hlo_module->config().debug_options(),
|
||||
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
|
||||
|
||||
return pipeline.Run(hlo_module).status();
|
||||
}
|
||||
|
||||
|
@ -276,9 +276,6 @@ cc_library(
|
||||
xla_test(
|
||||
name = "bad_rng_shape_validation_test",
|
||||
srcs = ["bad_rng_shape_validation_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -344,9 +341,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "check_execution_arity_test",
|
||||
srcs = ["check_execution_arity_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -367,9 +361,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "query_inferred_shape_test",
|
||||
srcs = ["query_inferred_shape_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -387,9 +378,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "while_test",
|
||||
srcs = ["while_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -413,6 +401,10 @@ xla_test(
|
||||
xla_test(
|
||||
name = "xla_hlo_profile_test",
|
||||
srcs = ["xla_hlo_profile_test.cc"],
|
||||
blacklisted_backends = [
|
||||
# Hlo profiles are not supported on the interpreter backend.
|
||||
"interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -436,9 +428,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "axpy_simple_test",
|
||||
srcs = ["axpy_simple_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
@ -453,7 +442,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "map_test",
|
||||
srcs = ["map_test.cc"],
|
||||
tags = ["enable_for_xla_interpreter"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
@ -506,9 +494,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "pred_test",
|
||||
srcs = ["pred_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
@ -524,9 +509,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "select_test",
|
||||
srcs = ["select_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:global_data",
|
||||
@ -544,7 +526,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "conditional_test",
|
||||
srcs = ["conditional_test.cc"],
|
||||
tags = ["enable_for_xla_interpreter"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:global_data",
|
||||
@ -562,7 +543,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "unary_op_test",
|
||||
srcs = ["unary_op_test.cc"],
|
||||
tags = ["enable_for_xla_interpreter"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:global_data",
|
||||
@ -623,9 +603,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "deconstruct_tuple_test",
|
||||
srcs = ["deconstruct_tuple_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -648,7 +625,6 @@ xla_test(
|
||||
name = "array_elementwise_ops_test",
|
||||
srcs = ["array_elementwise_ops_test.cc"],
|
||||
shard_count = 25,
|
||||
tags = ["enable_for_xla_interpreter"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array3d",
|
||||
@ -698,7 +674,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "reduce_precision_test",
|
||||
srcs = ["reduce_precision_test.cc"],
|
||||
tags = ["enable_for_xla_interpreter"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
@ -725,7 +700,6 @@ xla_test(
|
||||
srcs = ["dot_operation_test.cc"],
|
||||
shard_count = 20,
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
@ -806,9 +780,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "transpose_test",
|
||||
srcs = ["transpose_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:reference_util",
|
||||
@ -828,9 +799,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "constants_test",
|
||||
srcs = ["constants_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array3d",
|
||||
@ -951,6 +919,11 @@ xla_test(
|
||||
xla_test(
|
||||
name = "batch_normalization_test",
|
||||
srcs = ["batch_normalization_test.cc"],
|
||||
blacklisted_backends = [
|
||||
# BatchNorm HLOs are not handled by the interpreter backend, and the
|
||||
# BatchNorm expander is not run on the interpreter.
|
||||
"interpreter",
|
||||
],
|
||||
shard_count = 40,
|
||||
deps = [
|
||||
":test_utils",
|
||||
@ -1042,9 +1015,6 @@ xla_test(
|
||||
name = "slice_test",
|
||||
srcs = ["slice_test.cc"],
|
||||
shard_count = 40,
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:reference_util",
|
||||
@ -1065,9 +1035,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "multidimensional_slice_test",
|
||||
srcs = ["multidimensional_slice_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array3d",
|
||||
@ -1085,9 +1052,6 @@ xla_test(
|
||||
name = "dynamic_ops_test",
|
||||
timeout = "moderate",
|
||||
srcs = ["dynamic_ops_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:reference_util",
|
||||
@ -1113,9 +1077,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "tuple_test",
|
||||
srcs = ["tuple_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
@ -1139,9 +1100,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "vector_ops_reduce_test",
|
||||
srcs = ["vector_ops_reduce_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array3d",
|
||||
@ -1162,7 +1120,6 @@ xla_test(
|
||||
srcs = ["reduce_test.cc"],
|
||||
shard_count = 40,
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
@ -1229,7 +1186,6 @@ xla_test(
|
||||
srcs = [],
|
||||
shard_count = 20,
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
"optonly",
|
||||
],
|
||||
xla_test_library_deps = [":reduce_window_test_library"],
|
||||
@ -1241,7 +1197,6 @@ xla_test(
|
||||
timeout = "long",
|
||||
srcs = ["select_and_scatter_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
@ -1267,9 +1222,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "copy_test",
|
||||
srcs = ["copy_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
@ -1290,9 +1242,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "reduce_hlo_test",
|
||||
srcs = ["reduce_hlo_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
@ -1306,9 +1255,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "token_hlo_test",
|
||||
srcs = ["token_hlo_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
@ -1323,9 +1269,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "call_test",
|
||||
srcs = ["call_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
@ -1368,9 +1311,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "binop_scaling_test",
|
||||
srcs = ["binop_scaling_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
@ -1388,9 +1328,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "broadcast_simple_test",
|
||||
srcs = ["broadcast_simple_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
@ -1410,9 +1347,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "pad_test",
|
||||
srcs = ["pad_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
@ -1434,9 +1368,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "fmax_test",
|
||||
srcs = ["fmax_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
@ -1451,9 +1382,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "log_test",
|
||||
srcs = ["log_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
@ -1468,9 +1396,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "matrix_ops_simple_test",
|
||||
srcs = ["matrix_ops_simple_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
@ -1497,6 +1422,10 @@ xla_test(
|
||||
xla_test(
|
||||
name = "prng_test",
|
||||
srcs = ["prng_test.cc"],
|
||||
blacklisted_backends = [
|
||||
# TODO(b/122047800) support RNGs on the interpreter backend.
|
||||
"interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -1517,9 +1446,6 @@ xla_test(
|
||||
name = "reshape_test",
|
||||
srcs = ["reshape_test.cc"],
|
||||
shard_count = 30,
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
@ -1545,9 +1471,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "reverse_test",
|
||||
srcs = ["reverse_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
@ -1566,9 +1489,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "vector_ops_simple_test",
|
||||
srcs = ["vector_ops_simple_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -1592,9 +1512,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "concat_test",
|
||||
srcs = ["concat_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array3d",
|
||||
@ -1615,9 +1532,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "convert_test",
|
||||
srcs = ["convert_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
@ -1637,6 +1551,10 @@ xla_test(
|
||||
xla_test(
|
||||
name = "all_reduce_test",
|
||||
srcs = ["all_reduce_test.cc"],
|
||||
blacklisted_backends = [
|
||||
# All reduce is not supported on the interpreter backend.
|
||||
"interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -1661,9 +1579,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "bitcast_convert_test",
|
||||
srcs = ["bitcast_convert_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
@ -1703,9 +1618,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "floor_ceil_test",
|
||||
srcs = ["floor_ceil_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
@ -1767,6 +1679,10 @@ xla_test(
|
||||
xla_test(
|
||||
name = "execution_profile_test",
|
||||
srcs = ["execution_profile_test.cc"],
|
||||
blacklisted_backends = [
|
||||
# Execution profiles are not supported on the interpreter backend.
|
||||
"interpreter",
|
||||
],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
"//tensorflow/compiler/xla/client:global_data",
|
||||
@ -1781,6 +1697,10 @@ xla_test(
|
||||
name = "execution_profile_test_with_xla_hlo_profile",
|
||||
srcs = ["execution_profile_test.cc"],
|
||||
args = ["--xla_hlo_profile"],
|
||||
blacklisted_backends = [
|
||||
# Hlo profiles are not supported on the interpreter backend.
|
||||
"interpreter",
|
||||
],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
"//tensorflow/compiler/xla/client:global_data",
|
||||
@ -1794,9 +1714,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "replay_test",
|
||||
srcs = ["replay_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
@ -1819,9 +1736,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "broadcast_test",
|
||||
srcs = ["broadcast_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -1883,9 +1797,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "fusion_test",
|
||||
srcs = ["fusion_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
@ -2003,6 +1914,10 @@ xla_test(
|
||||
xla_test(
|
||||
name = "outfeed_in_nested_computation_test",
|
||||
srcs = ["outfeed_in_nested_computation_test.cc"],
|
||||
blacklisted_backends = [
|
||||
# Outfeed ops are not supported on the interpreter backend.
|
||||
"interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/tests:local_client_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
@ -2179,7 +2094,6 @@ xla_test(
|
||||
srcs = ["iota_test.cc"],
|
||||
shard_count = 30,
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
# Require optimized builds, iota_test_cpu is very slow in fastbuild.
|
||||
"optonly",
|
||||
],
|
||||
|
@ -2047,6 +2047,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, ClampF32) {
|
||||
SetFastMathDisabled(true);
|
||||
XlaBuilder builder(TestName());
|
||||
auto minimum = ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, NAN});
|
||||
auto argument =
|
||||
ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
|
||||
auto maximum = ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, NAN, 123.0f});
|
||||
Clamp(minimum, argument, maximum);
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, NAN, NAN}, {},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto minimum = ConstantR0<float>(&builder, 0.0f);
|
||||
|
@ -76,7 +76,9 @@ XLA_TEST_F(Bfloat16Test, NegateScalarF16) {
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
|
||||
// Disabled on interpreter since BatchNormExanper is not run by default on the
|
||||
// intepreter backend.
|
||||
XLA_TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormTraining)) {
|
||||
const int kFeatureIndex = 2;
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
@ -110,7 +112,9 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
|
||||
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02));
|
||||
}
|
||||
|
||||
XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
|
||||
// Disabled on interpreter since BatchNormExanper is not run by default on the
|
||||
// intepreter backend.
|
||||
XLA_TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormGrad)) {
|
||||
const int kFeatureIndex = 2;
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
|
@ -109,7 +109,10 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
|
||||
/*minor_to_major=*/{1, 0})));
|
||||
}
|
||||
|
||||
XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
|
||||
// Disabled for interpreter since ExecuteAsyncOnStream is not implemented on
|
||||
// interpreter backend.
|
||||
XLA_TEST_F(ClientTest,
|
||||
DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(ExecuteParallel))) {
|
||||
XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg;
|
||||
Shape shape = ShapeUtil::MakeShape(S32, {2, 2});
|
||||
|
||||
|
@ -600,7 +600,9 @@ ENTRY main {
|
||||
|
||||
class GatherClientLibraryTest : public ClientLibraryTestBase {};
|
||||
|
||||
XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
|
||||
// Disabled on interpreter since ExectuteAsyncOnStream is not supported.
|
||||
XLA_TEST_F(GatherClientLibraryTest,
|
||||
DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(Basic))) {
|
||||
// We create this HLO, but using the XlaBuilder API.
|
||||
//
|
||||
// ENTRY main {
|
||||
|
@ -842,7 +842,8 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
|
||||
LiteralUtil::CreateR0<int64>(123456789000LL)}));
|
||||
}
|
||||
|
||||
XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
|
||||
// Disabled on interpreter backend since infeed HLO is unsupported.
|
||||
XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedTest)) {
|
||||
XlaBuilder builder(TestName());
|
||||
const Shape shape = ShapeUtil::MakeShape(F32, {3});
|
||||
auto in = Infeed(&builder, shape);
|
||||
@ -867,7 +868,8 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
|
||||
LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
|
||||
}
|
||||
|
||||
XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) {
|
||||
// Disabled on interpreter backend since infeed/outfeed HLOs are unsupported.
|
||||
XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) {
|
||||
XlaBuilder builder(TestName());
|
||||
const Shape shape = ShapeUtil::MakeShape(F32, {3});
|
||||
auto in = Infeed(&builder, shape);
|
||||
|
Loading…
Reference in New Issue
Block a user