[XLA] Enable interpreter backend test by default for xla/tests without having to

using additional backend tags "enable_for_xla_interpreter".

PiperOrigin-RevId: 227598976
This commit is contained in:
Kay Zhu 2019-01-02 16:11:52 -08:00 committed by TensorFlower Gardener
parent 0b6177c2fa
commit dd3adf935a
12 changed files with 115 additions and 140 deletions

View File

@ -93,7 +93,6 @@ cc_library(
xla_test( xla_test(
name = "constants_test", name = "constants_test",
srcs = ["constants_test.cc"], srcs = ["constants_test.cc"],
tags = ["enable_for_xla_interpreter"],
deps = [ deps = [
":constants", ":constants",
"//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test",
@ -147,7 +146,6 @@ cc_library(
xla_test( xla_test(
name = "math_test", name = "math_test",
srcs = ["math_test.cc"], srcs = ["math_test.cc"],
tags = ["enable_for_xla_interpreter"],
deps = [ deps = [
":math", ":math",
"//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:literal_util",
@ -181,7 +179,6 @@ cc_library(
xla_test( xla_test(
name = "matrix_test", name = "matrix_test",
srcs = ["matrix_test.cc"], srcs = ["matrix_test.cc"],
tags = ["enable_for_xla_interpreter"],
deps = [ deps = [
":matrix", ":matrix",
":slicing", ":slicing",
@ -295,7 +292,6 @@ cc_library(
xla_test( xla_test(
name = "slicing_test", name = "slicing_test",
srcs = ["slicing_test.cc"], srcs = ["slicing_test.cc"],
tags = ["enable_for_xla_interpreter"],
deps = [ deps = [
":slicing", ":slicing",
"//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:literal_util",
@ -324,7 +320,6 @@ cc_library(
xla_test( xla_test(
name = "sorting_test", name = "sorting_test",
srcs = ["sorting_test.cc"], srcs = ["sorting_test.cc"],
tags = ["enable_for_xla_interpreter"],
deps = [ deps = [
":sorting", ":sorting",
"//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test",
@ -354,7 +349,6 @@ xla_test(
srcs = ["quantize_test.cc"], srcs = ["quantize_test.cc"],
# TODO(b/122119490): re-enable TAP after fixing. # TODO(b/122119490): re-enable TAP after fixing.
tags = [ tags = [
"enable_for_xla_interpreter",
"notap", "notap",
], ],
deps = [ deps = [

View File

@ -138,6 +138,11 @@ StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
} // namespace } // namespace
// Note that unsupported types by the typed visitor does not necessarily imply
// the non-typed HloEvaluator (parent evaluator) would not support them either
// in the type-agnostic handler. For e.g., HandleGetTupleElement in the parent
// type-agnostic evaluator will be able to accept Tuple primitive type, whereas
// HloEvaluatorTypedVisitor cannot.
HloEvaluator::HloEvaluator(int64 max_loop_iterations) HloEvaluator::HloEvaluator(int64 max_loop_iterations)
: max_loop_iterations_(max_loop_iterations) { : max_loop_iterations_(max_loop_iterations) {
typed_visitors_[PRED] = typed_visitors_[PRED] =

View File

@ -211,6 +211,29 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleReduce(HloInstruction* reduce) override; Status HandleReduce(HloInstruction* reduce) override;
// Unsupported HLOs, note some of them (such as BatchNorm*) are typically
// expanded in a semantic-preserving way into other HLOs by adding exanpsion
// HLO pass to the HLO optimization pass during compilation, which can then be
// handled by the evaluator.
Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override {
return Unimplemented("BatchNormGrad HLO is unsupported by the evaluator.");
};
Status HandleBatchNormInference(
HloInstruction* batch_norm_inference) override {
return Unimplemented(
"BatchNormInference HLO is unsupported by the evaluator.");
};
Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override {
return Unimplemented(
"BatchNormTraining HLO is unsupported by the evaluator.");
};
Status HandleInfeed(HloInstruction* infeed) override {
return Unimplemented("Infeed HLO is unsupported by the evaluator.");
};
Status HandleOutfeed(HloInstruction* outfeed) override {
return Unimplemented("Outfeed HLO is unsupported by the evaluator.");
};
// Returns the already-evaluated literal result for the instruction. // Returns the already-evaluated literal result for the instruction.
// A Constant instruction is considered evaluated and its literal will be // A Constant instruction is considered evaluated and its literal will be
// returned directly without looking up the cache. // returned directly without looking up the cache.

View File

@ -917,7 +917,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
Status HandleClamp(HloInstruction* clamp) { Status HandleClamp(HloInstruction* clamp) {
std::function<ElementwiseT(ElementwiseT, ElementwiseT, ElementwiseT)> std::function<ElementwiseT(ElementwiseT, ElementwiseT, ElementwiseT)>
clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) {
return std::fmin(high, std::fmax(value, low)); if (std::isnan(low) || std::isnan(high)) {
return static_cast<ElementwiseT>(NAN);
}
return static_cast<ElementwiseT>(
std::fmin(high, std::fmax(value, low)));
}; };
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
parent_->evaluated_[clamp], parent_->evaluated_[clamp],
@ -2664,12 +2668,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleReducePrecision<ElementwiseT>(reduce_precision); return HandleReducePrecision<ElementwiseT>(reduce_precision);
} }
template <typename NativeT, template <
typename std::enable_if< typename NativeT,
std::is_same<NativeT, bfloat16>::value || typename std::enable_if<
std::is_same<NativeT, Eigen::half>::value || std::is_same<NativeT, bfloat16>::value ||
std::is_integral<NativeT>::value || std::is_same<NativeT, Eigen::half>::value ||
std::is_floating_point<NativeT>::value>::type* = nullptr> std::is_integral<NativeT>::value || is_complex_t<NativeT>::value ||
std::is_floating_point<NativeT>::value>::type* = nullptr>
Status HandleIota(HloInstruction* instruction) { Status HandleIota(HloInstruction* instruction) {
auto* iota = Cast<HloIotaInstruction>(instruction); auto* iota = Cast<HloIotaInstruction>(instruction);
const int64 iota_size = iota->shape().dimensions(iota->iota_dimension()); const int64 iota_size = iota->shape().dimensions(iota->iota_dimension());
@ -2700,12 +2705,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK(); return Status::OK();
} }
template <typename NativeT, template <
typename std::enable_if< typename NativeT,
!(std::is_same<NativeT, bfloat16>::value || typename std::enable_if<
std::is_same<NativeT, Eigen::half>::value || !(std::is_same<NativeT, bfloat16>::value ||
std::is_integral<NativeT>::value || std::is_same<NativeT, Eigen::half>::value ||
std::is_floating_point<NativeT>::value)>::type* = nullptr> std::is_integral<NativeT>::value || is_complex_t<NativeT>::value ||
std::is_floating_point<NativeT>::value)>::type* = nullptr>
Status HandleIota(HloInstruction* iota) { Status HandleIota(HloInstruction* iota) {
return UnsupportedTypeError(iota); return UnsupportedTypeError(iota);
} }

View File

@ -32,6 +32,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:algebraic_simplifier",
"//tensorflow/compiler/xla/service:batchnorm_expander",
"//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:executable",
@ -41,12 +42,14 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_cost_analysis",
"//tensorflow/compiler/xla/service:hlo_cse", "//tensorflow/compiler/xla/service:hlo_cse",
"//tensorflow/compiler/xla/service:hlo_dce", "//tensorflow/compiler/xla/service:hlo_dce",
"//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter",
"//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/compiler/xla/service:layout_assignment",
"//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:map_inliner",
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service:while_loop_simplifier",
"//tensorflow/core:lib", "//tensorflow/core:lib",

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/interpreter/executable.h" #include "tensorflow/compiler/xla/service/interpreter/executable.h"
#include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/compiler/xla/service/layout_assignment.h"
#include "tensorflow/compiler/xla/service/map_inliner.h" #include "tensorflow/compiler/xla/service/map_inliner.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
@ -46,6 +47,11 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
pipeline.AddPass<LayoutAssignment>( pipeline.AddPass<LayoutAssignment>(
hlo_module->mutable_entry_computation_layout(), hlo_module->mutable_entry_computation_layout(),
LayoutAssignment::InstructionCanChangeLayout); LayoutAssignment::InstructionCanChangeLayout);
ReducePrecisionInsertion::AddPasses(
&pipeline, hlo_module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
return pipeline.Run(hlo_module).status(); return pipeline.Run(hlo_module).status();
} }

View File

@ -276,9 +276,6 @@ cc_library(
xla_test( xla_test(
name = "bad_rng_shape_validation_test", name = "bad_rng_shape_validation_test",
srcs = ["bad_rng_shape_validation_test.cc"], srcs = ["bad_rng_shape_validation_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test",
@ -344,9 +341,6 @@ xla_test(
xla_test( xla_test(
name = "check_execution_arity_test", name = "check_execution_arity_test",
srcs = ["check_execution_arity_test.cc"], srcs = ["check_execution_arity_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
@ -367,9 +361,6 @@ xla_test(
xla_test( xla_test(
name = "query_inferred_shape_test", name = "query_inferred_shape_test",
srcs = ["query_inferred_shape_test.cc"], srcs = ["query_inferred_shape_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
@ -387,9 +378,6 @@ xla_test(
xla_test( xla_test(
name = "while_test", name = "while_test",
srcs = ["while_test.cc"], srcs = ["while_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
@ -413,6 +401,10 @@ xla_test(
xla_test( xla_test(
name = "xla_hlo_profile_test", name = "xla_hlo_profile_test",
srcs = ["xla_hlo_profile_test.cc"], srcs = ["xla_hlo_profile_test.cc"],
blacklisted_backends = [
# Hlo profiles are not supported on the interpreter backend.
"interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
@ -436,9 +428,6 @@ xla_test(
xla_test( xla_test(
name = "axpy_simple_test", name = "axpy_simple_test",
srcs = ["axpy_simple_test.cc"], srcs = ["axpy_simple_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_builder",
@ -453,7 +442,6 @@ xla_test(
xla_test( xla_test(
name = "map_test", name = "map_test",
srcs = ["map_test.cc"], srcs = ["map_test.cc"],
tags = ["enable_for_xla_interpreter"],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
@ -506,9 +494,6 @@ xla_test(
xla_test( xla_test(
name = "pred_test", name = "pred_test",
srcs = ["pred_test.cc"], srcs = ["pred_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:local_client",
@ -524,9 +509,6 @@ xla_test(
xla_test( xla_test(
name = "select_test", name = "select_test",
srcs = ["select_test.cc"], srcs = ["select_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:global_data",
@ -544,7 +526,6 @@ xla_test(
xla_test( xla_test(
name = "conditional_test", name = "conditional_test",
srcs = ["conditional_test.cc"], srcs = ["conditional_test.cc"],
tags = ["enable_for_xla_interpreter"],
deps = [ deps = [
"//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:global_data",
@ -562,7 +543,6 @@ xla_test(
xla_test( xla_test(
name = "unary_op_test", name = "unary_op_test",
srcs = ["unary_op_test.cc"], srcs = ["unary_op_test.cc"],
tags = ["enable_for_xla_interpreter"],
deps = [ deps = [
"//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:global_data",
@ -623,9 +603,6 @@ xla_test(
xla_test( xla_test(
name = "deconstruct_tuple_test", name = "deconstruct_tuple_test",
srcs = ["deconstruct_tuple_test.cc"], srcs = ["deconstruct_tuple_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
@ -648,7 +625,6 @@ xla_test(
name = "array_elementwise_ops_test", name = "array_elementwise_ops_test",
srcs = ["array_elementwise_ops_test.cc"], srcs = ["array_elementwise_ops_test.cc"],
shard_count = 25, shard_count = 25,
tags = ["enable_for_xla_interpreter"],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array3d",
@ -698,7 +674,6 @@ xla_test(
xla_test( xla_test(
name = "reduce_precision_test", name = "reduce_precision_test",
srcs = ["reduce_precision_test.cc"], srcs = ["reduce_precision_test.cc"],
tags = ["enable_for_xla_interpreter"],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
@ -725,7 +700,6 @@ xla_test(
srcs = ["dot_operation_test.cc"], srcs = ["dot_operation_test.cc"],
shard_count = 20, shard_count = 20,
tags = [ tags = [
"enable_for_xla_interpreter",
"optonly", "optonly",
], ],
deps = [ deps = [
@ -806,9 +780,6 @@ xla_test(
xla_test( xla_test(
name = "transpose_test", name = "transpose_test",
srcs = ["transpose_test.cc"], srcs = ["transpose_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:reference_util",
@ -828,9 +799,6 @@ xla_test(
xla_test( xla_test(
name = "constants_test", name = "constants_test",
srcs = ["constants_test.cc"], srcs = ["constants_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array3d",
@ -951,6 +919,11 @@ xla_test(
xla_test( xla_test(
name = "batch_normalization_test", name = "batch_normalization_test",
srcs = ["batch_normalization_test.cc"], srcs = ["batch_normalization_test.cc"],
blacklisted_backends = [
# BatchNorm HLOs are not handled by the interpreter backend, and the
# BatchNorm expander is not run on the interpreter.
"interpreter",
],
shard_count = 40, shard_count = 40,
deps = [ deps = [
":test_utils", ":test_utils",
@ -1042,9 +1015,6 @@ xla_test(
name = "slice_test", name = "slice_test",
srcs = ["slice_test.cc"], srcs = ["slice_test.cc"],
shard_count = 40, shard_count = 40,
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:reference_util",
@ -1065,9 +1035,6 @@ xla_test(
xla_test( xla_test(
name = "multidimensional_slice_test", name = "multidimensional_slice_test",
srcs = ["multidimensional_slice_test.cc"], srcs = ["multidimensional_slice_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array3d",
@ -1085,9 +1052,6 @@ xla_test(
name = "dynamic_ops_test", name = "dynamic_ops_test",
timeout = "moderate", timeout = "moderate",
srcs = ["dynamic_ops_test.cc"], srcs = ["dynamic_ops_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:reference_util",
@ -1113,9 +1077,6 @@ xla_test(
xla_test( xla_test(
name = "tuple_test", name = "tuple_test",
srcs = ["tuple_test.cc"], srcs = ["tuple_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
@ -1139,9 +1100,6 @@ xla_test(
xla_test( xla_test(
name = "vector_ops_reduce_test", name = "vector_ops_reduce_test",
srcs = ["vector_ops_reduce_test.cc"], srcs = ["vector_ops_reduce_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array3d",
@ -1162,7 +1120,6 @@ xla_test(
srcs = ["reduce_test.cc"], srcs = ["reduce_test.cc"],
shard_count = 40, shard_count = 40,
tags = [ tags = [
"enable_for_xla_interpreter",
"optonly", "optonly",
], ],
deps = [ deps = [
@ -1229,7 +1186,6 @@ xla_test(
srcs = [], srcs = [],
shard_count = 20, shard_count = 20,
tags = [ tags = [
"enable_for_xla_interpreter",
"optonly", "optonly",
], ],
xla_test_library_deps = [":reduce_window_test_library"], xla_test_library_deps = [":reduce_window_test_library"],
@ -1241,7 +1197,6 @@ xla_test(
timeout = "long", timeout = "long",
srcs = ["select_and_scatter_test.cc"], srcs = ["select_and_scatter_test.cc"],
tags = [ tags = [
"enable_for_xla_interpreter",
"optonly", "optonly",
], ],
deps = [ deps = [
@ -1267,9 +1222,6 @@ xla_test(
xla_test( xla_test(
name = "copy_test", name = "copy_test",
srcs = ["copy_test.cc"], srcs = ["copy_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
":client_library_test_base", ":client_library_test_base",
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
@ -1290,9 +1242,6 @@ xla_test(
xla_test( xla_test(
name = "reduce_hlo_test", name = "reduce_hlo_test",
srcs = ["reduce_hlo_test.cc"], srcs = ["reduce_hlo_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base",
@ -1306,9 +1255,6 @@ xla_test(
xla_test( xla_test(
name = "token_hlo_test", name = "token_hlo_test",
srcs = ["token_hlo_test.cc"], srcs = ["token_hlo_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:hlo_verifier",
@ -1323,9 +1269,6 @@ xla_test(
xla_test( xla_test(
name = "call_test", name = "call_test",
srcs = ["call_test.cc"], srcs = ["call_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:literal_util",
@ -1368,9 +1311,6 @@ xla_test(
xla_test( xla_test(
name = "binop_scaling_test", name = "binop_scaling_test",
srcs = ["binop_scaling_test.cc"], srcs = ["binop_scaling_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:array4d",
@ -1388,9 +1328,6 @@ xla_test(
xla_test( xla_test(
name = "broadcast_simple_test", name = "broadcast_simple_test",
srcs = ["broadcast_simple_test.cc"], srcs = ["broadcast_simple_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:array4d",
@ -1410,9 +1347,6 @@ xla_test(
xla_test( xla_test(
name = "pad_test", name = "pad_test",
srcs = ["pad_test.cc"], srcs = ["pad_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:array4d",
@ -1434,9 +1368,6 @@ xla_test(
xla_test( xla_test(
name = "fmax_test", name = "fmax_test",
srcs = ["fmax_test.cc"], srcs = ["fmax_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_builder",
@ -1451,9 +1382,6 @@ xla_test(
xla_test( xla_test(
name = "log_test", name = "log_test",
srcs = ["log_test.cc"], srcs = ["log_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_builder",
@ -1468,9 +1396,6 @@ xla_test(
xla_test( xla_test(
name = "matrix_ops_simple_test", name = "matrix_ops_simple_test",
srcs = ["matrix_ops_simple_test.cc"], srcs = ["matrix_ops_simple_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
@ -1497,6 +1422,10 @@ xla_test(
xla_test( xla_test(
name = "prng_test", name = "prng_test",
srcs = ["prng_test.cc"], srcs = ["prng_test.cc"],
blacklisted_backends = [
# TODO(b/122047800) support RNGs on the interpreter backend.
"interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
@ -1517,9 +1446,6 @@ xla_test(
name = "reshape_test", name = "reshape_test",
srcs = ["reshape_test.cc"], srcs = ["reshape_test.cc"],
shard_count = 30, shard_count = 30,
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:array4d",
@ -1545,9 +1471,6 @@ xla_test(
xla_test( xla_test(
name = "reverse_test", name = "reverse_test",
srcs = ["reverse_test.cc"], srcs = ["reverse_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:array4d",
@ -1566,9 +1489,6 @@ xla_test(
xla_test( xla_test(
name = "vector_ops_simple_test", name = "vector_ops_simple_test",
srcs = ["vector_ops_simple_test.cc"], srcs = ["vector_ops_simple_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
@ -1592,9 +1512,6 @@ xla_test(
xla_test( xla_test(
name = "concat_test", name = "concat_test",
srcs = ["concat_test.cc"], srcs = ["concat_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array3d",
@ -1615,9 +1532,6 @@ xla_test(
xla_test( xla_test(
name = "convert_test", name = "convert_test",
srcs = ["convert_test.cc"], srcs = ["convert_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_data_proto",
@ -1637,6 +1551,10 @@ xla_test(
xla_test( xla_test(
name = "all_reduce_test", name = "all_reduce_test",
srcs = ["all_reduce_test.cc"], srcs = ["all_reduce_test.cc"],
blacklisted_backends = [
# All reduce is not supported on the interpreter backend.
"interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
@ -1661,9 +1579,6 @@ xla_test(
xla_test( xla_test(
name = "bitcast_convert_test", name = "bitcast_convert_test",
srcs = ["bitcast_convert_test.cc"], srcs = ["bitcast_convert_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_data_proto",
@ -1703,9 +1618,6 @@ xla_test(
xla_test( xla_test(
name = "floor_ceil_test", name = "floor_ceil_test",
srcs = ["floor_ceil_test.cc"], srcs = ["floor_ceil_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_builder",
@ -1767,6 +1679,10 @@ xla_test(
xla_test( xla_test(
name = "execution_profile_test", name = "execution_profile_test",
srcs = ["execution_profile_test.cc"], srcs = ["execution_profile_test.cc"],
blacklisted_backends = [
# Execution profiles are not supported on the interpreter backend.
"interpreter",
],
deps = [ deps = [
":client_library_test_base", ":client_library_test_base",
"//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:global_data",
@ -1781,6 +1697,10 @@ xla_test(
name = "execution_profile_test_with_xla_hlo_profile", name = "execution_profile_test_with_xla_hlo_profile",
srcs = ["execution_profile_test.cc"], srcs = ["execution_profile_test.cc"],
args = ["--xla_hlo_profile"], args = ["--xla_hlo_profile"],
blacklisted_backends = [
# Hlo profiles are not supported on the interpreter backend.
"interpreter",
],
deps = [ deps = [
":client_library_test_base", ":client_library_test_base",
"//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:global_data",
@ -1794,9 +1714,6 @@ xla_test(
xla_test( xla_test(
name = "replay_test", name = "replay_test",
srcs = ["replay_test.cc"], srcs = ["replay_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:protobuf_util",
@ -1819,9 +1736,6 @@ xla_test(
xla_test( xla_test(
name = "broadcast_test", name = "broadcast_test",
srcs = ["broadcast_test.cc"], srcs = ["broadcast_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
@ -1883,9 +1797,6 @@ xla_test(
xla_test( xla_test(
name = "fusion_test", name = "fusion_test",
srcs = ["fusion_test.cc"], srcs = ["fusion_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
@ -2003,6 +1914,10 @@ xla_test(
xla_test( xla_test(
name = "outfeed_in_nested_computation_test", name = "outfeed_in_nested_computation_test",
srcs = ["outfeed_in_nested_computation_test.cc"], srcs = ["outfeed_in_nested_computation_test.cc"],
blacklisted_backends = [
# Outfeed ops are not supported on the interpreter backend.
"interpreter",
],
deps = [ deps = [
"//tensorflow/compiler/xla/tests:local_client_test_base", "//tensorflow/compiler/xla/tests:local_client_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/compiler/xla/tests:xla_internal_test_main",
@ -2179,7 +2094,6 @@ xla_test(
srcs = ["iota_test.cc"], srcs = ["iota_test.cc"],
shard_count = 30, shard_count = 30,
tags = [ tags = [
"enable_for_xla_interpreter",
# Require optimized builds, iota_test_cpu is very slow in fastbuild. # Require optimized builds, iota_test_cpu is very slow in fastbuild.
"optonly", "optonly",
], ],

View File

@ -2047,6 +2047,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
error_spec_); error_spec_);
} }
XLA_TEST_F(ArrayElementwiseOpTest, ClampF32) {
SetFastMathDisabled(true);
XlaBuilder builder(TestName());
auto minimum = ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, NAN});
auto argument =
ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
auto maximum = ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, NAN, 123.0f});
Clamp(minimum, argument, maximum);
ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, NAN, NAN}, {},
error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
auto minimum = ConstantR0<float>(&builder, 0.0f); auto minimum = ConstantR0<float>(&builder, 0.0f);

View File

@ -76,7 +76,9 @@ XLA_TEST_F(Bfloat16Test, NegateScalarF16) {
error_spec_); error_spec_);
} }
XLA_TEST_F(Bfloat16Test, BatchNormTraining) { // Disabled on interpreter since BatchNormExanper is not run by default on the
// intepreter backend.
XLA_TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormTraining)) {
const int kFeatureIndex = 2; const int kFeatureIndex = 2;
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
@ -110,7 +112,9 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02)); ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02));
} }
XLA_TEST_F(Bfloat16Test, BatchNormGrad) { // Disabled on interpreter since BatchNormExanper is not run by default on the
// intepreter backend.
XLA_TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormGrad)) {
const int kFeatureIndex = 2; const int kFeatureIndex = 2;
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());

View File

@ -109,7 +109,10 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
/*minor_to_major=*/{1, 0}))); /*minor_to_major=*/{1, 0})));
} }
XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { // Disabled for interpreter since ExecuteAsyncOnStream is not implemented on
// interpreter backend.
XLA_TEST_F(ClientTest,
DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(ExecuteParallel))) {
XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg; XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg;
Shape shape = ShapeUtil::MakeShape(S32, {2, 2}); Shape shape = ShapeUtil::MakeShape(S32, {2, 2});

View File

@ -600,7 +600,9 @@ ENTRY main {
class GatherClientLibraryTest : public ClientLibraryTestBase {}; class GatherClientLibraryTest : public ClientLibraryTestBase {};
XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { // Disabled on interpreter since ExectuteAsyncOnStream is not supported.
XLA_TEST_F(GatherClientLibraryTest,
DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(Basic))) {
// We create this HLO, but using the XlaBuilder API. // We create this HLO, but using the XlaBuilder API.
// //
// ENTRY main { // ENTRY main {

View File

@ -842,7 +842,8 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
LiteralUtil::CreateR0<int64>(123456789000LL)})); LiteralUtil::CreateR0<int64>(123456789000LL)}));
} }
XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { // Disabled on interpreter backend since infeed HLO is unsupported.
XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedTest)) {
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
const Shape shape = ShapeUtil::MakeShape(F32, {3}); const Shape shape = ShapeUtil::MakeShape(F32, {3});
auto in = Infeed(&builder, shape); auto in = Infeed(&builder, shape);
@ -867,7 +868,8 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result); LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
} }
XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) { // Disabled on interpreter backend since infeed/outfeed HLOs are unsupported.
XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) {
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
const Shape shape = ShapeUtil::MakeShape(F32, {3}); const Shape shape = ShapeUtil::MakeShape(F32, {3});
auto in = Infeed(&builder, shape); auto in = Infeed(&builder, shape);