Use VerifiedHloModule in a few more tests.

Also enable some tests that are now passing on the GPU backend.
Finally, remove some unused hlo_parser.h includes and the corresponding
dependency.

PiperOrigin-RevId: 275224726
Change-Id: Icc206d85b40c439abe8232aaa748ed4a07b50b09
This commit is contained in:
Adrian Kuegel 2019-10-17 03:31:35 -07:00 committed by TensorFlower Gardener
parent 6caee8ef9b
commit 6f81dbf07a
17 changed files with 42 additions and 62 deletions

View File

@ -227,7 +227,6 @@ cc_library(
deps = [
":codegen_test_base",
":filecheck",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:test",
@ -333,7 +332,6 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:bfloat16_normalization",
"//tensorflow/compiler/xla/service:despecializer",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@ -359,7 +357,6 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:bfloat16_normalization",
"//tensorflow/compiler/xla/service:despecializer",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@ -372,8 +369,6 @@ xla_test(
timeout = "long",
srcs = ["grouped_convolution_test.cc"],
blacklisted_backends = [
# disabled because of a break b/119590850.
"gpu",
# disabled because it times out.
"cpu",
],
@ -386,7 +381,6 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:bfloat16_normalization",
"//tensorflow/compiler/xla/service:despecializer",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@ -749,7 +743,6 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:bfloat16_normalization",
"//tensorflow/compiler/xla/service:despecializer",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@ -995,7 +988,6 @@ xla_test(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@ -1009,7 +1001,6 @@ xla_test(
":test_macros_header",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@ -1549,7 +1540,6 @@ xla_test(
srcs = ["reduce_hlo_test.cc"],
deps = [
":test_macros_header",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@ -1563,7 +1553,6 @@ xla_test(
srcs = ["token_hlo_test.cc"],
deps = [
":test_macros_header",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@ -1889,7 +1878,6 @@ xla_test(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
@ -2123,6 +2111,7 @@ tf_cc_test(
srcs = ["llvm_compiler_test.cc"],
tags = tf_cuda_tests_tags(),
deps = [
":verified_hlo_module",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:backend",
@ -2175,7 +2164,6 @@ xla_test(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",

View File

@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
#include "tensorflow/compiler/xla/service/despecializer.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
#include "tensorflow/compiler/xla/service/despecializer.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
#include "tensorflow/compiler/xla/service/despecializer.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
#include "tensorflow/compiler/xla/service/despecializer.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"

View File

@ -45,7 +45,7 @@ class CopyOpTest : public HloTestBase {
builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kCopy, constant));
auto computation = builder.Build();
auto module = CreateNewUnverifiedModule();
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(std::move(computation));
Literal result = ExecuteAndTransfer(std::move(module), {});
@ -98,7 +98,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
auto computation = builder.Build();
auto module = CreateNewUnverifiedModule();
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(std::move(computation));
Literal result = ExecuteAndTransfer(std::move(module), {&literal});
@ -119,7 +119,7 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) {
auto computation = builder.Build();
auto module = CreateNewUnverifiedModule();
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(std::move(computation));
Literal result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, result,
@ -145,7 +145,7 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
std::unique_ptr<HloComputation> computation = builder.Build();
auto module = CreateNewUnverifiedModule();
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(std::move(computation));
Literal result = ExecuteAndTransfer(std::move(module), {});
@ -177,7 +177,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
std::unique_ptr<HloComputation> computation = builder.Build();
auto module = CreateNewUnverifiedModule();
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(std::move(computation));
ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0}));
Literal result = ExecuteAndTransfer(std::move(module), {});
@ -211,7 +211,7 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
std::unique_ptr<HloComputation> computation = builder.Build();
auto module = CreateNewUnverifiedModule();
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(std::move(computation));
ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation));
Literal result = ExecuteAndTransfer(std::move(module), {});

View File

@ -35,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"

View File

@ -79,7 +79,7 @@ class CustomCallTest : public HloTestBase {
};
XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2) {
auto module = CreateNewUnverifiedModule();
auto module = CreateNewVerifiedModule();
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
@ -94,7 +94,7 @@ XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2) {
}
XLA_TEST_F(CustomCallTest, CustomCallR2F32Reduce) {
auto module = CreateNewUnverifiedModule();
auto module = CreateNewVerifiedModule();
auto builder = HloComputation::Builder(TestName());
Array2D<float> array(2, 2);
@ -115,7 +115,7 @@ XLA_TEST_F(CustomCallTest, CustomCallR2F32Reduce) {
}
XLA_TEST_F(CustomCallTest, UsedInOtherComputations) {
auto module = CreateNewUnverifiedModule();
auto module = CreateNewVerifiedModule();
auto b = HloComputation::Builder(TestName());
auto input = b.AddInstruction(
@ -139,7 +139,7 @@ XLA_TEST_F(CustomCallTest, UsedInOtherComputations) {
}
XLA_TEST_F(CustomCallTest, InputAndOutputLayoutDiffer) {
auto module = CreateNewUnverifiedModule();
auto module = CreateNewVerifiedModule();
auto b = HloComputation::Builder(TestName());
auto input =
@ -164,7 +164,7 @@ XLA_TEST_F(CustomCallTest, LayoutConstrained) {
// The argument and result of the computation are set to different layouts,
// but the custom call is layout constrained to a fixed operand and result
// layout, so the correct result should be produced.
auto module = CreateNewUnverifiedModule();
auto module = CreateNewVerifiedModule();
auto b = HloComputation::Builder(TestName());
auto input =

View File

@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@ -38,7 +37,7 @@ class GatherOperationTest : public HloTestBase {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_text, config));
ParseAndReturnVerifiedModule(hlo_text, config));
EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt));
}
};

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
#include "tensorflow/compiler/xla/service/despecializer.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
@ -22,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/stream_executor/stream_executor.h"
@ -68,7 +70,7 @@ class LLVMCompilerTest : public ::testing::Test {
builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
auto hlo_module = CreateNewUnverifiedModule();
auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());
compiler->SetPreOptimizationHook(pre_opt_hook);
@ -90,7 +92,7 @@ class LLVMCompilerTest : public ::testing::Test {
builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
std::unique_ptr<HloModule> hlo_module = CreateNewUnverifiedModule();
std::unique_ptr<HloModule> hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());
auto module_group = absl::make_unique<HloModuleGroup>("test_module_group");
@ -124,10 +126,13 @@ class LLVMCompilerTest : public ::testing::Test {
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
}
static std::unique_ptr<HloModule> CreateNewUnverifiedModule() {
std::unique_ptr<HloModule> CreateNewVerifiedModule() {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsFromFlags());
return absl::make_unique<HloModule>(TestName(), config);
return absl::make_unique<VerifiedHloModule>(
TestName(), config, /*verifier_layout_sensitive=*/false,
/*allow_mixed_precision_in_hlo_verifier=*/true,
backend_->compiler()->ShapeSizeBytesFunction());
}
};

View File

@ -18,7 +18,6 @@ limitations under the License.
#include <functional>
#include <utility>
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/tests/filecheck.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -66,7 +65,7 @@ void LlvmIrGenTestBase::CompileAndVerifyIr(const string& hlo_text,
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_text, config));
ParseAndReturnVerifiedModule(hlo_text, config));
CompileAndVerifyIr(std::move(module), expected_llvm_ir, match_optimized_ir);
}

View File

@ -62,7 +62,7 @@ class MultiOutputFusionTest : public HloTestBase {
void RunTest2D(bool manual_fusion, int64 size) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewUnverifiedModule();
auto hlo_module = CreateNewVerifiedModule();
const Shape elem_shape0 = ShapeUtil::MakeShapeWithLayout(F32, {}, {});
const Shape elem_shape2 =
@ -122,7 +122,7 @@ class MultiOutputFusionTest : public HloTestBase {
void RunTest1D(bool manual_fusion, int size) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewUnverifiedModule();
auto hlo_module = CreateNewVerifiedModule();
const Shape elem_shape_F32 =
ShapeUtil::MakeShapeWithDescendingLayout(F32, {size});

View File

@ -17,7 +17,6 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/platform/test.h"
@ -50,10 +49,10 @@ void PrintTo(const ReduceLayout& reduce_layout, ::std::ostream* os) {
class ReduceWithLayoutTest
: public HloTestBase,
public ::testing::WithParamInterface<ReduceLayout> {};
StatusOr<std::unique_ptr<HloModule>> GetParsedModule() {
const char* const hlo_string = R"(
public ::testing::WithParamInterface<ReduceLayout> {
public:
StatusOr<std::unique_ptr<HloModule>> GetParsedModule() {
const char* const hlo_string = R"(
HloModule BadReduce
Sum {
@ -70,12 +69,11 @@ ENTRY reduce.1 {
}
)";
return ParseAndReturnUnverifiedModule(hlo_string);
}
return ParseAndReturnVerifiedModule(hlo_string);
}
};
// TODO(b/72454718): XLA:GPU does not support executing code compiled without
// optimizations.
XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_GPU(Reduce)) {
XLA_TEST_P(ReduceWithLayoutTest, Reduce) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, GetParsedModule());
HloInstruction* reduce_instruction =
module->entry_computation()->root_instruction()->mutable_operand(0);

View File

@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@ -36,7 +35,7 @@ class ScatterTest : public HloTestBase {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_text, config));
ParseAndReturnVerifiedModule(hlo_text, config));
EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt));
}
};
@ -158,7 +157,7 @@ ENTRY main {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_text, config));
ParseAndReturnVerifiedModule(hlo_text, config));
auto actual = ExecuteAndTransfer(std::move(module), {&permutation});
Literal expected =
LiteralUtil::CreateR2<int32>({{3, 0, 2, 1}, {1, 3, 2, 0}, {3, 2, 0, 1}});

