[XLA] Fix the interpreter to use the dynamic dimension inference when run separately

PiperOrigin-RevId: 297875268
Change-Id: I97e735cfd57ad74122469d29afa606e922f566bd
This commit is contained in:
George Karpenkov 2020-02-28 10:33:52 -08:00 committed by TensorFlower Gardener
parent 037f8b1c00
commit 8d352c8b62
5 changed files with 79 additions and 9 deletions

View File

@ -104,9 +104,8 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
VLOG(1) << "Run backend " << hlo_module->name();
// Typically you would visit the HLO graph, building up a compiled equivalent
// In this case we are using an HloEvaluator at execution time, so we don't
// need to compile anything
TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
DynamicDimensionInference::Run(hlo_module.get()));
auto evaluator = absl::make_unique<HloEvaluator>();
evaluator->set_use_fast_path(
@ -115,8 +114,9 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
// Create executable from only the Hlo module.
std::unique_ptr<Executable> executable =
absl::make_unique<InterpreterExecutable>(std::move(hlo_module),
std::move(evaluator));
absl::make_unique<InterpreterExecutable>(
std::move(hlo_module), std::move(evaluator),
std::move(dynamic_dimension_inference));
return std::move(executable);
}

View File

@ -39,10 +39,17 @@ namespace interpreter {
InterpreterExecutable::InterpreterExecutable(
std::unique_ptr<HloModule> hlo_module,
std::unique_ptr<HloEvaluator> evaluator)
std::unique_ptr<HloEvaluator> evaluator,
absl::optional<DynamicDimensionInference> dynamic_dymension_inference)
: Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr,
/*hlo_profile_index_map=*/nullptr),
evaluator_(std::move(evaluator)) {}
evaluator_(std::move(evaluator)),
dynamic_dimension_inference_(std::move(dynamic_dymension_inference)) {
if (dynamic_dimension_inference_.has_value()) {
evaluator_->set_dynamic_dimension_inference(
&dynamic_dimension_inference_.value());
}
}
InterpreterExecutable::~InterpreterExecutable() {}

View File

@ -42,8 +42,10 @@ namespace interpreter {
// buffer allocation. Refer to interpreter/README.md for more.
class InterpreterExecutable : public Executable {
public:
InterpreterExecutable(std::unique_ptr<HloModule> hlo_module,
std::unique_ptr<HloEvaluator> evaluator);
InterpreterExecutable(
std::unique_ptr<HloModule> hlo_module,
std::unique_ptr<HloEvaluator> evaluator,
absl::optional<DynamicDimensionInference> dynamic_dymension_inference);
~InterpreterExecutable() override;
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
@ -60,6 +62,7 @@ class InterpreterExecutable : public Executable {
mutable tensorflow::mutex evaluator_lock_;
private:
absl::optional<DynamicDimensionInference> dynamic_dimension_inference_;
TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable);
};

View File

@ -2513,6 +2513,18 @@ xla_test(
],
)
xla_test(
name = "get_dimension_size_test",
srcs = ["get_dimension_size_test.cc"],
deps = [
":hlo_test_base",
":test_macros_header",
":xla_internal_test_main", # fixdeps: keep
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:test",
],
)
xla_test(
name = "triangular_solve_test",
srcs = ["triangular_solve_test.cc"],

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
namespace xla {
namespace {
class GetDimensionSizeTest : public HloTestBase {};
// Test that the interpreter can correctly compute get_dimension_size.
TEST_F(GetDimensionSizeTest, DoIt) {
const char* const kModuleStr = R"(
HloModule a_inference_call_110__.55
ENTRY %a_inference_call_110__.55 (arg0.1: f32[1,8], arg1.2: f32[8], arg2.3: f32[8]) -> s32[] {
%constant.37 = f32[] constant(1e-12)
%broadcast.38 = f32[1,1]{1,0} broadcast(f32[] %constant.37), dimensions={}
%arg0.1 = f32[1,8]{1,0} parameter(0), parameter_replication={false}
%reshape.4 = f32[1,8]{1,0} reshape(f32[1,8]{1,0} %arg0.1)
%convert.5 = f32[1,8]{1,0} convert(f32[1,8]{1,0} %reshape.4)
%constant.6 = f32[] constant(0)
%convert.7 = f32[] convert(f32[] %constant.6)
ROOT %get-dimension-size.13 = s32[] get-dimension-size(f32[1,8]{1,0} %convert.5), dimensions={1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr));
EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0.01, 0.01}));
}
} // anonymous namespace
} // namespace xla