[XLA] Fix the interpreter to use the dynamic dimension inference when run separately
PiperOrigin-RevId: 297875268 Change-Id: I97e735cfd57ad74122469d29afa606e922f566bd
This commit is contained in:
parent
037f8b1c00
commit
8d352c8b62
@ -104,9 +104,8 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
|
||||
|
||||
VLOG(1) << "Run backend " << hlo_module->name();
|
||||
|
||||
// Typically you would visit the HLO graph, building up a compiled equivalent
|
||||
// In this case we are using an HloEvaluator at execution time, so we don't
|
||||
// need to compile anything
|
||||
TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
|
||||
DynamicDimensionInference::Run(hlo_module.get()));
|
||||
|
||||
auto evaluator = absl::make_unique<HloEvaluator>();
|
||||
evaluator->set_use_fast_path(
|
||||
@ -115,8 +114,9 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
|
||||
|
||||
// Create executable from only the Hlo module.
|
||||
std::unique_ptr<Executable> executable =
|
||||
absl::make_unique<InterpreterExecutable>(std::move(hlo_module),
|
||||
std::move(evaluator));
|
||||
absl::make_unique<InterpreterExecutable>(
|
||||
std::move(hlo_module), std::move(evaluator),
|
||||
std::move(dynamic_dimension_inference));
|
||||
|
||||
return std::move(executable);
|
||||
}
|
||||
|
@ -39,10 +39,17 @@ namespace interpreter {
|
||||
|
||||
InterpreterExecutable::InterpreterExecutable(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
std::unique_ptr<HloEvaluator> evaluator)
|
||||
std::unique_ptr<HloEvaluator> evaluator,
|
||||
absl::optional<DynamicDimensionInference> dynamic_dymension_inference)
|
||||
: Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr,
|
||||
/*hlo_profile_index_map=*/nullptr),
|
||||
evaluator_(std::move(evaluator)) {}
|
||||
evaluator_(std::move(evaluator)),
|
||||
dynamic_dimension_inference_(std::move(dynamic_dymension_inference)) {
|
||||
if (dynamic_dimension_inference_.has_value()) {
|
||||
evaluator_->set_dynamic_dimension_inference(
|
||||
&dynamic_dimension_inference_.value());
|
||||
}
|
||||
}
|
||||
|
||||
InterpreterExecutable::~InterpreterExecutable() {}
|
||||
|
||||
|
@ -42,8 +42,10 @@ namespace interpreter {
|
||||
// buffer allocation. Refer to interpreter/README.md for more.
|
||||
class InterpreterExecutable : public Executable {
|
||||
public:
|
||||
InterpreterExecutable(std::unique_ptr<HloModule> hlo_module,
|
||||
std::unique_ptr<HloEvaluator> evaluator);
|
||||
InterpreterExecutable(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
std::unique_ptr<HloEvaluator> evaluator,
|
||||
absl::optional<DynamicDimensionInference> dynamic_dymension_inference);
|
||||
~InterpreterExecutable() override;
|
||||
|
||||
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
|
||||
@ -60,6 +62,7 @@ class InterpreterExecutable : public Executable {
|
||||
mutable tensorflow::mutex evaluator_lock_;
|
||||
|
||||
private:
|
||||
absl::optional<DynamicDimensionInference> dynamic_dimension_inference_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable);
|
||||
};
|
||||
|
||||
|
@ -2513,6 +2513,18 @@ xla_test(
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "get_dimension_size_test",
|
||||
srcs = ["get_dimension_size_test.cc"],
|
||||
deps = [
|
||||
":hlo_test_base",
|
||||
":test_macros_header",
|
||||
":xla_internal_test_main", # fixdeps: keep
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "triangular_solve_test",
|
||||
srcs = ["triangular_solve_test.cc"],
|
||||
|
48
tensorflow/compiler/xla/tests/get_dimension_size_test.cc
Normal file
48
tensorflow/compiler/xla/tests/get_dimension_size_test.cc
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class GetDimensionSizeTest : public HloTestBase {};
|
||||
|
||||
// Test that the interpreter can correctly compute get_dimension_size.
|
||||
TEST_F(GetDimensionSizeTest, DoIt) {
|
||||
const char* const kModuleStr = R"(
|
||||
HloModule a_inference_call_110__.55
|
||||
|
||||
ENTRY %a_inference_call_110__.55 (arg0.1: f32[1,8], arg1.2: f32[8], arg2.3: f32[8]) -> s32[] {
|
||||
%constant.37 = f32[] constant(1e-12)
|
||||
%broadcast.38 = f32[1,1]{1,0} broadcast(f32[] %constant.37), dimensions={}
|
||||
%arg0.1 = f32[1,8]{1,0} parameter(0), parameter_replication={false}
|
||||
%reshape.4 = f32[1,8]{1,0} reshape(f32[1,8]{1,0} %arg0.1)
|
||||
%convert.5 = f32[1,8]{1,0} convert(f32[1,8]{1,0} %reshape.4)
|
||||
%constant.6 = f32[] constant(0)
|
||||
%convert.7 = f32[] convert(f32[] %constant.6)
|
||||
ROOT %get-dimension-size.13 = s32[] get-dimension-size(f32[1,8]{1,0} %convert.5), dimensions={1}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0.01, 0.01}));
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace xla
|
Loading…
x
Reference in New Issue
Block a user