STT-tensorflow/tensorflow/compiler/xla/tools/run_hlo_module.h
Adrian Kuegel 3d21ad0e16 Open source run_hlo_module.
PiperOrigin-RevId: 284166663
Change-Id: I395f6a0a8efeb60784bdcca4e5227f0ef470f6f7
2019-12-06 05:34:57 -08:00

77 lines
2.9 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_RUN_HLO_MODULE_H_
#define TENSORFLOW_COMPILER_XLA_TOOLS_RUN_HLO_MODULE_H_
#include <functional>
#include <random>
#include <string>
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/stream_executor/platform.h"
namespace xla {
// Command-line options to this tool. See main() in run_hlo_module_main.cc for
// descriptions of these fields.
struct RunHloModuleOptions {
RunHloModuleOptions()
: platform(""),
reference_platform("default"),
print_literals(false),
run_test_hlo_passes(true),
run_reference_hlo_passes(true),
use_large_float_range(true),
// TODO(b/68721786): These tolerances are set to match the values in the
// isolation test. The goal is to lower these to 0.001.
abs_error_bound(0.1),
rel_error_bound(0.1),
input_format("hlo"),
input_module(""),
iterations(1) {}
std::string platform;
std::string reference_platform;
bool print_literals;
bool run_test_hlo_passes;
bool run_reference_hlo_passes;
bool use_large_float_range;
float abs_error_bound;
float rel_error_bound;
std::string input_format;
std::string input_module;
int iterations;
};
// Reads a HloModule from 'hlo_filename', runs it on the platform with the name
// 'test_platform_name', and if 'reference_platform_name' is non-empty, it also
// runs it on the platform with the name 'reference_platform_name' and compares
// the results. 'reference_module_modifier_hook' can be used to transform the
// HloModule before it is run on the reference platform. This may be necessary
// to match the numerics of the test platform.
::testing::AssertionResult RunAndCompare(
const std::string& hlo_filename, const std::string& test_platform_name,
const std::string& reference_platform_name, std::minstd_rand0* engine,
const RunHloModuleOptions& options,
std::function<Status(const HloModule&,
const ::stream_executor::Platform::Id&, HloModule*)>
reference_module_modifier_hook = {});
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TOOLS_RUN_HLO_MODULE_H_