[XLA] Fix the interpreter to use the dynamic dimension inference when run separately
PiperOrigin-RevId: 297875268 Change-Id: I97e735cfd57ad74122469d29afa606e922f566bd
This commit is contained in:
parent
037f8b1c00
commit
8d352c8b62
@ -104,9 +104,8 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
|
|||||||
|
|
||||||
VLOG(1) << "Run backend " << hlo_module->name();
|
VLOG(1) << "Run backend " << hlo_module->name();
|
||||||
|
|
||||||
// Typically you would visit the HLO graph, building up a compiled equivalent
|
TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
|
||||||
// In this case we are using an HloEvaluator at execution time, so we don't
|
DynamicDimensionInference::Run(hlo_module.get()));
|
||||||
// need to compile anything
|
|
||||||
|
|
||||||
auto evaluator = absl::make_unique<HloEvaluator>();
|
auto evaluator = absl::make_unique<HloEvaluator>();
|
||||||
evaluator->set_use_fast_path(
|
evaluator->set_use_fast_path(
|
||||||
@ -115,8 +114,9 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
|
|||||||
|
|
||||||
// Create executable from only the Hlo module.
|
// Create executable from only the Hlo module.
|
||||||
std::unique_ptr<Executable> executable =
|
std::unique_ptr<Executable> executable =
|
||||||
absl::make_unique<InterpreterExecutable>(std::move(hlo_module),
|
absl::make_unique<InterpreterExecutable>(
|
||||||
std::move(evaluator));
|
std::move(hlo_module), std::move(evaluator),
|
||||||
|
std::move(dynamic_dimension_inference));
|
||||||
|
|
||||||
return std::move(executable);
|
return std::move(executable);
|
||||||
}
|
}
|
||||||
|
@ -39,10 +39,17 @@ namespace interpreter {
|
|||||||
|
|
||||||
InterpreterExecutable::InterpreterExecutable(
|
InterpreterExecutable::InterpreterExecutable(
|
||||||
std::unique_ptr<HloModule> hlo_module,
|
std::unique_ptr<HloModule> hlo_module,
|
||||||
std::unique_ptr<HloEvaluator> evaluator)
|
std::unique_ptr<HloEvaluator> evaluator,
|
||||||
|
absl::optional<DynamicDimensionInference> dynamic_dymension_inference)
|
||||||
: Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr,
|
: Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr,
|
||||||
/*hlo_profile_index_map=*/nullptr),
|
/*hlo_profile_index_map=*/nullptr),
|
||||||
evaluator_(std::move(evaluator)) {}
|
evaluator_(std::move(evaluator)),
|
||||||
|
dynamic_dimension_inference_(std::move(dynamic_dymension_inference)) {
|
||||||
|
if (dynamic_dimension_inference_.has_value()) {
|
||||||
|
evaluator_->set_dynamic_dimension_inference(
|
||||||
|
&dynamic_dimension_inference_.value());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
InterpreterExecutable::~InterpreterExecutable() {}
|
InterpreterExecutable::~InterpreterExecutable() {}
|
||||||
|
|
||||||
|
@ -42,8 +42,10 @@ namespace interpreter {
|
|||||||
// buffer allocation. Refer to interpreter/README.md for more.
|
// buffer allocation. Refer to interpreter/README.md for more.
|
||||||
class InterpreterExecutable : public Executable {
|
class InterpreterExecutable : public Executable {
|
||||||
public:
|
public:
|
||||||
InterpreterExecutable(std::unique_ptr<HloModule> hlo_module,
|
InterpreterExecutable(
|
||||||
std::unique_ptr<HloEvaluator> evaluator);
|
std::unique_ptr<HloModule> hlo_module,
|
||||||
|
std::unique_ptr<HloEvaluator> evaluator,
|
||||||
|
absl::optional<DynamicDimensionInference> dynamic_dymension_inference);
|
||||||
~InterpreterExecutable() override;
|
~InterpreterExecutable() override;
|
||||||
|
|
||||||
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
|
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
|
||||||
@ -60,6 +62,7 @@ class InterpreterExecutable : public Executable {
|
|||||||
mutable tensorflow::mutex evaluator_lock_;
|
mutable tensorflow::mutex evaluator_lock_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
absl::optional<DynamicDimensionInference> dynamic_dimension_inference_;
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable);
|
TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -2513,6 +2513,18 @@ xla_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
xla_test(
|
||||||
|
name = "get_dimension_size_test",
|
||||||
|
srcs = ["get_dimension_size_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":hlo_test_base",
|
||||||
|
":test_macros_header",
|
||||||
|
":xla_internal_test_main", # fixdeps: keep
|
||||||
|
"//tensorflow/compiler/xla:debug_options_flags",
|
||||||
|
"//tensorflow/compiler/xla:test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
xla_test(
|
xla_test(
|
||||||
name = "triangular_solve_test",
|
name = "triangular_solve_test",
|
||||||
srcs = ["triangular_solve_test.cc"],
|
srcs = ["triangular_solve_test.cc"],
|
||||||
|
48
tensorflow/compiler/xla/tests/get_dimension_size_test.cc
Normal file
48
tensorflow/compiler/xla/tests/get_dimension_size_test.cc
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||||
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class GetDimensionSizeTest : public HloTestBase {};
|
||||||
|
|
||||||
|
// Test that the interpreter can correctly compute get_dimension_size.
|
||||||
|
TEST_F(GetDimensionSizeTest, DoIt) {
|
||||||
|
const char* const kModuleStr = R"(
|
||||||
|
HloModule a_inference_call_110__.55
|
||||||
|
|
||||||
|
ENTRY %a_inference_call_110__.55 (arg0.1: f32[1,8], arg1.2: f32[8], arg2.3: f32[8]) -> s32[] {
|
||||||
|
%constant.37 = f32[] constant(1e-12)
|
||||||
|
%broadcast.38 = f32[1,1]{1,0} broadcast(f32[] %constant.37), dimensions={}
|
||||||
|
%arg0.1 = f32[1,8]{1,0} parameter(0), parameter_replication={false}
|
||||||
|
%reshape.4 = f32[1,8]{1,0} reshape(f32[1,8]{1,0} %arg0.1)
|
||||||
|
%convert.5 = f32[1,8]{1,0} convert(f32[1,8]{1,0} %reshape.4)
|
||||||
|
%constant.6 = f32[] constant(0)
|
||||||
|
%convert.7 = f32[] convert(f32[] %constant.6)
|
||||||
|
ROOT %get-dimension-size.13 = s32[] get-dimension-size(f32[1,8]{1,0} %convert.5), dimensions={1}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0.01, 0.01}));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
} // namespace xla
|
Loading…
x
Reference in New Issue
Block a user