[XLA] Enable interpreter backend test by default for xla/tests without having to

using additional backend tags "enable_for_xla_interpreter".

PiperOrigin-RevId: 227598976
This commit is contained in:
Kay Zhu 2019-01-02 16:11:52 -08:00 committed by TensorFlower Gardener
parent 0b6177c2fa
commit dd3adf935a
12 changed files with 115 additions and 140 deletions

View File

@ -93,7 +93,6 @@ cc_library(
xla_test(
name = "constants_test",
srcs = ["constants_test.cc"],
tags = ["enable_for_xla_interpreter"],
deps = [
":constants",
"//tensorflow/compiler/xla:test",
@ -147,7 +146,6 @@ cc_library(
xla_test(
name = "math_test",
srcs = ["math_test.cc"],
tags = ["enable_for_xla_interpreter"],
deps = [
":math",
"//tensorflow/compiler/xla:literal_util",
@ -181,7 +179,6 @@ cc_library(
xla_test(
name = "matrix_test",
srcs = ["matrix_test.cc"],
tags = ["enable_for_xla_interpreter"],
deps = [
":matrix",
":slicing",
@ -295,7 +292,6 @@ cc_library(
xla_test(
name = "slicing_test",
srcs = ["slicing_test.cc"],
tags = ["enable_for_xla_interpreter"],
deps = [
":slicing",
"//tensorflow/compiler/xla:literal_util",
@ -324,7 +320,6 @@ cc_library(
xla_test(
name = "sorting_test",
srcs = ["sorting_test.cc"],
tags = ["enable_for_xla_interpreter"],
deps = [
":sorting",
"//tensorflow/compiler/xla:test",
@ -354,7 +349,6 @@ xla_test(
srcs = ["quantize_test.cc"],
# TODO(b/122119490): re-enable TAP after fixing.
tags = [
"enable_for_xla_interpreter",
"notap",
],
deps = [

View File

@ -138,6 +138,11 @@ StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
} // namespace
// Note that unsupported types by the typed visitor does not necessarily imply
// the non-typed HloEvaluator (parent evaluator) would not support them either
// in the type-agnostic handler. For e.g., HandleGetTupleElement in the parent
// type-agnostic evaluator will be able to accept Tuple primitive type, whereas
// HloEvaluatorTypedVisitor cannot.
HloEvaluator::HloEvaluator(int64 max_loop_iterations)
: max_loop_iterations_(max_loop_iterations) {
typed_visitors_[PRED] =

View File

@ -211,6 +211,29 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleReduce(HloInstruction* reduce) override;
// Unsupported HLOs, note some of them (such as BatchNorm*) are typically
// expanded in a semantic-preserving way into other HLOs by adding exanpsion
// HLO pass to the HLO optimization pass during compilation, which can then be
// handled by the evaluator.
Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override {
return Unimplemented("BatchNormGrad HLO is unsupported by the evaluator.");
};
Status HandleBatchNormInference(
HloInstruction* batch_norm_inference) override {
return Unimplemented(
"BatchNormInference HLO is unsupported by the evaluator.");
};
Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override {
return Unimplemented(
"BatchNormTraining HLO is unsupported by the evaluator.");
};
Status HandleInfeed(HloInstruction* infeed) override {
return Unimplemented("Infeed HLO is unsupported by the evaluator.");
};
Status HandleOutfeed(HloInstruction* outfeed) override {
return Unimplemented("Outfeed HLO is unsupported by the evaluator.");
};
// Returns the already-evaluated literal result for the instruction.
// A Constant instruction is considered evaluated and its literal will be
// returned directly without looking up the cache.

View File

@ -917,7 +917,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
Status HandleClamp(HloInstruction* clamp) {
std::function<ElementwiseT(ElementwiseT, ElementwiseT, ElementwiseT)>
clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) {
return std::fmin(high, std::fmax(value, low));
if (std::isnan(low) || std::isnan(high)) {
return static_cast<ElementwiseT>(NAN);
}
return static_cast<ElementwiseT>(
std::fmin(high, std::fmax(value, low)));
};
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[clamp],
@ -2664,11 +2668,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleReducePrecision<ElementwiseT>(reduce_precision);
}
template <typename NativeT,
template <
typename NativeT,
typename std::enable_if<
std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, Eigen::half>::value ||
std::is_integral<NativeT>::value ||
std::is_integral<NativeT>::value || is_complex_t<NativeT>::value ||
std::is_floating_point<NativeT>::value>::type* = nullptr>
Status HandleIota(HloInstruction* instruction) {
auto* iota = Cast<HloIotaInstruction>(instruction);
@ -2700,11 +2705,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
template <typename NativeT,
template <
typename NativeT,
typename std::enable_if<
!(std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, Eigen::half>::value ||
std::is_integral<NativeT>::value ||
std::is_integral<NativeT>::value || is_complex_t<NativeT>::value ||
std::is_floating_point<NativeT>::value)>::type* = nullptr>
Status HandleIota(HloInstruction* iota) {
return UnsupportedTypeError(iota);

View File

@ -32,6 +32,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service:algebraic_simplifier",
"//tensorflow/compiler/xla/service:batchnorm_expander",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:executable",
@ -41,12 +42,14 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
"//tensorflow/compiler/xla/service:hlo_cse",
"//tensorflow/compiler/xla/service:hlo_dce",
"//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:layout_assignment",
"//tensorflow/compiler/xla/service:map_inliner",
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:while_loop_simplifier",
"//tensorflow/core:lib",

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/interpreter/executable.h"
#include "tensorflow/compiler/xla/service/layout_assignment.h"
#include "tensorflow/compiler/xla/service/map_inliner.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -46,6 +47,11 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
pipeline.AddPass<LayoutAssignment>(
hlo_module->mutable_entry_computation_layout(),
LayoutAssignment::InstructionCanChangeLayout);
ReducePrecisionInsertion::AddPasses(
&pipeline, hlo_module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
return pipeline.Run(hlo_module).status();
}

View File

@ -276,9 +276,6 @@ cc_library(
xla_test(
name = "bad_rng_shape_validation_test",
srcs = ["bad_rng_shape_validation_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
@ -344,9 +341,6 @@ xla_test(
xla_test(
name = "check_execution_arity_test",
srcs = ["check_execution_arity_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@ -367,9 +361,6 @@ xla_test(
xla_test(
name = "query_inferred_shape_test",
srcs = ["query_inferred_shape_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@ -387,9 +378,6 @@ xla_test(
xla_test(
name = "while_test",
srcs = ["while_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@ -413,6 +401,10 @@ xla_test(
xla_test(
name = "xla_hlo_profile_test",
srcs = ["xla_hlo_profile_test.cc"],
blacklisted_backends = [
# Hlo profiles are not supported on the interpreter backend.
"interpreter",
],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:shape_util",
@ -436,9 +428,6 @@ xla_test(
xla_test(
name = "axpy_simple_test",
srcs = ["axpy_simple_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
@ -453,7 +442,6 @@ xla_test(
xla_test(
name = "map_test",
srcs = ["map_test.cc"],
tags = ["enable_for_xla_interpreter"],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
@ -506,9 +494,6 @@ xla_test(
xla_test(
name = "pred_test",
srcs = ["pred_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla/client:local_client",
@ -524,9 +509,6 @@ xla_test(
xla_test(
name = "select_test",
srcs = ["select_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
@ -544,7 +526,6 @@ xla_test(
xla_test(
name = "conditional_test",
srcs = ["conditional_test.cc"],
tags = ["enable_for_xla_interpreter"],
deps = [
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
@ -562,7 +543,6 @@ xla_test(
xla_test(
name = "unary_op_test",
srcs = ["unary_op_test.cc"],
tags = ["enable_for_xla_interpreter"],
deps = [
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
@ -623,9 +603,6 @@ xla_test(
xla_test(
name = "deconstruct_tuple_test",
srcs = ["deconstruct_tuple_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@ -648,7 +625,6 @@ xla_test(
name = "array_elementwise_ops_test",
srcs = ["array_elementwise_ops_test.cc"],
shard_count = 25,
tags = ["enable_for_xla_interpreter"],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
@ -698,7 +674,6 @@ xla_test(
xla_test(
name = "reduce_precision_test",
srcs = ["reduce_precision_test.cc"],
tags = ["enable_for_xla_interpreter"],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
@ -725,7 +700,6 @@ xla_test(
srcs = ["dot_operation_test.cc"],
shard_count = 20,
tags = [
"enable_for_xla_interpreter",
"optonly",
],
deps = [
@ -806,9 +780,6 @@ xla_test(
xla_test(
name = "transpose_test",
srcs = ["transpose_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:reference_util",
@ -828,9 +799,6 @@ xla_test(
xla_test(
name = "constants_test",
srcs = ["constants_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
@ -951,6 +919,11 @@ xla_test(
xla_test(
name = "batch_normalization_test",
srcs = ["batch_normalization_test.cc"],
blacklisted_backends = [
# BatchNorm HLOs are not handled by the interpreter backend, and the
# BatchNorm expander is not run on the interpreter.
"interpreter",
],
shard_count = 40,
deps = [
":test_utils",
@ -1042,9 +1015,6 @@ xla_test(
name = "slice_test",
srcs = ["slice_test.cc"],
shard_count = 40,
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:reference_util",
@ -1065,9 +1035,6 @@ xla_test(
xla_test(
name = "multidimensional_slice_test",
srcs = ["multidimensional_slice_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
@ -1085,9 +1052,6 @@ xla_test(
name = "dynamic_ops_test",
timeout = "moderate",
srcs = ["dynamic_ops_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:reference_util",
@ -1113,9 +1077,6 @@ xla_test(
xla_test(
name = "tuple_test",
srcs = ["tuple_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
@ -1139,9 +1100,6 @@ xla_test(
xla_test(
name = "vector_ops_reduce_test",
srcs = ["vector_ops_reduce_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
@ -1162,7 +1120,6 @@ xla_test(
srcs = ["reduce_test.cc"],
shard_count = 40,
tags = [
"enable_for_xla_interpreter",
"optonly",
],
deps = [
@ -1229,7 +1186,6 @@ xla_test(
srcs = [],
shard_count = 20,
tags = [
"enable_for_xla_interpreter",
"optonly",
],
xla_test_library_deps = [":reduce_window_test_library"],
@ -1241,7 +1197,6 @@ xla_test(
timeout = "long",
srcs = ["select_and_scatter_test.cc"],
tags = [
"enable_for_xla_interpreter",
"optonly",
],
deps = [
@ -1267,9 +1222,6 @@ xla_test(
xla_test(
name = "copy_test",
srcs = ["copy_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
":client_library_test_base",
"//tensorflow/compiler/xla:array2d",
@ -1290,9 +1242,6 @@ xla_test(
xla_test(
name = "reduce_hlo_test",
srcs = ["reduce_hlo_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
@ -1306,9 +1255,6 @@ xla_test(
xla_test(
name = "token_hlo_test",
srcs = ["token_hlo_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_verifier",
@ -1323,9 +1269,6 @@ xla_test(
xla_test(
name = "call_test",
srcs = ["call_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
@ -1368,9 +1311,6 @@ xla_test(
xla_test(
name = "binop_scaling_test",
srcs = ["binop_scaling_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
@ -1388,9 +1328,6 @@ xla_test(
xla_test(
name = "broadcast_simple_test",
srcs = ["broadcast_simple_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
@ -1410,9 +1347,6 @@ xla_test(
xla_test(
name = "pad_test",
srcs = ["pad_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
@ -1434,9 +1368,6 @@ xla_test(
xla_test(
name = "fmax_test",
srcs = ["fmax_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
@ -1451,9 +1382,6 @@ xla_test(
xla_test(
name = "log_test",
srcs = ["log_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
@ -1468,9 +1396,6 @@ xla_test(
xla_test(
name = "matrix_ops_simple_test",
srcs = ["matrix_ops_simple_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
@ -1497,6 +1422,10 @@ xla_test(
xla_test(
name = "prng_test",
srcs = ["prng_test.cc"],
blacklisted_backends = [
# TODO(b/122047800) support RNGs on the interpreter backend.
"interpreter",
],
deps = [
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@ -1517,9 +1446,6 @@ xla_test(
name = "reshape_test",
srcs = ["reshape_test.cc"],
shard_count = 30,
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
@ -1545,9 +1471,6 @@ xla_test(
xla_test(
name = "reverse_test",
srcs = ["reverse_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
@ -1566,9 +1489,6 @@ xla_test(
xla_test(
name = "vector_ops_simple_test",
srcs = ["vector_ops_simple_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:shape_util",
@ -1592,9 +1512,6 @@ xla_test(
xla_test(
name = "concat_test",
srcs = ["concat_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
@ -1615,9 +1532,6 @@ xla_test(
xla_test(
name = "convert_test",
srcs = ["convert_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
@ -1637,6 +1551,10 @@ xla_test(
xla_test(
name = "all_reduce_test",
srcs = ["all_reduce_test.cc"],
blacklisted_backends = [
# All reduce is not supported on the interpreter backend.
"interpreter",
],
deps = [
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@ -1661,9 +1579,6 @@ xla_test(
xla_test(
name = "bitcast_convert_test",
srcs = ["bitcast_convert_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
@ -1703,9 +1618,6 @@ xla_test(
xla_test(
name = "floor_ceil_test",
srcs = ["floor_ceil_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
@ -1767,6 +1679,10 @@ xla_test(
xla_test(
name = "execution_profile_test",
srcs = ["execution_profile_test.cc"],
blacklisted_backends = [
# Execution profiles are not supported on the interpreter backend.
"interpreter",
],
deps = [
":client_library_test_base",
"//tensorflow/compiler/xla/client:global_data",
@ -1781,6 +1697,10 @@ xla_test(
name = "execution_profile_test_with_xla_hlo_profile",
srcs = ["execution_profile_test.cc"],
args = ["--xla_hlo_profile"],
blacklisted_backends = [
# Hlo profiles are not supported on the interpreter backend.
"interpreter",
],
deps = [
":client_library_test_base",
"//tensorflow/compiler/xla/client:global_data",
@ -1794,9 +1714,6 @@ xla_test(
xla_test(
name = "replay_test",
srcs = ["replay_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
@ -1819,9 +1736,6 @@ xla_test(
xla_test(
name = "broadcast_test",
srcs = ["broadcast_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@ -1883,9 +1797,6 @@ xla_test(
xla_test(
name = "fusion_test",
srcs = ["fusion_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
@ -2003,6 +1914,10 @@ xla_test(
xla_test(
name = "outfeed_in_nested_computation_test",
srcs = ["outfeed_in_nested_computation_test.cc"],
blacklisted_backends = [
# Outfeed ops are not supported on the interpreter backend.
"interpreter",
],
deps = [
"//tensorflow/compiler/xla/tests:local_client_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@ -2179,7 +2094,6 @@ xla_test(
srcs = ["iota_test.cc"],
shard_count = 30,
tags = [
"enable_for_xla_interpreter",
# Require optimized builds, iota_test_cpu is very slow in fastbuild.
"optonly",
],

View File

@ -2047,6 +2047,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, ClampF32) {
SetFastMathDisabled(true);
XlaBuilder builder(TestName());
auto minimum = ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, NAN});
auto argument =
ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
auto maximum = ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, NAN, 123.0f});
Clamp(minimum, argument, maximum);
ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, NAN, NAN}, {},
error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
XlaBuilder builder(TestName());
auto minimum = ConstantR0<float>(&builder, 0.0f);

View File

@ -76,7 +76,9 @@ XLA_TEST_F(Bfloat16Test, NegateScalarF16) {
error_spec_);
}
XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
// Disabled on interpreter since BatchNormExanper is not run by default on the
// intepreter backend.
XLA_TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormTraining)) {
const int kFeatureIndex = 2;
XlaBuilder builder(TestName());
@ -110,7 +112,9 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02));
}
XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
// Disabled on interpreter since BatchNormExanper is not run by default on the
// intepreter backend.
XLA_TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormGrad)) {
const int kFeatureIndex = 2;
XlaBuilder builder(TestName());

View File

@ -109,7 +109,10 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
/*minor_to_major=*/{1, 0})));
}
XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
// Disabled for interpreter since ExecuteAsyncOnStream is not implemented on
// interpreter backend.
XLA_TEST_F(ClientTest,
DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(ExecuteParallel))) {
XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg;
Shape shape = ShapeUtil::MakeShape(S32, {2, 2});

View File

@ -600,7 +600,9 @@ ENTRY main {
class GatherClientLibraryTest : public ClientLibraryTestBase {};
XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
// Disabled on interpreter since ExectuteAsyncOnStream is not supported.
XLA_TEST_F(GatherClientLibraryTest,
DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(Basic))) {
// We create this HLO, but using the XlaBuilder API.
//
// ENTRY main {

View File

@ -842,7 +842,8 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
LiteralUtil::CreateR0<int64>(123456789000LL)}));
}
XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
// Disabled on interpreter backend since infeed HLO is unsupported.
XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedTest)) {
XlaBuilder builder(TestName());
const Shape shape = ShapeUtil::MakeShape(F32, {3});
auto in = Infeed(&builder, shape);
@ -867,7 +868,8 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
}
XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) {
// Disabled on interpreter backend since infeed/outfeed HLOs are unsupported.
XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) {
XlaBuilder builder(TestName());
const Shape shape = ShapeUtil::MakeShape(F32, {3});
auto in = Infeed(&builder, shape);