From 8d352c8b621ff519303f21e85759494ab89703d0 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Fri, 28 Feb 2020 10:33:52 -0800 Subject: [PATCH] [XLA] Fix the interpreter to use the dynamic dimension inference when run separately PiperOrigin-RevId: 297875268 Change-Id: I97e735cfd57ad74122469d29afa606e922f566bd --- .../xla/service/interpreter/compiler.cc | 10 ++-- .../xla/service/interpreter/executable.cc | 11 ++++- .../xla/service/interpreter/executable.h | 7 ++- tensorflow/compiler/xla/tests/BUILD | 12 +++++ .../xla/tests/get_dimension_size_test.cc | 48 +++++++++++++++++++ 5 files changed, 79 insertions(+), 9 deletions(-) create mode 100644 tensorflow/compiler/xla/tests/get_dimension_size_test.cc diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index d034762cb15..1649be2ca8f 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -104,9 +104,8 @@ StatusOr> InterpreterCompiler::RunBackend( VLOG(1) << "Run backend " << hlo_module->name(); - // Typically you would visit the HLO graph, building up a compiled equivalent - // In this case we are using an HloEvaluator at execution time, so we don't - // need to compile anything + TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference, + DynamicDimensionInference::Run(hlo_module.get())); auto evaluator = absl::make_unique(); evaluator->set_use_fast_path( @@ -115,8 +114,9 @@ StatusOr> InterpreterCompiler::RunBackend( // Create executable from only the Hlo module. std::unique_ptr executable = - absl::make_unique(std::move(hlo_module), - std::move(evaluator)); + absl::make_unique( + std::move(hlo_module), std::move(evaluator), + std::move(dynamic_dimension_inference)); return std::move(executable); } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index f82a439fdb0..a0ff1cb4e60 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -39,10 +39,17 @@ namespace interpreter { InterpreterExecutable::InterpreterExecutable( std::unique_ptr hlo_module, - std::unique_ptr evaluator) + std::unique_ptr evaluator, + absl::optional dynamic_dymension_inference) : Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr, /*hlo_profile_index_map=*/nullptr), - evaluator_(std::move(evaluator)) {} + evaluator_(std::move(evaluator)), + dynamic_dimension_inference_(std::move(dynamic_dymension_inference)) { + if (dynamic_dimension_inference_.has_value()) { + evaluator_->set_dynamic_dimension_inference( + &dynamic_dimension_inference_.value()); + } +} InterpreterExecutable::~InterpreterExecutable() {} diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index 1bea6773fdd..5df13dfb368 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -42,8 +42,10 @@ namespace interpreter { // buffer allocation. Refer to interpreter/README.md for more. class InterpreterExecutable : public Executable { public: - InterpreterExecutable(std::unique_ptr hlo_module, - std::unique_ptr evaluator); + InterpreterExecutable( + std::unique_ptr hlo_module, + std::unique_ptr evaluator, + absl::optional dynamic_dymension_inference); ~InterpreterExecutable() override; StatusOr ExecuteAsyncOnStream( @@ -60,6 +62,7 @@ class InterpreterExecutable : public Executable { mutable tensorflow::mutex evaluator_lock_; private: + absl::optional dynamic_dimension_inference_; TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable); }; diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 540a63405ef..23010d6ce70 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -2513,6 +2513,18 @@ xla_test( ], ) +xla_test( + name = "get_dimension_size_test", + srcs = ["get_dimension_size_test.cc"], + deps = [ + ":hlo_test_base", + ":test_macros_header", + ":xla_internal_test_main", # fixdeps: keep + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:test", + ], +) + xla_test( name = "triangular_solve_test", srcs = ["triangular_solve_test.cc"], diff --git a/tensorflow/compiler/xla/tests/get_dimension_size_test.cc b/tensorflow/compiler/xla/tests/get_dimension_size_test.cc new file mode 100644 index 00000000000..05ac332f4bd --- /dev/null +++ b/tensorflow/compiler/xla/tests/get_dimension_size_test.cc @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +class GetDimensionSizeTest : public HloTestBase {}; + +// Test that the interpreter can correctly compute get_dimension_size. +TEST_F(GetDimensionSizeTest, DoIt) { + const char* const kModuleStr = R"( +HloModule a_inference_call_110__.55 + +ENTRY %a_inference_call_110__.55 (arg0.1: f32[1,8], arg1.2: f32[8], arg2.3: f32[8]) -> s32[] { + %constant.37 = f32[] constant(1e-12) + %broadcast.38 = f32[1,1]{1,0} broadcast(f32[] %constant.37), dimensions={} + %arg0.1 = f32[1,8]{1,0} parameter(0), parameter_replication={false} + %reshape.4 = f32[1,8]{1,0} reshape(f32[1,8]{1,0} %arg0.1) + %convert.5 = f32[1,8]{1,0} convert(f32[1,8]{1,0} %reshape.4) + %constant.6 = f32[] constant(0) + %convert.7 = f32[] convert(f32[] %constant.6) + ROOT %get-dimension-size.13 = s32[] get-dimension-size(f32[1,8]{1,0} %convert.5), dimensions={1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0.01, 0.01})); +} + +} // anonymous namespace +} // namespace xla