View File

@ -16,7 +16,6 @@ limitations under the License.
#include <array>
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
@ -29,7 +28,7 @@ namespace {
class TokenHloTest : public HloTestBase {};
XLA_TEST_F(TokenHloTest, SingleTokenInstruction) {
std::unique_ptr<HloModule> module = CreateNewUnverifiedModule();
std::unique_ptr<HloModule> module = CreateNewVerifiedModule();
auto builder = HloComputation::Builder(TestName());
builder.AddInstruction(HloInstruction::CreateToken());
@ -40,7 +39,7 @@ XLA_TEST_F(TokenHloTest, SingleTokenInstruction) {
}
XLA_TEST_F(TokenHloTest, TokenInTuple) {
std::unique_ptr<HloModule> module = CreateNewUnverifiedModule();
std::unique_ptr<HloModule> module = CreateNewVerifiedModule();
auto builder = HloComputation::Builder(TestName());
auto token = builder.AddInstruction(HloInstruction::CreateToken());
builder.AddInstruction(HloInstruction::CreateTuple({token}));
@ -54,7 +53,7 @@ XLA_TEST_F(TokenHloTest, TokenInTuple) {
}
XLA_TEST_F(TokenHloTest, TokenTree) {
std::unique_ptr<HloModule> module = CreateNewUnverifiedModule();
std::unique_ptr<HloModule> module = CreateNewVerifiedModule();
auto builder = HloComputation::Builder(TestName());
auto token0 = builder.AddInstruction(HloInstruction::CreateToken());
auto token1 = builder.AddInstruction(HloInstruction::CreateToken());
@ -222,7 +221,7 @@ ENTRY %AddDependency (p0: f32[], p1: f32[]) -> f32[] {
)";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_string, GetModuleConfigForTest()));
ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest()));
auto p0 = LiteralUtil::CreateR0<float>(10.0);
auto p1 = LiteralUtil::CreateR0<float>(3.0);
auto expected = LiteralUtil::CreateR0<float>(-156.0);
@ -243,7 +242,7 @@ ENTRY %AddDependency (p0: f32[]) -> f32[] {
)";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_string, GetModuleConfigForTest()));
ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest()));
auto p0 = LiteralUtil::CreateR0<float>(10.0);
auto expected = LiteralUtil::CreateR0<float>(420.0);
EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0}));
@ -261,7 +260,7 @@ ENTRY %AddDependency (p: f32[3]) -> f32[3] {
)";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_string, GetModuleConfigForTest()));
ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest()));
auto input = LiteralUtil::CreateR1<float>({1.0, 3.0, 7.0});
auto expected = LiteralUtil::CreateR1<float>({-1.0, -3.0, -7.0});
EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&input}));
@ -284,7 +283,7 @@ ENTRY %TupleShapedAddDependency (p0: f32[3], p1: f32[3]) -> f32[3] {
)";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_string, GetModuleConfigForTest()));
ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest()));
auto p0 = LiteralUtil::CreateR1<float>({3.0, 3.0, 47.0});
auto p1 = LiteralUtil::CreateR1<float>({1.0, -2.0, 2.0});
auto expected = LiteralUtil::CreateR1<float>({2.0, 5.0, 45.0});