Branch 160346151 (#11094)
* Properly handle ops that don't have a CPU kernel PiperOrigin-RevId: 159655906 * Selected BUILD cleanup in tensorflow/contrib/... PiperOrigin-RevId: 159673079 * Remove redundant `get` calls on smart pointers PiperOrigin-RevId: 159675809 * PiperOrigin-RevId: 159698321 * Migrate kernels to boosted_trees. PiperOrigin-RevId: 159698656 * Fix a bug in the memory optimizer when two inputs to a node are both recomputed PiperOrigin-RevId: 159700457 * Fixed memory leak that can be triggered by a failed node evaluation PiperOrigin-RevId: 159707380 * Updates get_started tutorial. PiperOrigin-RevId: 159709158 * [XLA] Remove unused factory in local_service PiperOrigin-RevId: 159712806 * Fix typo in docstring PiperOrigin-RevId: 159714414 * Migrate ops for new version of TensorForest. PiperOrigin-RevId: 159718610 * Added parameterized tests to reduce window tests. PiperOrigin-RevId: 159721784 * Use C API to implement Operation.device property PiperOrigin-RevId: 159723490 * Several Estimator changes: - support configurable input_fn calling in Estimator subclasses. - pass params and config to the input_fn. - allow callables for model_fn and input_fn. PiperOrigin-RevId: 159725554 * Fixed the scalar output for shard api when outputs_from_all_shards=True. PiperOrigin-RevId: 159726444 * Automated g4 rollback of changelist 159718610 PiperOrigin-RevId: 159728380 * Adding missing deps to targets in llvm.BUILD. This was only working in non-sandboxed builds. PiperOrigin-RevId: 159729295 * [XLA:HLO] Move sequence functions from hlo_ordering.h to hlo_scheduling.h. This is required for upcoming changes to convert the sequence creation functions (and HeapSimulator and BufferAssignment) over to using the new Hlo{Dataflow,Alias}Analysis. It's required because otherwise there's a dependency cycle: Hlo{Dataflow,Alias}Analysis depends on HloOrdering CreateMemoryMinimizingSequence will depend on Hlo{Dataflow,Alias}Analysis There's already a cycle here, if both HloOrdering and CreateMemoryMinimizingSequence are in the same file. Also note that: MinimumMemoryForSequence depends on HeapSimulator HeapSimulator will depend on Hlo{Dataflow,Alias}Analysis Hlo{Dataflow,Alias}Analysis depends on HloOrdering Splitting out the sequence functions resolves the cycle. Refactoring only; no functional changes. PiperOrigin-RevId: 159731836 * [XLA:HLO] Split Hlo{Value,Buffer} out of Hlo{Dataflow,Alias}Analysis. This will make dependencies cleaner for upcoming CLs that will convert HeapSimulator and HloOrdering to use the new analyses. No change in functionality. PiperOrigin-RevId: 159737265 * Internal change PiperOrigin-RevId: 159738215 * Suggest people need to do some build environment ./configur'ing. Fixes #4279 PiperOrigin-RevId: 159738412 * Rewrite SameDefinedShape function in ShapeRefiner PiperOrigin-RevId: 159745894 * [XLA] Remove xla_cpu_*_eigen flags from CPU backends. These flags are currently de-facto unused; parallelism should be controlled through the cpu_parallel backend. For configuring Eigen, if needed, the options should be piped more directly to the code. PiperOrigin-RevId: 159746509 * Updates layers tutorial and corresponding example. PiperOrigin-RevId: 159749528 * Further BUILD cleanup PiperOrigin-RevId: 159749869 * Use more efficient squared_difference PiperOrigin-RevId: 159751209 * Add log_step_count_steps to RunConfig and allow it to flow to the MonitoredSession. PiperOrigin-RevId: 159753935 * [XLA] Remove xla_hlo_test_generate_hlo_graph, which is now redundant. PiperOrigin-RevId: 159755688 * Do not use SSE4.1 instructions on Android builds. PiperOrigin-RevId: 159756104 * Add nonpublic helper `tf.distributions.util.tridiag` op. PiperOrigin-RevId: 159757904 * [XLA] Remove dead "in-client" code. Remove Service::runs_in_client_process_ field and it's dead user. This was previously used by the "InProcess" methods which have been replaced with the LocalClient API. PiperOrigin-RevId: 159759455 * [tf contrib seq2seq] Add monotonic attention mechanisms * Add monotonic_attention and safe_cumprod helper functions. * Add _BaseMonotonicAttentionMechanism base class. * Add BahdanauMonotonicAttention and LuongMonotonicAttention classes. These attention mechanisms are proposed in Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, "Online and Linear-Time Attention by Enforcing Monotonic Alignments." ICML 2017. https://arxiv.org/abs/1704.00784 PiperOrigin-RevId: 159760073 * Add ability for argmax to output int32 indices. Default remains int64. Change is made in a backwards and forward compatible manner, since we add a new attribute with a default that remains the same, and simply register a few new kernels. PiperOrigin-RevId: 159761347 * Automated g4 rollback of changelist 159746509 PiperOrigin-RevId: 159763112 * Raise ValueError if invalid dtype for random_uniform. PiperOrigin-RevId: 159764956 * Internal change. PiperOrigin-RevId: 159769520 * Support zero shapes for random_poisson. This matches random_uniform. PiperOrigin-RevId: 159771215 * Blacklist the quantized ops since they have too many issues (incorrect shape functions, memory corruptions, ...) PiperOrigin-RevId: 159772801 * Fixed the shape functions of the QuantizedAdd and QuantizedMul ops PiperOrigin-RevId: 159772841 * Switch from assigning namedtuple.__new__.__defaults__ to overwriting __new__. Assigning __defaults__ relies on an implementation detail of CPython, confuses type checkers (and developers :)), and is error-prone since it doesn't make the relationship between parameter names and default values explicit. This CL switches to overloading __new__ instead. PiperOrigin-RevId: 159773922 * Made sure that we can call the constant folding code twice safely. PiperOrigin-RevId: 159781607 * Added batch_matmul op dependence to android_extended_ops PiperOrigin-RevId: 159787178 * Fixes a TODO in head_test. PiperOrigin-RevId: 159789178 * When configuring per-session thread pools, allow a pool to be a global pool. This allows a division between large and small pools, without needing to make new pool for each session. PiperOrigin-RevId: 159789678 * Add a multi-head TensorForest estimator. PiperOrigin-RevId: 159820487 * Have RestoreV2's shape fn set all outputs to unknown shape. PiperOrigin-RevId: 159835723 * VectorExponential added to distributions. PiperOrigin-RevId: 159840822 * Fold as many nodes as possible instead of giving up if there is any error. PiperOrigin-RevId: 159841935 * Removed deprecated summary usage from estimators. Made name_space usage consistent. PiperOrigin-RevId: 159846928 * Adding missing license notice to toolchain build files PiperOrigin-RevId: 159847551 * [XLA] Remove unused flags and move debugging flag to debug options. PiperOrigin-RevId: 159849759 * Fixes some docstrings in feature_column. PiperOrigin-RevId: 159850619 * TpuEstimator: Replicate the input_fn to the worker CPU for each shard. The batch size is configured as follows: The user may specify a global batch size in their hyperparameters. If the 'batch_size' field is set, then we convert the global batch size into a per-shard batch size by dividing by num_shards before running their input_fn. PiperOrigin-RevId: 159851773 * Modify beam search decoder to use symbolic shape for vocab size if the static shape is not present. PiperOrigin-RevId: 159852297 * Generalize cluster initialization to span multiple mini-batches if necessary. PiperOrigin-RevId: 159852557 * Use a single threaded session for SDCALinearRegressorTest to avoid incorrect threading test failures (tsan). PiperOrigin-RevId: 159852818 * Migrate ops for new version of TensorForest. PiperOrigin-RevId: 159852889 * Replaced constant inputs with variables to ensure most of the graph doesn't get optimized away PiperOrigin-RevId: 159853171 * For candidate sampling, add facility to colocate the logit computation with the sharded embeddings. PiperOrigin-RevId: 159854706 * Added a utility to create parsing spec for regressors (canned estimator) PiperOrigin-RevId: 159855254 * Fix cuda_kernel_helper_test. std::numeric_limits<int32>::max() doesn't pass, so I didn't use that. PiperOrigin-RevId: 159869169 * In tfcompile, prune nodes that are not reachable from the fetches before building the Graph. This allows loading a graph that contains ops not needed for the compiled binary. PiperOrigin-RevId: 159869692 * Fix bugs related to distributions over integers. - Ensure that the max number of categories does not exceed largest integer-form float. - Make dtype inference consistent between Categorical and Multinomial distributions. - Improve documentation to better reflect that the Categorical distribution is analogous to `argmax{OneHotCategorical}` (itself being identical to `argmax{Multinomial(p,n=1)}` but not Multinomial. - Fix validation_args Heisenberg uncertainty: only validation logic should live under self.validate_args. E.g., validate_args=True would sometimes imply `x=floor(x)` which changes behavior thus making debugging impossible because enabling validation *changes* values. - Corrected `Geometric` swapping of validate_args` and `allow_nan_stats` default-values. Fixes #10149 PiperOrigin-RevId: 159872532 * Make HloModule clonable This CL makes HloModule clonable, which is necessary when we want to run the same compilation twice with the same input. PiperOrigin-RevId: 159874256 * Internal change. PiperOrigin-RevId: 159876942 * Implement alternative `monte_carlo.expectation_v2`. This function implements the reparameterization and score-gradient tricks and does not depend on tf.Distribution like inputs. PiperOrigin-RevId: 159877923 * In SE_ASSIGN_OR_RETURN change ConsumeValueOrDie to the preferred std::move ValueOrDie. PiperOrigin-RevId: 159879754 * If rank is unknown, do not add output shapes to transpose nodes. PiperOrigin-RevId: 159879840 * Move sparse_fill_empty_rows to new, *significantly* faster, C++ kernel for everyone. Also fix a bug in the C++ op when the input ST has 0 elements. PiperOrigin-RevId: 159880044 * Add support of label_keys to DebugClassifier PiperOrigin-RevId: 159883986 * Register devices under their legacy names Because some higher level APIs continue to use the legacy name format, when using ClusterSpec propagation, we need to ensure that we register the devices under their legacy names as well as their canonical names. PiperOrigin-RevId: 159885777 * [BatchNorm] Minor fixes to TF doc PiperOrigin-RevId: 159886125 * Generating TBAA metadata causes the LLVM to miscompile after https://reviews.llvm.org/rL305938). Disable TBAA (to stop the miscompiles) while we fix the root issue. PiperOrigin-RevId: 159895736 * Improve score-trick to be a valid Csiszar f-Divergence yet numerically stable. PiperOrigin-RevId: 159896013 * Support advisor in all places (Command line, APIs) Add expensive operation checker PiperOrigin-RevId: 159897279 * Added canned estimators to Tensorflow library. List of added estimators: * DNNClassifier * DNNRegressor * LinearClassifer * LinearRegressor * DNNLinearCombinedClassifier * DNNLinearCombinedRegressor PiperOrigin-RevId: 159898954 * Alligned how model-fns handled params among linear/dnn/combined estimators. PiperOrigin-RevId: 159899925 * Fixed cmake tests. PiperOrigin-RevId: 159901417 * [XLA:CPU] Add VLOGs to cpu_compiler.cc PiperOrigin-RevId: 159902919 * Make occurence (op run times and op definition) selectable in all views to address the loop problem. When a node is in loop, its execution times are accumulated, its run times will increase. PiperOrigin-RevId: 159912429 * [XLA] Small error message improvement in binop shape inference. PiperOrigin-RevId: 159920109 * Follow upstream API change from r306058. PiperOrigin-RevId: 159938416 * [TF:XLA] Update LLVM to upstream revision r306085. PiperOrigin-RevId: 159946562 * [XLA] Remove unused xla_cpu flag and move another to DebugOptions. PiperOrigin-RevId: 159952124 * Updates linear.md tutorial PiperOrigin-RevId: 159956867 * Add TraceMe instrumentation of RunStep in GRPC distributed runtime. A unique ID is added to each RunStep call that allows the client and server events to be correlated. PiperOrigin-RevId: 159956950 * [XLA] Add general F32 implementation for ReducePrecision operation. This only tests with parameter inputs (which is needed to ensure we actually test on GPUs as well as CPUs); there's no point in separately testing with constants. PiperOrigin-RevId: 159961430 * Java: NativeLibrary: Fix URL in error message. And add some detail. Inspired by #11015 PiperOrigin-RevId: 159962478 * Increase rtol for util_test. PiperOrigin-RevId: 159971136 * Re-enable IR dumping for the sequential CPU backend. PiperOrigin-RevId: 159974126 * tfdbg: a few minor fixes and improvements * Let DumpingDebugWrapperSession and DumpingDebugHook create session_root if it doesn't exist * Add README.md to tensorflow/python/debug * Add section "Debugging Keras Models with TFDBG" in debugger.md PiperOrigin-RevId: 159976070 * Add None check for save_path when restoring checkpoints as if something is wrong in tf.train.latest_checkpoint, it will often return None and it's nice to have a common sense check in restore for this. This way log.error says what has happened. PiperOrigin-RevId: 159979481 * Don't crash if a metagraph fails to load. PiperOrigin-RevId: 159981628 * Prepare to not include node_def.proto.h in node_def_util.h The goal is to make kernels mostly independent of proto headers, which will let us lock down our .so imports. This CL makes a bunch of .cc files either include node_def.proto.h themselves or not need the definition of NodeDef; a second CL will make node_def_util.h not include node_def.proto.h. RELNOTES: n/a PiperOrigin-RevId: 159982117 * Add a few diagnostic flags to help narrow down issues with the LLVM backends. PiperOrigin-RevId: 159982441 * Updated wide-n-deep tutorial code to use core version of estimators and feature-columns. PiperOrigin-RevId: 159984663 * Modify ControlFlowContext to also respect import_scope in 'values_' and keys of 'external_values_' PiperOrigin-RevId: 159985290 * Add item's graph to partition_graphs in virtual cluster's run method. Put node op name in timeline_label instead of node_name. PiperOrigin-RevId: 159986583 * Use short-proto for logging purposes. A short proto will be output on a single log line, making it easier for certain automated tools to handle. PiperOrigin-RevId: 159994005 * Sinh, ArcSinh, Cosh, LogCosh functions added to distributions/python/ops/trig. Care is taken to ensure a fair bit of stability. PiperOrigin-RevId: 159995514 * Updates some examples in examples/learn. PiperOrigin-RevId: 159996397 * Add kernel tests for boosted_trees. PiperOrigin-RevId: 160002696 * Avoid doing unecessary work in the OptimizeGraph() function whenever possible PiperOrigin-RevId: 160003173 * Use std::shared_ptr instead of core::RefCounted for Node::Properties Also changes Node::Properties to a struct and removes underscores from public member variables. This change should make it easier to work with Properties moving forward as the refcount will be automatically updated. PiperOrigin-RevId: 160003281 * Make the CPU compiler dump optimized IR along with the unoptimized IR. PiperOrigin-RevId: 160005257 * Disable flaky run_metadata_test. PiperOrigin-RevId: 160015399 * BUILD cleanup in tensorflow/tools/... PiperOrigin-RevId: 160018623 * SinhArcSinh bijector added. This two-parameter diffeomorphism from R --> R allows for skewness and fatter or thinner tails. See docstring and also http://oro.open.ac.uk/22510/1/sinhasinh.pdf PiperOrigin-RevId: 160019380 * Avoid hardcoded names for temporary files in tests. These tests (and examples that are run as tests) were using hardcoded names for temporary files. This failed when multiple copies of these tests were run in parallel, or even successively by different users, where the second run could not overwrite files left by the first. This change uses the TEST_TMPDIR environment variable used by bazel's test runner to choose a temporary directory. If that directory is not set, /tmp is used, as before. PiperOrigin-RevId: 160026924 * Fix multinomial doc-string, input arg logits expects to log-probabilities and not log-odds. PiperOrigin-RevId: 160036709 * Made TensorFlow documentation on LSTMs slightly more accurate. PiperOrigin-RevId: 160047054 * Follow LLVM/ORC upstream API change in r306166. PiperOrigin-RevId: 160108102 * Move resampler from sonnet to contrib. PiperOrigin-RevId: 160134565 * [TPUEstimator] Make input_fn invoked properly with eval on CPU. PiperOrigin-RevId: 160151890 * Deletes iris_val_based_early_stopping example, which uses deprecated ValidationMonitor. PiperOrigin-RevId: 160154863 * [XLA] Move HLO dumping flags from service_flags to debug_options_flags This also removes the duplication in the xla_generate_hlo_graph flag. This CL also moves the actual dumping logic from Executable to the hlo_graph_dumper namespace, where it belongs; this is in preparation for removing the hlo_dumper callback altogether, since it isn't serving any role beyond what a direct call to hlo_graph_dumper would have (b/62872831 has more details). PiperOrigin-RevId: 160154869 * Fix missing variable unref Direct leak of 56 byte(s) in 1 object(s) allocated from: #0 0xf5ee272 in operator new(unsigned long) (/build/cas/5d2/5d2be3b530580573ff7269adcab7cbac+0xf5ee272) #1 0x1b51394c in tensorflow::AssignVariableOp<Eigen::ThreadPoolDevice, float>::Compute(tensorflow::OpKernelContext*)::'lambda'(tensorflow::Var**)::operator()(tensorflow::Var**) const (/build/cas/5d2/5d2be3b530580573ff7269adcab7cbac+0x1b51394c) #2 0x1b5136c0 in std::_Function_handler<tensorflow::Status (tensorflow::Var**), tensorflow::AssignVariableOp<Eigen::ThreadPoolDevice, float>::Compute(tensorflow::OpKernelContext*)::'lambda'(tensorflow::Var**)>::_M_invoke(std::_Any_data const&, tensorflow::Var**) (/build/cas/5d2/5d2be3b530580573ff7269adcab7cbac+0x1b5136c0) #3 0x1b50b289 in std::function<tensorflow::Status (tensorflow::Var**)>::operator()(tensorflow::Var**) const (/build/cas/5d2/5d2be3b530580573ff7269adcab7cbac+0x1b50b289) #4 0x1b50af88 in tensorflow::Status tensorflow::ResourceMgr::LookupOrCreate<tensorflow::Var>(basic_string<char, std::char_traits<char>, std::allocator<char> > const&, basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tensorflow::Var**, std::function<tensorflow::Status (tensorflow::Var**)>) (/build/cas/5d2/5d2be3b530580573ff7269adcab7cbac+0x1b50af88) #5 0x1b50ac10 in tensorflow::Status tensorflow::LookupOrCreateResource<tensorflow::Var>(tensorflow::OpKernelContext*, tensorflow::ResourceHandle const&, tensorflow::Var**, std::function<tensorflow::Status (tensorflow::Var**)>) (/build/cas/5d2/5d2be3b530580573ff7269adcab7cbac+0x1b50ac10) #6 0x1b512f1e in tensorflow::AssignVariableOp<Eigen::ThreadPoolDevice, float>::Compute(tensorflow::OpKernelContext*) (/build/cas/5d2/5d2be3b530580573ff7269adcab7cbac+0x1b512f1e) #7 0x1d1881c7 in tensorflow::ThreadPoolDevice::Compute(tensorflow::OpKernel*, tensorflow::OpKernelContext*) (/build/cas/5d2/5d2be3b530580573ff7269adcab7cbac+0x1d1881c7) #8 0xf96e0fe in tensorflow::KernelAndDevice::Run(std::vector<tensorflow::Tensor, std::allocator<tensorflow::Tensor> >*, std::vector<tensorflow::Tensor, std::allocator<tensorflow::Tensor> >*) (/build/cas/5d2/5d2be3b530580573ff7269adcab7cbac+0xf96e0fe) #9 0xf94f9c8 in TFE_Execute (/build/cas/5d2/5d2be3b530580573ff7269adcab7cbac+0xf94f9c8) #10 0xf94356d in TFE_Py_Execute(TFE_Context*, int, char const*, tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>*, _object*, tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2>*, TF_Status*) (/build/cas/5d2/5d2be3b530580573ff7269adcab7cbac+0xf94356d) PiperOrigin-RevId: 160160101 * Simplify strided_slice's shape handling Now that TensorShape and PartialTensorShape share memory representations, there's no need for an abstract class that makes TensorShape and TensorShapeProto look the same. RELNOTES: n/a PiperOrigin-RevId: 160161618 * Added a tool to report the static information that can be extracted from a TF model. PiperOrigin-RevId: 160162256 * Properly handle RefEnter, RefExit and RefNextIteration nodes. PiperOrigin-RevId: 160162338 * Switch tfprof to use proto3 PiperOrigin-RevId: 160163483 * Fixes to cuda_config.h. PiperOrigin-RevId: 160168545 * Update ops-related pbtxt files. PiperOrigin-RevId: 160171187 * Adds notes to prevent overfitting for Experiment continous_train_and_eval. PiperOrigin-RevId: 160172692 * Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 160172985 * Merge changes from github. END_PUBLIC Note: this CL will break builds. cl/159887762 to follow to fix all the breakages. --- Commit2336cdf7f
authored by Maxwell Paul Brickner<mbrickn@users.noreply.github.com> Committed by gunan<gunan@google.com>: Updated link to use HTTPS (#10998) Howdy! I just updated a link to use https instead of http. Thanks! --- Commitad0892df1
authored by Luke Iwanski<luke@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] Fixes run_metadata_test for SYCL This test is designed to test CUDA specific behavior --- Commit6b37a0725
authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Update comments --- Commit1699d904a
authored by John Lawson<john@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] Fixes CUDA specific test run on SYCL (#56) The testBadParentValuesOnGPU should only be run on CUDA devices, as the test checks for particular CUDA behaviour. We don't actually provide a SYCL kernel for GatherTree and so it's not a problem that the tests don't target SYCL. --- Commit3c1946230
authored by myPrecious<Moriadry@users.noreply.github.com> Committed by Shanqing Cai<cais@google.com>: Java API to get the size of specified input list of operations. (#10865) * Java API to get the size of specified input list of operations * remove unnecessary explain to avoid bring a new term to users. --- Commite911c7480
authored by Luke Iwanski<luke@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] REGISTER -> REGISTER6 --- Commitfbf6c4cec
authored by superryanguo<superryanguo@gmail.com> Committed by superryanguo<superryanguo@gmail.com>: Simplify the Quickstart section with the weblink is better --- Commit72e2918cc
authored by Taehoon Lee<taehoonlee@snu.ac.kr> Committed by Taehoon Lee<taehoonlee@snu.ac.kr>: Fix typos --- Commit90c4406b7
authored by Rishabh Patel<patelrishabh@users.noreply.github.com> Committed by GitHub<noreply@github.com>: Correct the learning rate as per the code snippet --- Commit03da61134
authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Update ir_array.cc --- Commit2df6cd3ac
authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Another try --- Commitaf0cbace1
authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Transpose to go through Eigen (#10321) --- Commitfc7361081
authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Registers RGBToHSV and HSVToRGB (#91) (#10848) * [OpenCL] Added RGBToHSV and HSVToRGB * Aligning '\' --- Commit832894ef8
authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Registers AdjustContrastv2 (#10949) * [OpenCL] Registers AdjustContrastv2 (#93) * [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL (#96) * [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL * simplified to #ifndef * Changed to "#if GOOGLE_CUDA" * Update adjust_contrast_op_benchmark_test.cc * Added comments --- Commitcb4c2f8d1
authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Make TransferBufferToInFeed not virual so it compiles. --- Commite89f04d80
authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix calling Literal member functions. --- Commit15a8df724
authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix mac build clone from meheff's change: [XLA] Change return type of DeviceAssignment::Deserialize to fix build breakage on mac. The mac build had the following error: error: incomplete type 'xla::DeviceAssignment' used in type trait expression This was due to a static method returning a StatusOr<DeviceAssignment> inside of the definition of DeviceAssignment. --- Commita54d43fa4
authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Replace LiteralUtil to Literal in compiler/plugin/executor --- Commit88a6bb80c
authored by Guenther Schmuelling<guschmue@microsoft.com> Committed by Guenther Schmuelling<guschmue@microsoft.com>: expand inline for debug builds to limit number of symbols --- Commit62fb49d31
authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix visibility error for contrib/remote_fused_graph/pylib/BUILD. --- Commit4c75252f2
authored by Mark Neumann<markn@allenai.org> Committed by Mark Neumann<markn@allenai.org>: fix initial test values to avoid numerical instability --- Commitb58d98353
authored by sj6077<epik03sj@gmail.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: Fixes of AutoParallel bug (#10368) * Fix the bug that auto_parallel could replicate variable snapshot name * Use NodeName in grappler:utils instead of substr, convert variables->variable_def of grappler item * remove variable_def from grappler item, exclude snapshot nodes from dont_replicate_nodes in auto_parallel --- Commita286b7db8
authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Make debug_test slice integer. --- Commit97fcfdfa6
authored by Toby Boyd<tobyboyd@google.com> Committed by GitHub<noreply@github.com>: Fixed path to seq2seq.py and minor formatting --- Commit63c1befb8
authored by Anish Shah<shah.anish07@gmail.com> Committed by Anish Shah<shah.anish07@gmail.com>: Improve docs for tf.nn.depthwise_conv2d_native --- Commit8d42202b2
authored by Yong Tang<yong.tang.github@outlook.com> Committed by Yong Tang<yong.tang.github@outlook.com>: Fix mismatched delete in mkl_tfconv_op.cc This fix fixes mismatched new[]-delete in mkl_tfconv_op.cc (the file went through clang-format so there are some additional changes) Signed-off-by: Yong Tang <yong.tang.github@outlook.com> --- Commit26301bd55
authored by Danny Goodman<goodman.danny@gmail.com> Committed by Danny Goodman<goodman.danny@gmail.com>: fix error format --- Commitb3f33ad46
authored by Yao Zhang<yaozhang@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Make changes to prepare for the fused option of batch norm to be set to None (None means using fused batch norm if possible). PiperOrigin-RevId: 159649743 --- Commita4a469832
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Add tests for select ops and while loops that produce tuples that contain predicates. PiperOrigin-RevId: 159645900 --- Commit980d3f2be
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Use C API to implement Operation.name property This name property is used in many existing tests including those that already run with C API enabled (math_ops_test, framework_ops_test, session_test, session_partial_run_test, math_ops_test_gpu, etc). PiperOrigin-RevId: 159645767 --- Commit26239c706
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Previously we didn't have an implementation of BatchNormInference and BatchNormTraining, which gives a linker error if anyone ever tries to call that. A dummy implementation is friendlier than a linker error. PiperOrigin-RevId: 159645612 --- Commitf671c5caa
authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: BEGIN_PUBLIC Automated g4 rollback of changelist 159570549 PiperOrigin-RevId: 160182040 * Update ops-related pbtxt files. PiperOrigin-RevId: 160183349 * Merge changes from github followup. PiperOrigin-RevId: 160183498 * Automated g4 rollback of changelist 160183498 PiperOrigin-RevId: 160189134 * Automated g4 rollback of changelist 160182040 PiperOrigin-RevId: 160190881 * [XLA] Disallow fuse X into Y if there are paths from X to Y which don't fuse Just because X can fuse into all of its consumers does not mean that those consumers can fuse into anything. Depending on the structure of the graph, this can either result in no performance win at all or, in the case of recurrent networks, a big performance deficit. PiperOrigin-RevId: 160194058 * First draft of Tensors segment of the programmer's guide. PiperOrigin-RevId: 160196550 * First draft of variables unit of programmer's guide. PiperOrigin-RevId: 160196566 * Make xla::Literal moveable. PiperOrigin-RevId: 160197273 * Automated g4 rollback of changelist 159897279 PiperOrigin-RevId: 160198598 * Updates text_classification example. PiperOrigin-RevId: 160200457 * Fix backward compatibility test broken by rollback. PiperOrigin-RevId: 160222187 * Support advisor in all places (Command line, APIs) Add expensive operation checker PiperOrigin-RevId: 160222348 * [XLA] Simplify the fusion heuristic We had two different aspects of the fusion heuristic: - Don't fuse a producer into a consumer if there exists a path from the producer to the consumer which cannot be fused. - Don't fuse a producer into a consumer if any consumer of the producer cannot fuse. These can be combined into one, simpler, heuristic. PiperOrigin-RevId: 160222771 * Automated g4 rollback of changelist 160196566 PiperOrigin-RevId: 160222930 * Automated g4 rollback of changelist 160196550 PiperOrigin-RevId: 160222942 * Lets the HParam parser also accept True and False as inputs, since that's how python prints booleans. PiperOrigin-RevId: 160234658 * Automated g4 rollback of changelist 155070869 PiperOrigin-RevId: 160249526 * [TF:XLA] Inline the sigmoid operation instead of mapping it elementwise. PiperOrigin-RevId: 160274436 * Make sure all convolution tests are testing non-trivial cases, i.e. where not all inputs are 0, leading to an all-0 output, which masks most possible bugs. We do not check-fail on 0-sized dimensions as tests for these special cases exist. PiperOrigin-RevId: 160274593 * Explicitly use "dns" URI scheme when using DNS names or literal IP addresses with gRPC. This avoids problems in environments in which the default URI scheme is something other than "dns". PiperOrigin-RevId: 160276862 * Add RWSE (root weighted squared error) to the WALS estimator. PiperOrigin-RevId: 160276937 * Don't include node_def.proto.h in node_def_util.h The goal is to make kernels mostly independent of proto headers, which will let us lock down our .so imports. RELNOTES: n/a PiperOrigin-RevId: 160278032 * [XLA] Add tuple support to Literal::CreateFromShape. PiperOrigin-RevId: 160278561 * Updates some more examples in examples/learn. PiperOrigin-RevId: 160278757 * Automated g4 rollback of changelist 160278032 PiperOrigin-RevId: 160280961 * Fixed the bug that Estimator does not make deepcopy of params in constructor PiperOrigin-RevId: 160281247 * Clean out the config and params in TPUEstimator. PiperOrigin-RevId: 160281507 * [XLA] Remove the "hlo dumper" parameter of xla::Compiler and its piping. This dumper is no longer necessary since the restructuring of HLO dumping and the addition of MaybeDumpHloModule which heeds to the right flags. The remaining bits didn't have additional functionality, but constituted a lot of boilerplate that has to be propagated throughout the backends. PiperOrigin-RevId: 160281798 * [TF:XLA] Refactor the sigmoid op as a rescaled tanh. PiperOrigin-RevId: 160282472 * Fix uninitialized values in TensorForest code. PiperOrigin-RevId: 160284420 * [TF:XLA] Update Tensorflow LLVM release to upstream r306370. Fix broken XLA build. PiperOrigin-RevId: 160284588 * tfdbg example: fix --tensor_size issue in debug_fibonacci PiperOrigin-RevId: 160290541 * [SE] ThenConvolveWithAlgorithm vlogs algorithm configs. PiperOrigin-RevId: 160292762 * Fix documentation of Estimator class (invalid quotes). PiperOrigin-RevId: 160292803 * Shrink the test size to avoid OOM error on old GPUs. PiperOrigin-RevId: 160292834 * [TF:XLA] Reject operators with resource outputs on CPU and GPU devices. We were checking for resource inputs but not resource outputs, which led to accidental fusion of some TensorArray ops on CPU and GPU. PiperOrigin-RevId: 160294302 * Add a functionality of remote fused graph transformation to fuse graphs by op type PiperOrigin-RevId: 160300039 * Cudnn compatible LSTMCell and LSTMBlockCell PiperOrigin-RevId: 160300668 * [XLA] Remove "operand" argument from HandleReducePrecision. PiperOrigin-RevId: 160301461 * Added more reduce window tests. PiperOrigin-RevId: 160301509 * Updates more text classification examples in examples/learn. PiperOrigin-RevId: 160305131 * Use C API to implement Operation._output_types This change first converts the _output_types member to a property and then implements it using C API if it is enabled. PiperOrigin-RevId: 160306227 * Add more tests for BatchNormTraining. RELNOTES: n/a PiperOrigin-RevId: 160307959 * Update path to print_selective_registration_header.py in comment PiperOrigin-RevId: 160308173 * Migrate TensorForest v4 python to contrib. PiperOrigin-RevId: 160308805 * Automated g4 rollback of changelist 159454657 PiperOrigin-RevId: 160314706 * TESTFIX: distributions:trig_test wasn't passing in ASAN mode. PiperOrigin-RevId: 160315597 * tfdbg doc: fixes and improvements PiperOrigin-RevId: 160318411 * Add a time estimation to HloCostAnalysis and represent properties as a map so that adding more properties will be easier, e.g. in a sub-class. PiperOrigin-RevId: 160318494 * tfdbg: revert dns:/// prefix in gRPC mode PiperOrigin-RevId: 160319348 * Moves TensorCApi from c_api.cc to c_api_internal.h, where it can be used by other code that require access to the underlying TensorBuffers. PiperOrigin-RevId: 160323362 * Readd the new tensors and variables documents, with tests passing. PiperOrigin-RevId: 160324191 * Make ResourceHandle not be a proto I'm trying to make core/kernels independent of protos. Currently the dtype ResourceHandle is itself a proto. After this CL, ResourceHandle is a normal C++ type which gets converted to/from ResourceHandleProto at (de)serialization time. RELNOTES: n/a PiperOrigin-RevId: 160329002 * Minor cleanup: remove unused dependencies and inclusions PiperOrigin-RevId: 160334030 * Add name_scopes to mnist_deep.py for a cleaner graph layout. PiperOrigin-RevId: 160338775 * Add note about `tf.test.mock` to docs for `tf.test` PiperOrigin-RevId: 160338811 * Internal change. PiperOrigin-RevId: 160339087 * Fix bugs in ScatterNd and add ScatterNdNonAliasingAdd. tf.scatter_nd_non_aliasing_add acts similarly to tf.scatter_nd_add but works on non-ref objects (i.e., Tensors -- not Variables). This means it has a gradient with respect to the primary input as well as the updates. It does its best to avoid making extra copies of the input. PiperOrigin-RevId: 160339328 * Update ops-related pbtxt files. PiperOrigin-RevId: 160340888 * Add checkpoint conversion for models that use the attention mechanism implemented in tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py. PiperOrigin-RevId: 160340994 * Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 160341769 * Merge changes from github. PiperOrigin-RevId: 160344052 * Update ops-related pbtxt files. PiperOrigin-RevId: 160346151 * Load py_test in tensorflow/contrib/boosted_trees/BUILD to fix pip test visibility failures. * Disable boosted_trees tests on mac while they are being debugged.
This commit is contained in:
parent
b183be3b4d
commit
cf7c008ab1
tensorflow
BUILD
c
cc/gradients
compiler
aot
BUILDcompile.cctest_graph_tfunknownop.config.pbtxttest_graph_tfunknownop.pbtxttfcompile_main.cctfcompile_util.cctfcompile_util.htfcompile_util_test.cc
jit
plugin/executor
tests
tf2xla
BUILD
kernels
test_util.ccxla_context.ccxla_context.hxla_op_kernel.ccxla_op_kernel.hxla_op_registry.ccxla
array4d.h
legacy_flags
BUILDcompiler_functor_flags.cccompiler_functor_flags.hconvolution_thunk_flags.ccconvolution_thunk_flags.hcpu_runtime_flags.cccpu_runtime_flags.hdebug_options_flags.cchlo_test_base_flags.cchlo_test_base_flags.hservice_flags.ccservice_flags.h
literal_util.ccliteral_util.hliteral_util_test.ccpacked_literal_reader.ccreference_util.ccreference_util.hservice
BUILDbuffer_assignment.ccbuffer_assignment_test.cccompile_only_service.cccompiler.h
cpu
BUILDcompiler_functor.cccompiler_functor.hconv_canonicalization.cccpu_compiler.cccpu_compiler.hdot_op_emitter.ccdot_op_emitter.hir_emission_utils.ccir_emitter.ccsimple_orc_jit.ccsimple_orc_jit.h
dfs_hlo_visitor.helemental_ir_emitter.ccexecutable.ccexecutable.hgpu
BUILDconvolution_thunk.ccconvolution_thunk.hgpu_compiler.ccgpu_compiler.hhlo_schedule.cchlo_schedule.h
hlo_alias_analysis.cchlo_alias_analysis.hhlo_buffer.cchlo_buffer.hhlo_cost_analysis.cchlo_cost_analysis.hhlo_cost_analysis_test.cchlo_dataflow_analysis.cchlo_dataflow_analysis.hhlo_evaluator.cchlo_graph_dumper.cchlo_graph_dumper.hhlo_instruction.cchlo_instruction.hhlo_module.cchlo_module.hhlo_module_test.cchlo_ordering.cchlo_ordering.hhlo_ordering_test.cchlo_pass_pipeline.cc@ -216,6 +216,7 @@ filegroup(
|
||||
"//tensorflow/compiler/jit/kernels:all_files",
|
||||
"//tensorflow/compiler/jit/legacy_flags:all_files",
|
||||
"//tensorflow/compiler/jit/ops:all_files",
|
||||
"//tensorflow/compiler/plugin/executor:all_files",
|
||||
"//tensorflow/compiler/tests:all_files",
|
||||
"//tensorflow/compiler/tf2xla:all_files",
|
||||
"//tensorflow/compiler/tf2xla/cc:all_files",
|
||||
@ -288,6 +289,7 @@ filegroup(
|
||||
"//tensorflow/contrib/opt:all_files",
|
||||
"//tensorflow/contrib/predictor:all_files",
|
||||
"//tensorflow/contrib/remote_fused_graph/pylib:all_files",
|
||||
"//tensorflow/contrib/resampler:all_files",
|
||||
"//tensorflow/contrib/rnn:all_files",
|
||||
"//tensorflow/contrib/saved_model:all_files",
|
||||
"//tensorflow/contrib/saved_model/cc/saved_model:all_files",
|
||||
|
@ -466,15 +466,6 @@ TF_Tensor* TF_Tensor_EncodeStrings(const Tensor& src) {
|
||||
dimvec.size(), base, size, DeleteArray, base);
|
||||
}
|
||||
|
||||
class TensorCApi {
|
||||
public:
|
||||
static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; }
|
||||
static Tensor MakeTensor(TF_DataType type, const TensorShape& shape,
|
||||
TensorBuffer* buf) {
|
||||
return Tensor(static_cast<DataType>(type), shape, buf);
|
||||
}
|
||||
};
|
||||
|
||||
// Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
|
||||
// result in a zero-sized tensor.
|
||||
static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) {
|
||||
|
@ -117,3 +117,16 @@ struct TF_ImportGraphDefOptions {
|
||||
struct TF_DeviceList {
|
||||
std::vector<tensorflow::DeviceAttributes> response;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorCApi {
|
||||
public:
|
||||
static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; }
|
||||
static Tensor MakeTensor(TF_DataType type, const TensorShape& shape,
|
||||
TensorBuffer* buf) {
|
||||
return Tensor(static_cast<DataType>(type), shape, buf);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -247,6 +247,17 @@ Status ScatterNdGrad(const Scope& scope, const Operation& op,
|
||||
}
|
||||
REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad);
|
||||
|
||||
Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
auto indices = op.input(1);
|
||||
grad_outputs->push_back(Identity(scope, grad_inputs[0]));
|
||||
grad_outputs->push_back(NoGradient());
|
||||
grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad);
|
||||
|
||||
Status PadGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
|
@ -233,6 +233,28 @@ TEST_F(ArrayGradTest, ScatterNdGrad_SliceIndexing) {
|
||||
RunTest(updates, updates_shape, y, y_shape);
|
||||
}
|
||||
|
||||
TEST_F(ArrayGradTest, ScatterNdNonAliasingAddGrad_SimpleIndexing) {
|
||||
TensorShape updates_shape({4});
|
||||
TensorShape input_shape({8});
|
||||
auto input = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(input_shape));
|
||||
auto updates =
|
||||
Placeholder(scope_, DT_FLOAT, Placeholder::Shape(updates_shape));
|
||||
auto indices = Const(scope_, {{4}, {3}, {1}, {7}});
|
||||
auto y = ScatterNdNonAliasingAdd(scope_, input, indices, updates);
|
||||
RunTest({input, updates}, {input_shape, updates_shape}, {y}, {input_shape});
|
||||
}
|
||||
|
||||
TEST_F(ArrayGradTest, ScatterNdNonAliasingAddGrad_SliceIndexing) {
|
||||
TensorShape updates_shape({2, 4, 4});
|
||||
TensorShape input_shape({4, 4, 4});
|
||||
auto input = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(input_shape));
|
||||
auto updates =
|
||||
Placeholder(scope_, DT_FLOAT, Placeholder::Shape(updates_shape));
|
||||
auto indices = Const(scope_, {{0}, {2}});
|
||||
auto y = ScatterNdNonAliasingAdd(scope_, input, indices, updates);
|
||||
RunTest({input, updates}, {input_shape, updates_shape}, {y}, {input_shape});
|
||||
}
|
||||
|
||||
TEST_F(ArrayGradTest, PadGrad) {
|
||||
TensorShape x_shape({2, 3});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
|
||||
|
@ -108,6 +108,7 @@ cc_test(
|
||||
deps = [
|
||||
":tfcompile_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
@ -127,8 +128,6 @@ cc_library(
|
||||
":tfcompile_lib",
|
||||
":tfcompile_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:buffer_assignment_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:compiler_functor_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:service_flags",
|
||||
@ -158,6 +157,17 @@ tf_library(
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
# A test of tf_library that includes a graph with an unknown op, but where
|
||||
# the compilation works because the the unknown op is not needed for the fetches.
|
||||
tf_library(
|
||||
name = "test_graph_tfunknownop",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfunknownop.config.pbtxt",
|
||||
cpp_class = "UnknownOpAddComp",
|
||||
graph = "test_graph_tfunknownop.pbtxt",
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
# Utility library for benchmark binaries, used by the *_benchmark rules that are
|
||||
# added by the tfcompile bazel macro.
|
||||
cc_library(
|
||||
@ -201,6 +211,7 @@ test_suite(
|
||||
tests = [
|
||||
":benchmark_test",
|
||||
":test_graph_tfadd_test",
|
||||
":test_graph_tfunknownop_test",
|
||||
"//tensorflow/compiler/aot/tests:all_tests",
|
||||
],
|
||||
)
|
||||
|
@ -378,9 +378,16 @@ Status CompileXla(xla::CompileOnlyClient* client,
|
||||
Status InitGraph(const GraphDef& graph_def, const Config& config,
|
||||
const MainFlags& flags, std::unique_ptr<Graph>* graph) {
|
||||
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
||||
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
|
||||
std::unique_ptr<Graph> g(new Graph(flib_def));
|
||||
GraphDef copy_def(graph_def);
|
||||
|
||||
GraphDef copy_def;
|
||||
|
||||
// Prune the GraphDef first so that unknown ops that we aren't compiling get
|
||||
// filtered out.
|
||||
TF_RETURN_IF_ERROR(PruneGraphDefInto(config, graph_def, ©_def));
|
||||
|
||||
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(©_def, *g->op_registry(),
|
||||
0 /*node_offset*/));
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
16
tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt
Normal file
16
tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt
Normal file
@ -0,0 +1,16 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
feed {
|
||||
id { node_name: "x_const" }
|
||||
shape {
|
||||
dim { size: 1 }
|
||||
}
|
||||
}
|
||||
feed {
|
||||
id { node_name: "y_const" }
|
||||
shape {
|
||||
dim { size: 1 }
|
||||
}
|
||||
}
|
||||
fetch {
|
||||
id { node_name: "x_y_sum" }
|
||||
}
|
86
tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt
Normal file
86
tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt
Normal file
@ -0,0 +1,86 @@
|
||||
node {
|
||||
name : "x_const"
|
||||
op : "Const"
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
int_val: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key : "dtype"
|
||||
value {
|
||||
type : DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name : "y_const"
|
||||
op : "Const"
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
int_val: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name : "x_y_sum"
|
||||
op : "Add"
|
||||
input : "x_const"
|
||||
input : "y_const"
|
||||
attr {
|
||||
key : "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name : "z"
|
||||
op : "SomeUnknownOp"
|
||||
input : "x_const"
|
||||
attr {
|
||||
key : "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name : "x_z_sum"
|
||||
op : "Add"
|
||||
input : "x_const"
|
||||
input : "z"
|
||||
attr {
|
||||
key : "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 15
|
||||
}
|
@ -24,8 +24,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/aot/tfcompile.pb.h"
|
||||
#include "tensorflow/compiler/aot/tfcompile_util.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/service_flags.h"
|
||||
@ -134,8 +132,6 @@ int main(int argc, char** argv) {
|
||||
std::vector<tensorflow::Flag> flag_list;
|
||||
AppendMainFlags(&flag_list, &flags);
|
||||
xla::legacy_flags::AppendBufferAssignmentFlags(&flag_list);
|
||||
xla::legacy_flags::AppendCompilerFunctorFlags(&flag_list);
|
||||
xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list);
|
||||
xla::legacy_flags::AppendHloGraphDumperFlags(&flag_list);
|
||||
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
|
||||
xla::legacy_flags::AppendServiceFlags(&flag_list);
|
||||
|
@ -15,10 +15,14 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/aot/tfcompile_util.h"
|
||||
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/compiler/aot/tfcompile.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
@ -115,5 +119,51 @@ Status ValidateConfig(const Config& config) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PruneGraphDefInto(const Config& config, const GraphDef& in,
|
||||
GraphDef* out) {
|
||||
*out = in;
|
||||
out->clear_node();
|
||||
|
||||
// Maps node name to reachability.
|
||||
std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
|
||||
for (const NodeDef& node : in.node()) {
|
||||
node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
|
||||
}
|
||||
|
||||
std::queue<string> name_queue;
|
||||
for (int i = 0; i < config.fetch_size(); ++i) {
|
||||
name_queue.push(config.fetch(i).id().node_name());
|
||||
}
|
||||
while (!name_queue.empty()) {
|
||||
const string name = name_queue.front();
|
||||
name_queue.pop();
|
||||
|
||||
auto find_it = node_by_name.find(name);
|
||||
if (find_it == node_by_name.end()) {
|
||||
return errors::InvalidArgument("While pruning graph, node ", name,
|
||||
" needed but not found in the graph.");
|
||||
}
|
||||
auto& map_entry = find_it->second;
|
||||
if (map_entry.first) {
|
||||
continue;
|
||||
}
|
||||
map_entry.first = true;
|
||||
|
||||
for (const string& in_edge : map_entry.second->input()) {
|
||||
name_queue.push(ParseTensorName(in_edge).first.ToString());
|
||||
}
|
||||
}
|
||||
|
||||
// Copy over, preserving order of original and only nodes that are reachable
|
||||
// from the fetches.
|
||||
out->mutable_node()->Reserve(in.node_size());
|
||||
for (const NodeDef& node : in.node()) {
|
||||
if (node_by_name[node.name()].first) {
|
||||
*out->add_node() = node;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
|
||||
|
||||
#include "tensorflow/compiler/aot/tfcompile.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
|
||||
@ -30,6 +31,11 @@ Status ValidateCppIdent(StringPiece ident, StringPiece msg);
|
||||
// ValidateConfig returns OK iff config is valid.
|
||||
Status ValidateConfig(const Config& config);
|
||||
|
||||
// Returns in <out> a copy of <in>, pruned to only include fetches from
|
||||
// <config>.
|
||||
Status PruneGraphDefInto(const Config& config, const GraphDef& in,
|
||||
GraphDef* out);
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -15,9 +15,11 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/aot/tfcompile_util.h"
|
||||
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -180,6 +182,65 @@ TEST(ValidateConfig, ConflictingFetchName) {
|
||||
ExpectErrorContains(ValidateConfig(config), "conflicting fetch name");
|
||||
}
|
||||
|
||||
static Config FetchesConfig(std::vector<string> fetches) {
|
||||
Config config;
|
||||
for (const auto& fetch_node_name : fetches) {
|
||||
auto* fetch = config.add_fetch();
|
||||
fetch->set_name(strings::StrCat("fetch_", fetch_node_name));
|
||||
fetch->mutable_id()->set_node_name(fetch_node_name);
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
TEST(PruneGraphDefInto, Basic) {
|
||||
GraphDef def;
|
||||
auto* n = def.add_node();
|
||||
n->set_name("a");
|
||||
n->add_input("b:0");
|
||||
n->add_input("^c");
|
||||
|
||||
GraphDef copy;
|
||||
ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"missing"}), def, ©),
|
||||
"node missing needed");
|
||||
ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"a"}), def, ©),
|
||||
"node b needed");
|
||||
|
||||
n = def.add_node();
|
||||
n->set_name("b");
|
||||
ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"a"}), def, ©),
|
||||
"node c needed");
|
||||
n->add_input("d:1");
|
||||
|
||||
n = def.add_node();
|
||||
n->set_name("c");
|
||||
n->add_input("d:1");
|
||||
|
||||
n = def.add_node();
|
||||
n->set_name("d");
|
||||
|
||||
// Graph is full, no pruning done.
|
||||
// Graph right now has diamond from d:
|
||||
// d --> b --> a
|
||||
// d --> c --> a
|
||||
TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a"}), def, ©));
|
||||
EXPECT_EQ(def.DebugString(), copy.DebugString());
|
||||
GraphDef pruned_a = copy;
|
||||
|
||||
// Add some unrelated fields that use b and c, but are not needed for a.
|
||||
n = def.add_node();
|
||||
n->set_name("e");
|
||||
n->add_input("^d");
|
||||
n->add_input("b:2");
|
||||
copy.Clear();
|
||||
TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a"}), def, ©));
|
||||
EXPECT_EQ(pruned_a.DebugString(), copy.DebugString());
|
||||
|
||||
// Fetch "a" and "e" to get the original graph.
|
||||
copy.Clear();
|
||||
TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a", "e"}), def, ©));
|
||||
EXPECT_EQ(def.DebugString(), copy.DebugString());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,10 @@ package_group(
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [":internal"],
|
||||
default_visibility = [
|
||||
":internal",
|
||||
"//tensorflow/compiler/plugin/executor:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
|
||||
|
@ -2,6 +2,7 @@ licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/compiler/plugin/executor:__pkg__",
|
||||
"//tensorflow/compiler/tf2xla:internal",
|
||||
],
|
||||
)
|
||||
|
@ -64,7 +64,7 @@ class ParallelCheckOp : public OpKernel {
|
||||
ok = (diff <= tolerance);
|
||||
}
|
||||
if (ok) continue;
|
||||
LOG(ERROR) << "Op " << def().name() << " fails equality at output "
|
||||
LOG(ERROR) << "Op " << name() << " fails equality at output "
|
||||
<< input_idx << " type " << DataTypeString(dtype)
|
||||
<< " element " << i << ": std_val=" << p0[i]
|
||||
<< " test_val=" << p1[i] << " diff=" << (p0[i] - p1[i]);
|
||||
@ -75,7 +75,7 @@ class ParallelCheckOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
VLOG(1) << "Compute " << def().name();
|
||||
VLOG(1) << "Compute " << name();
|
||||
const int num_pairs = ctx->num_inputs() / 2;
|
||||
for (int i = 0; i < num_pairs; ++i) {
|
||||
CHECK_EQ(ctx->input_dtype(i), ctx->input_dtype(i + num_pairs));
|
||||
@ -113,7 +113,7 @@ class ParallelCheckOp : public OpKernel {
|
||||
LOG(FATAL) << "unimpl: " << ctx->input_dtype(i);
|
||||
}
|
||||
if (failed > 0) {
|
||||
LOG(ERROR) << "check failed for " << def().name() << " output " << i
|
||||
LOG(ERROR) << "check failed for " << name() << " output " << i
|
||||
<< " num_elts: " << num_elts;
|
||||
legacy_flags::ParallelCheckOpFlags* flags =
|
||||
legacy_flags::GetParallelCheckOpFlags();
|
||||
@ -121,7 +121,7 @@ class ParallelCheckOp : public OpKernel {
|
||||
LOG(QFATAL) << "failfast on first parallel-check failure";
|
||||
}
|
||||
} else {
|
||||
VLOG(1) << "check passed for " << def().name() << " output " << i
|
||||
VLOG(1) << "check passed for " << name() << " output " << i
|
||||
<< " num_elts: " << num_elts;
|
||||
}
|
||||
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/memory_types.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
@ -162,10 +163,12 @@ Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Does `node` have a DT_RESOURCE typed argument?
|
||||
bool HasResourceArgument(const Node& node) {
|
||||
// Tests whether `node` has a DT_RESOURCE typed input or output.
|
||||
bool HasResourceInputOrOutput(const Node& node) {
|
||||
return std::find(node.input_types().begin(), node.input_types().end(),
|
||||
DT_RESOURCE) != node.input_types().end();
|
||||
DT_RESOURCE) != node.input_types().end() ||
|
||||
std::find(node.output_types().begin(), node.output_types().end(),
|
||||
DT_RESOURCE) != node.output_types().end();
|
||||
}
|
||||
|
||||
Status FindCompilationCandidates(
|
||||
@ -193,9 +196,10 @@ Status FindCompilationCandidates(
|
||||
<< ": " << node->type_string();
|
||||
continue;
|
||||
}
|
||||
if (!registration->compile_resource_ops && HasResourceArgument(*node)) {
|
||||
VLOG(2) << "Compilation rejected node: resource argument " << node->name()
|
||||
<< ": " << node->type_string();
|
||||
if (!registration->compile_resource_ops &&
|
||||
HasResourceInputOrOutput(*node)) {
|
||||
VLOG(2) << "Compilation rejected node: resource input/output "
|
||||
<< node->name() << ": " << node->type_string();
|
||||
continue;
|
||||
}
|
||||
if (node->type_string() == "While" &&
|
||||
|
@ -14,11 +14,13 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
@ -455,5 +457,39 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
|
||||
EXPECT_EQ(clusters["B"], clusters["C"]);
|
||||
}
|
||||
|
||||
REGISTER_OP("ResourceInput").Input("a: resource").Output("o: float");
|
||||
REGISTER_OP("ResourceOutput").Input("a: float").Output("o: resource");
|
||||
|
||||
namespace {
|
||||
|
||||
class DummyOp : public XlaOpKernel {
|
||||
using XlaOpKernel::XlaOpKernel;
|
||||
void Compile(XlaOpKernelContext* ctx) override {}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("ResourceInput"), DummyOp);
|
||||
REGISTER_XLA_OP(Name("ResourceOutput"), DummyOp);
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(XlaCompilationTest, Resources) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a =
|
||||
ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
|
||||
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
|
||||
// We should not form clusters with resource ops by default.
|
||||
Node* c = ops::UnaryOp("ResourceOutput", b, builder.opts().WithName("C"));
|
||||
Node* d = ops::UnaryOp("ResourceInput", c, builder.opts().WithName("D"));
|
||||
ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
|
||||
TF_EXPECT_OK(builder.ToGraph(graph.get()));
|
||||
}
|
||||
MarkForCompilation(&graph);
|
||||
auto clusters = GetClusters(*graph);
|
||||
EXPECT_EQ(0, clusters.size()); // Nothing should be compiled.
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -11,9 +11,11 @@ cc_library(
|
||||
"*.h",
|
||||
]),
|
||||
deps = [
|
||||
"//tensorflow/compiler/jit:xla_device",
|
||||
"//tensorflow/compiler/jit:xla_jit_headers_lib",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:xla_headers_lib",
|
||||
"//tensorflow/compiler/xla/service:hlo_evaluator",
|
||||
"//tensorflow/compiler/xla/service",
|
||||
"//third_party/eigen3",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@protobuf//:protobuf_headers",
|
||||
|
@ -48,9 +48,8 @@ namespace executorplugin {
|
||||
* each pass in the optimization pipeline. The service subdirectory
|
||||
* contains useful optimization passes.
|
||||
*/
|
||||
Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module,
|
||||
HloDumper dump_hlo) {
|
||||
HloPassPipeline pipeline("Executor", dump_hlo);
|
||||
Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module) {
|
||||
HloPassPipeline pipeline("Executor");
|
||||
pipeline.AddPass<Inliner>();
|
||||
pipeline.AddPass<HloSubcomputationUnification>();
|
||||
pipeline.AddPass<HloCSE>(false);
|
||||
@ -67,13 +66,13 @@ Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module,
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> ExecutorCompiler::Compile(
|
||||
std::unique_ptr<HloModule> hlo_module, HloDumper dump_hlo,
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
se::StreamExecutor* stream_exec) {
|
||||
TF_RET_CHECK(stream_exec != nullptr);
|
||||
|
||||
VLOG(1) << "Generate graph " << hlo_module->name();
|
||||
|
||||
TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get(), dump_hlo));
|
||||
TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get()));
|
||||
|
||||
// Typically you would visit the HLO graph, building up a compiled equivalent
|
||||
// In this case we are using an Hlo evaluator at execution time, so we don't
|
||||
@ -88,7 +87,7 @@ StatusOr<std::unique_ptr<Executable>> ExecutorCompiler::Compile(
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> ExecutorCompiler::Compile(
|
||||
std::vector<std::unique_ptr<HloModule>> hlo_modules,
|
||||
HloDumper dump_hlos, std::vector<se::StreamExecutor*> stream_execs) {
|
||||
std::vector<se::StreamExecutor*> stream_execs) {
|
||||
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"Compilation of multiple HLO modules is not supported on Executor.");
|
||||
@ -97,7 +96,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> ExecutorCompiler::Compile(
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
ExecutorCompiler::CompileAheadOfTime(
|
||||
std::vector<std::unique_ptr<HloModule>> hlo_modules,
|
||||
HloDumper dump_hlo, const AotCompilationOptions& aot_options) {
|
||||
const AotCompilationOptions& aot_options) {
|
||||
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"AOT compilation not supported on Executor");
|
||||
|
@ -35,25 +35,23 @@ class ExecutorCompiler : public Compiler {
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> Compile(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
HloDumper dump_hlo,
|
||||
perftools::gputools::StreamExecutor* stream_exec) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||
std::vector<std::unique_ptr<HloModule>> hlo_module,
|
||||
HloDumper dump_hlo,
|
||||
std::vector<perftools::gputools::StreamExecutor*> stream_exec) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(
|
||||
std::vector<std::unique_ptr<HloModule>> module,
|
||||
HloDumper dump_hlo, const AotCompilationOptions& options) override;
|
||||
const AotCompilationOptions& options) override;
|
||||
|
||||
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override;
|
||||
|
||||
perftools::gputools::Platform::Id PlatformId() const override;
|
||||
|
||||
private:
|
||||
Status RunHloOptimization(HloModule* hlo_module, HloDumper dump_hlo);
|
||||
Status RunHloOptimization(HloModule* hlo_module);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ExecutorCompiler);
|
||||
};
|
||||
|
@ -175,6 +175,11 @@ tf_xla_py_test(
|
||||
name = "slice_ops_test",
|
||||
size = "small",
|
||||
srcs = ["slice_ops_test.py"],
|
||||
# TODO(b/62962492): Test fails with assertion error.
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -456,6 +461,11 @@ cuda_py_test(
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn_ops",
|
||||
],
|
||||
# TODO(b/62961789): Test fails with SIGABRT
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -249,6 +249,7 @@ cc_library(
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
@ -63,17 +63,13 @@ class StridedSliceOp : public XlaOpKernel {
|
||||
&strides_tensor));
|
||||
|
||||
TensorShape dummy_processing_shape;
|
||||
ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape);
|
||||
ShapeReadWriteFromTensorShape wrapped_dummy_processing_shape(
|
||||
&dummy_processing_shape);
|
||||
bool dummy = false;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ValidateStridedSliceOp(
|
||||
&begin_tensor, &end_tensor, strides_tensor,
|
||||
ShapeReadWriteFromTensorShape(&input_shape), begin_mask_,
|
||||
end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
|
||||
&wrapped_dummy_processing_shape, &wrapped_final_shape, &dummy,
|
||||
&dummy, &dummy, &begin, &end, &strides));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ValidateStridedSliceOp(
|
||||
&begin_tensor, &end_tensor, strides_tensor, input_shape,
|
||||
begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
|
||||
shrink_axis_mask_, &dummy_processing_shape, &final_shape,
|
||||
&dummy, &dummy, &dummy, &begin, &end, &strides));
|
||||
|
||||
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
|
||||
gtl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
|
||||
@ -146,14 +142,11 @@ class StridedSliceGradOp : public XlaOpKernel {
|
||||
&strides_tensor));
|
||||
|
||||
bool dummy = false;
|
||||
ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape);
|
||||
ShapeReadWriteFromTensorShape wrapped_processing_shape(&processing_shape);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ValidateStridedSliceOp(
|
||||
&begin_tensor, &end_tensor, strides_tensor,
|
||||
ShapeReadWriteFromTensorShape(&input_shape), begin_mask_,
|
||||
end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
|
||||
&wrapped_processing_shape, &wrapped_final_shape, &dummy,
|
||||
&begin_tensor, &end_tensor, strides_tensor, input_shape,
|
||||
begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
|
||||
shrink_axis_mask_, &processing_shape, &final_shape, &dummy,
|
||||
&dummy, &dummy, &begin, &end, &strides));
|
||||
|
||||
// Check to make sure dy is consistent with the original slice
|
||||
@ -257,17 +250,13 @@ class StridedSliceAssignOp : public XlaOpKernel {
|
||||
const TensorShape rhs_shape = ctx->InputShape(4);
|
||||
|
||||
TensorShape dummy_processing_shape;
|
||||
ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape);
|
||||
ShapeReadWriteFromTensorShape wrapped_dummy_processing_shape(
|
||||
&dummy_processing_shape);
|
||||
bool dummy = false;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ValidateStridedSliceOp(
|
||||
&begin_tensor, &end_tensor, strides_tensor,
|
||||
ShapeReadWriteFromTensorShape(&lhs_shape), begin_mask_,
|
||||
end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
|
||||
&wrapped_dummy_processing_shape, &wrapped_final_shape, &dummy,
|
||||
&dummy, &dummy, &begin, &end, &strides));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ValidateStridedSliceOp(
|
||||
&begin_tensor, &end_tensor, strides_tensor, lhs_shape,
|
||||
begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
|
||||
shrink_axis_mask_, &dummy_processing_shape, &final_shape,
|
||||
&dummy, &dummy, &dummy, &begin, &end, &strides));
|
||||
|
||||
if (final_shape.num_elements() == 0 && rhs_shape.num_elements() == 0) {
|
||||
// DynamicUpdateSlice does not allow 0-element updates. We should probably
|
||||
|
@ -78,12 +78,19 @@ static xla::ComputationDataHandle Round(xla::ComputationBuilder* b,
|
||||
b->LogicalAnd(b->Eq(fraction, half), is_odd)),
|
||||
b->Add(round_val, one), round_val);
|
||||
}
|
||||
XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x));
|
||||
|
||||
// Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2.
|
||||
static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b,
|
||||
DataType dtype,
|
||||
const xla::ComputationDataHandle& x) {
|
||||
auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5);
|
||||
return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x))));
|
||||
}
|
||||
|
||||
XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x));
|
||||
XLAJIT_MAKE_UNARY(Rsqrt,
|
||||
b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5)));
|
||||
XLAJIT_MAKE_UNARY(Sigmoid,
|
||||
b->Map({x}, *ctx->GetOrCreateSigmoid(input_type(0))));
|
||||
XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x));
|
||||
XLAJIT_MAKE_UNARY(Softplus,
|
||||
b->Log(b->Add(b->Exp(x), XlaHelpers::One(b, input_type(0)))));
|
||||
XLAJIT_MAKE_UNARY(Sqrt,
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/test_util.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -172,27 +172,6 @@ const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) {
|
||||
});
|
||||
}
|
||||
|
||||
const xla::Computation* XlaContext::GetOrCreateSigmoid(const DataType type) {
|
||||
return LookupOrCreate(type, &sigmoid_func_, [this, type] {
|
||||
const string type_string = DataTypeString(type);
|
||||
VLOG(1) << "Building Sigmoid() for " << type_string;
|
||||
xla::ComputationBuilder b(builder()->client(),
|
||||
"sigmoid<" + type_string + ">");
|
||||
xla::PrimitiveType xla_type;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
|
||||
auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
|
||||
// Clamp the inputs to the range [-18, 18] since anything outside
|
||||
// this range is 0.0f or 1.0f in single-precision. We must clamp the range
|
||||
// of x to avoid incorrect outputs due to fast-math optimizations for large
|
||||
// negative x.
|
||||
x = b.Clamp(XlaHelpers::IntegerLiteral(&b, type, -18), x,
|
||||
XlaHelpers::IntegerLiteral(&b, type, 18));
|
||||
auto one = XlaHelpers::One(&b, type);
|
||||
b.Div(one, b.Add(b.Exp(b.Neg(x)), one));
|
||||
return b.Build().ConsumeValueOrDie();
|
||||
});
|
||||
}
|
||||
|
||||
const xla::Computation* XlaContext::LookupOrCreate(
|
||||
DataType type, ComputationMap* out,
|
||||
const std::function<xla::Computation()>& create) {
|
||||
|
@ -129,11 +129,6 @@ class XlaContext : public ResourceBase {
|
||||
// separate specialization of the computation for each DataType.
|
||||
const xla::Computation* GetOrCreateAdd(const DataType type);
|
||||
|
||||
// Get an XLA lambda to compute Sigmoid. This is cached in the
|
||||
// XlaContext since it may be used by multiple Ops. There is a
|
||||
// separate specialization of the computation for each DataType.
|
||||
const xla::Computation* GetOrCreateSigmoid(const DataType type);
|
||||
|
||||
// The name of the XlaContext resource during symbolic graph execution.
|
||||
static const char kXlaContextResourceName[];
|
||||
|
||||
|
@ -391,11 +391,6 @@ const xla::Computation* XlaOpKernelContext::GetOrCreateAdd(
|
||||
return XlaContext::Get(context_).GetOrCreateAdd(type);
|
||||
}
|
||||
|
||||
const xla::Computation* XlaOpKernelContext::GetOrCreateSigmoid(
|
||||
const DataType type) {
|
||||
return XlaContext::Get(context_).GetOrCreateSigmoid(type);
|
||||
}
|
||||
|
||||
XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void XlaOpKernel::Compute(OpKernelContext* context) {
|
||||
|
@ -201,11 +201,6 @@ class XlaOpKernelContext {
|
||||
// separate specialization of the computation for each DataType.
|
||||
const xla::Computation* GetOrCreateAdd(const DataType type);
|
||||
|
||||
// Get an XLA lambda to compute Sigmoid. This is cached in the
|
||||
// XlaContext since it may be used by multiple Ops. There is a
|
||||
// separate specialization of the computation for each DataType.
|
||||
const xla::Computation* GetOrCreateSigmoid(const DataType type);
|
||||
|
||||
private:
|
||||
OpKernelContext* const context_;
|
||||
};
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/local_device.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
|
@ -207,6 +207,18 @@ class Array4D {
|
||||
}
|
||||
}
|
||||
|
||||
// Invokes a callback with the (indices, value) for each cell in the 4D array.
|
||||
void Each(
|
||||
std::function<void(tensorflow::gtl::ArraySlice<int64>, T)> f) const {
|
||||
// We const_cast to be able to use the common non-const implementation,
|
||||
// but prevent modification of the data by passing it by-value to the
|
||||
// caller.
|
||||
const_cast<Array4D*>(this)->Each(
|
||||
[&f](tensorflow::gtl::ArraySlice<int64> indices, T* value) {
|
||||
f(indices, *value);
|
||||
});
|
||||
}
|
||||
|
||||
// Fills all of the {p,z} with the array provided, which specifies {y,x}.
|
||||
void FillWithYX(const Array2D<T>& value) {
|
||||
CHECK_EQ(value.height(), height());
|
||||
|
@ -79,41 +79,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_runtime_flags",
|
||||
srcs = ["cpu_runtime_flags.cc"],
|
||||
hdrs = ["cpu_runtime_flags.h"],
|
||||
deps =
|
||||
[
|
||||
":parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "compiler_functor_flags",
|
||||
srcs = ["compiler_functor_flags.cc"],
|
||||
hdrs = ["compiler_functor_flags.h"],
|
||||
deps = [
|
||||
":parse_flags_from_env",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "convolution_thunk_flags",
|
||||
srcs = ["convolution_thunk_flags.cc"],
|
||||
hdrs = ["convolution_thunk_flags.h"],
|
||||
deps = [
|
||||
":parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stream_assignment_flags",
|
||||
srcs = ["stream_assignment_flags.cc"],
|
||||
@ -160,17 +125,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_test_base_flags",
|
||||
srcs = ["hlo_test_base_flags.cc"],
|
||||
hdrs = ["hlo_test_base_flags.h"],
|
||||
deps = [
|
||||
":parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "backend_flags",
|
||||
srcs = ["backend_flags.cc"],
|
||||
|
@ -1,61 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Legacy flags for XLA's compiler_functor module.
|
||||
|
||||
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Pointers to the parsed value of the flags and flag descriptors, initialized
|
||||
// via flags_init.
|
||||
static CompilerFunctorFlags* flags;
|
||||
static std::vector<tensorflow::Flag>* flag_list;
|
||||
static std::once_flag flags_init;
|
||||
|
||||
// Allocate *flags. Called via call_once(&flags_init,...).
|
||||
static void AllocateFlags() {
|
||||
flags = new CompilerFunctorFlags;
|
||||
flag_list = new std::vector<tensorflow::Flag>({
|
||||
tensorflow::Flag("xla_debug_cpu_dump_ir", &flags->xla_debug_cpu_dump_ir,
|
||||
"Dump IR, before optimizations to a path"),
|
||||
});
|
||||
ParseFlagsFromEnv(*flag_list);
|
||||
}
|
||||
|
||||
// Append to *append_to flag definitions associated with XLA's compiler_functor
|
||||
// module.
|
||||
void AppendCompilerFunctorFlags(std::vector<tensorflow::Flag>* append_to) {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
|
||||
}
|
||||
|
||||
// Return a pointer to the CompilerFunctorFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
CompilerFunctorFlags* GetCompilerFunctorFlags() {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
return flags;
|
||||
}
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
@ -1,47 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_COMPILER_FUNCTOR_FLAGS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_COMPILER_FUNCTOR_FLAGS_H_
|
||||
|
||||
// Legacy flags for the XLA's compiler_functor module.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Append to *flag_list flag definitions associated with XLA's compiler_functor
|
||||
// module.
|
||||
void AppendCompilerFunctorFlags(std::vector<tensorflow::Flag>* flag_list);
|
||||
|
||||
// The values of flags associated with XLA's compiler_functor module.
|
||||
typedef struct {
|
||||
string xla_debug_cpu_dump_ir; // Dump IR, before optimizations to a path
|
||||
} CompilerFunctorFlags;
|
||||
|
||||
// Return a pointer to the CompilerFunctorFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
CompilerFunctorFlags* GetCompilerFunctorFlags();
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_COMPILER_FUNCTOR_FLAGS_H_
|
@ -1,63 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Legacy flags for XLA's convolution_thunk module.
|
||||
|
||||
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Pointers to the parsed value of the flags and flag descriptors, initialized
|
||||
// via flags_init.
|
||||
static ConvolutionThunkFlags* flags;
|
||||
static std::vector<tensorflow::Flag>* flag_list;
|
||||
static std::once_flag flags_init;
|
||||
|
||||
// Allocate *flags. Called via call_once(&flags_init,...).
|
||||
static void AllocateFlags() {
|
||||
flags = new ConvolutionThunkFlags;
|
||||
flags->xla_gpu_autotune_convolution_algorithm = true;
|
||||
flag_list = new std::vector<tensorflow::Flag>({
|
||||
tensorflow::Flag("xla_gpu_autotune_convolution_algorithm",
|
||||
&flags->xla_gpu_autotune_convolution_algorithm,
|
||||
"Auto-tune the algorithm used by convolution"),
|
||||
});
|
||||
ParseFlagsFromEnv(*flag_list);
|
||||
}
|
||||
|
||||
// Append to *append_to flag definitions associated with XLA's convolution_thunk
|
||||
// module.
|
||||
void AppendConvolutionThunkFlags(std::vector<tensorflow::Flag>* append_to) {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
|
||||
}
|
||||
|
||||
// Return a pointer to the ConvolutionThunkFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
ConvolutionThunkFlags* GetConvolutionThunkFlags() {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
return flags;
|
||||
}
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
@ -1,47 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CONVOLUTION_THUNK_FLAGS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CONVOLUTION_THUNK_FLAGS_H_
|
||||
|
||||
// Legacy flags for XLA's convolution_thunk module.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Append to *flag_list flag definitions associated with XLA's convolution_thunk
|
||||
// module.
|
||||
void AppendConvolutionThunkFlags(std::vector<tensorflow::Flag>* flag_list);
|
||||
|
||||
// The values of flags associated with XLA's convolution_thunk module.
|
||||
typedef struct {
|
||||
// Auto-tune the algorithm used by convolution
|
||||
bool xla_gpu_autotune_convolution_algorithm;
|
||||
} ConvolutionThunkFlags;
|
||||
|
||||
// Return a pointer to the ConvolutionThunkFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
ConvolutionThunkFlags* GetConvolutionThunkFlags();
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CONVOLUTION_THUNK_FLAGS_H_
|
@ -1,71 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Legacy flags for XLA's cpu_runtime module.
|
||||
|
||||
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Pointers to the parsed value of the flags and flag descriptors, initialized
|
||||
// via flags_init.
|
||||
static CpuRuntimeFlags* flags;
|
||||
static std::vector<tensorflow::Flag>* flag_list;
|
||||
static std::once_flag flags_init;
|
||||
|
||||
// Allocate *flags. Called via call_once(&flags_init,...).
|
||||
static void AllocateFlags() {
|
||||
flags = new CpuRuntimeFlags;
|
||||
flags->xla_cpu_use_eigen = true;
|
||||
flags->xla_cpu_multi_thread_eigen = true;
|
||||
flag_list = new std::vector<tensorflow::Flag>({
|
||||
tensorflow::Flag(
|
||||
"xla_cpu_use_eigen", &flags->xla_cpu_use_eigen,
|
||||
"Use Eigen for matrix multiply on the CPU platform. This "
|
||||
"is a useful hack for performance comparisons against "
|
||||
"XLA's implementation."),
|
||||
tensorflow::Flag(
|
||||
"xla_cpu_multi_thread_eigen", &flags->xla_cpu_multi_thread_eigen,
|
||||
"When generating calls to Eigen for matmul and conv, should "
|
||||
"single or multi-threaded eigen be used? "
|
||||
"Only used when --xla_cpu_use_eigen is true."),
|
||||
});
|
||||
ParseFlagsFromEnv(*flag_list);
|
||||
}
|
||||
|
||||
// Append to *append_to flag definitions associated with XLA's cpu_runtime
|
||||
// module.
|
||||
void AppendCpuRuntimeFlags(std::vector<tensorflow::Flag>* append_to) {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
|
||||
}
|
||||
|
||||
// Return a pointer to the CpuRuntimeFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
CpuRuntimeFlags* GetCpuRuntimeFlags() {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
return flags;
|
||||
}
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
@ -1,51 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_RUNTIME_FLAGS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_RUNTIME_FLAGS_H_
|
||||
|
||||
// Legacy flags for the XLA's cpu_runtime module.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Append to *flag_list flag definitions associated with XLA's cpu_runtime
|
||||
// module.
|
||||
void AppendCpuRuntimeFlags(std::vector<tensorflow::Flag>* flag_list);
|
||||
|
||||
// The values of flags associated with XLA's cpu_runtime module.
|
||||
typedef struct {
|
||||
// Use Eigen for matrix multiply on the CPU platform. This is a useful hack
|
||||
// for performance comparisons against XLA's implementation.
|
||||
bool xla_cpu_use_eigen;
|
||||
// When generating calls to Eigen for matmul and conv, should single or
|
||||
// multi-threaded eigen be used? Only used when --xla_cpu_use_eigen is true.
|
||||
bool xla_cpu_multi_thread_eigen;
|
||||
} CpuRuntimeFlags;
|
||||
|
||||
// Return a pointer to the CpuRuntimeFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
CpuRuntimeFlags* GetCpuRuntimeFlags();
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_RUNTIME_FLAGS_H_
|
@ -25,12 +25,23 @@ namespace legacy_flags {
|
||||
|
||||
struct DebugOptionsFlags {
|
||||
string xla_generate_hlo_graph;
|
||||
bool xla_hlo_graph_addresses;
|
||||
bool xla_hlo_graph_layout;
|
||||
string xla_log_hlo_text;
|
||||
string xla_generate_hlo_text_to;
|
||||
|
||||
string xla_disable_hlo_passes;
|
||||
bool xla_enable_fast_math;
|
||||
bool xla_llvm_enable_alias_scope_metadata;
|
||||
bool xla_llvm_enable_noalias_metadata;
|
||||
bool xla_llvm_enable_invariant_load_metadata;
|
||||
int32 xla_backend_optimization_level;
|
||||
bool xla_embed_ir_in_executable;
|
||||
string xla_dump_ir_to;
|
||||
string xla_dump_debug_json_to;
|
||||
|
||||
bool xla_cpu_multi_thread_eigen;
|
||||
|
||||
string xla_gpu_cuda_data_dir;
|
||||
bool xla_gpu_ftz;
|
||||
|
||||
@ -48,11 +59,20 @@ std::once_flag flags_init;
|
||||
void AllocateFlags() {
|
||||
flag_values = new DebugOptionsFlags;
|
||||
flag_values->xla_generate_hlo_graph = "";
|
||||
flag_values->xla_hlo_graph_addresses = false;
|
||||
flag_values->xla_hlo_graph_layout = false;
|
||||
flag_values->xla_log_hlo_text = "";
|
||||
flag_values->xla_generate_hlo_text_to = "";
|
||||
flag_values->xla_disable_hlo_passes = "";
|
||||
flag_values->xla_enable_fast_math = true;
|
||||
flag_values->xla_llvm_enable_alias_scope_metadata = true;
|
||||
flag_values->xla_llvm_enable_noalias_metadata = true;
|
||||
flag_values->xla_llvm_enable_invariant_load_metadata = true;
|
||||
flag_values->xla_backend_optimization_level = 3;
|
||||
flag_values->xla_embed_ir_in_executable = false;
|
||||
flag_values->xla_dump_ir_to = "";
|
||||
flag_values->xla_dump_debug_json_to = "";
|
||||
flag_values->xla_cpu_multi_thread_eigen = true;
|
||||
flag_values->xla_gpu_cuda_data_dir = "./cuda_sdk_lib";
|
||||
flag_values->xla_gpu_ftz = false;
|
||||
flag_values->xla_backend_extra_options = "";
|
||||
@ -62,10 +82,37 @@ void AllocateFlags() {
|
||||
"xla_generate_hlo_graph", &flag_values->xla_generate_hlo_graph,
|
||||
"HLO modules matching this regex will be dumped to a .dot file "
|
||||
"throughout various stages in compilation."),
|
||||
tensorflow::Flag(
|
||||
"xla_hlo_graph_addresses", &flag_values->xla_hlo_graph_addresses,
|
||||
"With xla_generate_hlo_graph, show addresses of HLO ops in "
|
||||
"graph dump."),
|
||||
tensorflow::Flag(
|
||||
"xla_hlo_graph_layout", &flag_values->xla_hlo_graph_layout,
|
||||
"With xla_generate_hlo_graph, show layout of HLO ops in "
|
||||
"graph dump."),
|
||||
tensorflow::Flag(
|
||||
"xla_log_hlo_text", &flag_values->xla_log_hlo_text,
|
||||
"HLO modules matching this regex will be dumped to LOG(INFO). "),
|
||||
tensorflow::Flag(
|
||||
"xla_generate_hlo_text_to", &flag_values->xla_generate_hlo_text_to,
|
||||
"Dump all HLO modules as text into the provided directory path."),
|
||||
tensorflow::Flag(
|
||||
"xla_enable_fast_math", &flag_values->xla_enable_fast_math,
|
||||
"Enable unsafe fast-math optimizations in the compiler; "
|
||||
"this may produce faster code at the expense of some accuracy."),
|
||||
tensorflow::Flag("xla_llvm_enable_alias_scope_metadata",
|
||||
&flag_values->xla_llvm_enable_alias_scope_metadata,
|
||||
"In LLVM-based backends, enable the emission of "
|
||||
"!alias.scope metadata in the generated IR."),
|
||||
tensorflow::Flag("xla_llvm_enable_noalias_metadata",
|
||||
&flag_values->xla_llvm_enable_noalias_metadata,
|
||||
"In LLVM-based backends, enable the emission of "
|
||||
"!noalias metadata in the generated IR."),
|
||||
tensorflow::Flag("xla_llvm_enable_invariant_load_metadata",
|
||||
&flag_values->xla_llvm_enable_invariant_load_metadata,
|
||||
"In LLVM-based backends, enable the emission of "
|
||||
"!invariant.load metadata in "
|
||||
"the generated IR."),
|
||||
tensorflow::Flag(
|
||||
"xla_backend_optimization_level",
|
||||
&flag_values->xla_backend_optimization_level,
|
||||
@ -78,6 +125,12 @@ void AllocateFlags() {
|
||||
tensorflow::Flag("xla_embed_ir_in_executable",
|
||||
&flag_values->xla_embed_ir_in_executable,
|
||||
"Embed the compiler IR as a string in the executable."),
|
||||
tensorflow::Flag("xla_dump_ir_to", &flag_values->xla_dump_ir_to,
|
||||
"Dump the compiler IR into this file/path."),
|
||||
tensorflow::Flag("xla_cpu_multi_thread_eigen",
|
||||
&flag_values->xla_cpu_multi_thread_eigen,
|
||||
"When generating calls to Eigen in the CPU backend, "
|
||||
"use multi-threaded Eigen mode."),
|
||||
tensorflow::Flag("xla_gpu_cuda_data_dir",
|
||||
&flag_values->xla_gpu_cuda_data_dir,
|
||||
"If non-empty, speficies a local directory containing "
|
||||
@ -111,6 +164,10 @@ xla::DebugOptions GetDebugOptionsFromFlags() {
|
||||
|
||||
DebugOptions options;
|
||||
options.set_xla_generate_hlo_graph(flag_values->xla_generate_hlo_graph);
|
||||
options.set_xla_hlo_graph_addresses(flag_values->xla_hlo_graph_addresses);
|
||||
options.set_xla_hlo_graph_layout(flag_values->xla_hlo_graph_layout);
|
||||
options.set_xla_log_hlo_text(flag_values->xla_log_hlo_text);
|
||||
options.set_xla_generate_hlo_text_to(flag_values->xla_generate_hlo_text_to);
|
||||
|
||||
std::vector<string> disabled_passes =
|
||||
tensorflow::str_util::Split(flag_values->xla_disable_hlo_passes, ',');
|
||||
@ -123,9 +180,18 @@ xla::DebugOptions GetDebugOptionsFromFlags() {
|
||||
flag_values->xla_backend_optimization_level);
|
||||
options.set_xla_embed_ir_in_executable(
|
||||
flag_values->xla_embed_ir_in_executable);
|
||||
options.set_xla_dump_ir_to(flag_values->xla_dump_ir_to);
|
||||
options.set_xla_dump_debug_json_to(flag_values->xla_dump_debug_json_to);
|
||||
options.set_xla_cpu_multi_thread_eigen(
|
||||
flag_values->xla_cpu_multi_thread_eigen);
|
||||
options.set_xla_gpu_cuda_data_dir(flag_values->xla_gpu_cuda_data_dir);
|
||||
options.set_xla_gpu_ftz(flag_values->xla_gpu_ftz);
|
||||
options.set_xla_llvm_enable_alias_scope_metadata(
|
||||
flag_values->xla_llvm_enable_alias_scope_metadata);
|
||||
options.set_xla_llvm_enable_noalias_metadata(
|
||||
flag_values->xla_llvm_enable_noalias_metadata);
|
||||
options.set_xla_llvm_enable_invariant_load_metadata(
|
||||
flag_values->xla_llvm_enable_invariant_load_metadata);
|
||||
|
||||
std::vector<string> extra_options_parts =
|
||||
tensorflow::str_util::Split(flag_values->xla_backend_extra_options, ',');
|
||||
|
@ -1,63 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Legacy flags for XLA's hlo_test_base module.
|
||||
|
||||
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Pointers to the parsed value of the flags and flag descriptors, initialized
|
||||
// via flags_init.
|
||||
static HloTestBaseFlags* flags;
|
||||
static std::vector<tensorflow::Flag>* flag_list;
|
||||
static std::once_flag flags_init;
|
||||
|
||||
// Allocate *flags. Called via call_once(&flags_init,...).
|
||||
static void AllocateFlags() {
|
||||
flags = new HloTestBaseFlags;
|
||||
flags->xla_hlo_test_generate_hlo_graph = false;
|
||||
flag_list = new std::vector<tensorflow::Flag>({
|
||||
tensorflow::Flag("xla_hlo_test_generate_hlo_graph",
|
||||
&flags->xla_hlo_test_generate_hlo_graph,
|
||||
"Generate graph output of HLO instructions"),
|
||||
});
|
||||
ParseFlagsFromEnv(*flag_list);
|
||||
}
|
||||
|
||||
// Append to *append_to flag definitions associated with XLA's hlo_test_base
|
||||
// module.
|
||||
void AppendHloTestBaseFlags(std::vector<tensorflow::Flag>* append_to) {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
|
||||
}
|
||||
|
||||
// Return a pointer to the HloTestBaseFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
HloTestBaseFlags* GetHloTestBaseFlags() {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
return flags;
|
||||
}
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
@ -1,47 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_TEST_BASE_FLAGS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_TEST_BASE_FLAGS_H_
|
||||
|
||||
// Legacy flags for XLA's hlo_test_base module.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Append to *flag_list flag definitions associated with XLA's hlo_test_base
|
||||
// module.
|
||||
void AppendHloTestBaseFlags(std::vector<tensorflow::Flag>* flag_list);
|
||||
|
||||
// The values of flags associated with XLA's hlo_test_base module.
|
||||
typedef struct {
|
||||
bool xla_hlo_test_generate_hlo_graph; // Generate graph output of HLO
|
||||
// instructions
|
||||
} HloTestBaseFlags;
|
||||
|
||||
// Return a pointer to the HloTestBaseFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
HloTestBaseFlags* GetHloTestBaseFlags();
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_TEST_BASE_FLAGS_H_
|
@ -36,33 +36,13 @@ static std::once_flag flags_init;
|
||||
static void AllocateFlags() {
|
||||
flags = new ServiceFlags;
|
||||
flags->xla_hlo_profile = false;
|
||||
flags->xla_log_hlo_text = "";
|
||||
flags->xla_generate_hlo_graph = "";
|
||||
flags->xla_hlo_graph_addresses = false;
|
||||
flags->xla_hlo_graph_layout = false;
|
||||
flags->xla_hlo_graph_for_compute_constant = false;
|
||||
flags->xla_dump_computations_to = "";
|
||||
flags->xla_dump_hlo_text_to = "";
|
||||
flags->xla_dump_executions_to = "";
|
||||
flag_list = new std::vector<tensorflow::Flag>({
|
||||
tensorflow::Flag(
|
||||
"xla_hlo_profile", &flags->xla_hlo_profile,
|
||||
"Instrument the computation to collect per-HLO cycle counts"),
|
||||
tensorflow::Flag(
|
||||
"xla_log_hlo_text", &flags->xla_log_hlo_text,
|
||||
"If non-empty, print the text format of "
|
||||
"HLO modules whose name partially matches this regex. E.g. "
|
||||
"xla_log_hlo_text=.* will dump the text for every module."),
|
||||
tensorflow::Flag(
|
||||
"xla_generate_hlo_graph", &flags->xla_generate_hlo_graph,
|
||||
"If non-empty, dump graph of HLO modules whose name partially "
|
||||
"matches this regex. E.g. --xla_generate_hlo_graph=.* will dump "
|
||||
"the graph of every module."),
|
||||
tensorflow::Flag("xla_hlo_graph_addresses",
|
||||
&flags->xla_hlo_graph_addresses,
|
||||
"Show addresses of HLO ops in graph"),
|
||||
tensorflow::Flag("xla_hlo_graph_layout", &flags->xla_hlo_graph_layout,
|
||||
"Show layout of HLO ops in graph"),
|
||||
tensorflow::Flag(
|
||||
"xla_hlo_graph_for_compute_constant",
|
||||
&flags->xla_hlo_graph_for_compute_constant,
|
||||
@ -72,9 +52,6 @@ static void AllocateFlags() {
|
||||
&flags->xla_dump_computations_to,
|
||||
"Dumps computations that XLA executes into the provided "
|
||||
"directory path"),
|
||||
tensorflow::Flag("xla_dump_hlo_text_to", &flags->xla_dump_hlo_text_to,
|
||||
"Dumps HLO modules that XLA executes into the provided "
|
||||
"directory path"),
|
||||
tensorflow::Flag("xla_dump_executions_to", &flags->xla_dump_executions_to,
|
||||
"Dumps parameters and results of computations that XLA "
|
||||
"executes into the provided directory path"),
|
||||
|
@ -34,23 +34,11 @@ void AppendServiceFlags(std::vector<tensorflow::Flag>* flag_list);
|
||||
typedef struct {
|
||||
bool xla_hlo_profile; // Instrument the computation to collect per-HLO cycle
|
||||
// counts
|
||||
string xla_log_hlo_text; // If non-empty, print the text format of the HLO
|
||||
// modules whose name partially
|
||||
// matches this regex. E.g. xla_log_hlo_text=.*
|
||||
// will dump the text for every module.
|
||||
string xla_generate_hlo_graph; // If non-empty, dump graph of HLO modules
|
||||
// whose name partially matches this regex.
|
||||
// E.g. --xla_generate_hlo_graph=.* will dump
|
||||
// the graph of every module.
|
||||
bool xla_hlo_graph_addresses; // Show addresses of HLO ops in graph
|
||||
bool xla_hlo_graph_layout; // Show layout of HLO ops in graph
|
||||
bool xla_hlo_graph_for_compute_constant; // If true, include hlo dumps of
|
||||
// graphs from ComputeConstant.
|
||||
// Such graphs still need to be
|
||||
// matched via
|
||||
// xla_generate_hlo_graph.
|
||||
string xla_dump_hlo_text_to; // Dumps HLO text for each HLO module that is
|
||||
// executed into the provided directory path
|
||||
string xla_dump_computations_to; // Dumps computations that XLA executes
|
||||
// into the provided directory path
|
||||
// Dumps parameters and results of computations that XLA executes into
|
||||
|
@ -62,7 +62,17 @@ Literal::StrideConfig::StrideConfig(
|
||||
std::unique_ptr<Literal> Literal::CreateFromShape(const Shape& shape) {
|
||||
auto literal = MakeUnique<Literal>();
|
||||
*literal->mutable_shape() = shape;
|
||||
literal->Reserve(ShapeUtil::ElementsIn(literal->shape()));
|
||||
if (ShapeUtil::IsTuple(shape)) {
|
||||
int64 num_elements = ShapeUtil::TupleElementCount(shape);
|
||||
literal->tuple_literals_.resize(num_elements);
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
std::unique_ptr<Literal> elem =
|
||||
CreateFromShape(ShapeUtil::GetTupleElementShape(shape, i));
|
||||
literal->tuple_literals_[i] = std::move(*elem);
|
||||
}
|
||||
} else {
|
||||
literal->Reserve(ShapeUtil::ElementsIn(literal->shape()));
|
||||
}
|
||||
return literal;
|
||||
}
|
||||
|
||||
|
@ -68,11 +68,13 @@ class BoolVector {
|
||||
}
|
||||
|
||||
BoolVector(const BoolVector& other) { CopyFrom(other); }
|
||||
BoolVector(BoolVector&&) = default;
|
||||
|
||||
BoolVector& operator=(const BoolVector& other) {
|
||||
CopyFrom(other);
|
||||
return *this;
|
||||
}
|
||||
BoolVector& operator=(BoolVector&&) = default;
|
||||
|
||||
void push_back(const bool& value) {
|
||||
resize(size_ + 1);
|
||||
@ -147,10 +149,12 @@ class Literal {
|
||||
Literal() {}
|
||||
|
||||
Literal(const Literal& other) = default;
|
||||
Literal(Literal&&) = default;
|
||||
|
||||
explicit Literal(const LiteralProto& other) { CopyFromProto(other); }
|
||||
|
||||
Literal& operator=(const Literal& other) = default;
|
||||
Literal& operator=(Literal&&) = default;
|
||||
|
||||
LiteralProto ToProto() const;
|
||||
|
||||
|
@ -291,20 +291,20 @@ TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
|
||||
auto colmajor = MakeUnique<Literal>();
|
||||
*colmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2});
|
||||
*colmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
|
||||
colmajor.get()->Reserve(4);
|
||||
colmajor.get()->Set<float>({0, 0}, 1.0);
|
||||
colmajor.get()->Set<float>({0, 1}, 2.0);
|
||||
colmajor.get()->Set<float>({1, 0}, 3.0);
|
||||
colmajor.get()->Set<float>({1, 1}, 4.0);
|
||||
colmajor->Reserve(4);
|
||||
colmajor->Set<float>({0, 0}, 1.0);
|
||||
colmajor->Set<float>({0, 1}, 2.0);
|
||||
colmajor->Set<float>({1, 0}, 3.0);
|
||||
colmajor->Set<float>({1, 1}, 4.0);
|
||||
|
||||
auto rowmajor = MakeUnique<Literal>();
|
||||
*rowmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2});
|
||||
*rowmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
|
||||
rowmajor.get()->Reserve(4);
|
||||
rowmajor.get()->Set<float>({0, 0}, 1.0);
|
||||
rowmajor.get()->Set<float>({0, 1}, 2.0);
|
||||
rowmajor.get()->Set<float>({1, 0}, 3.0);
|
||||
rowmajor.get()->Set<float>({1, 1}, 4.0);
|
||||
rowmajor->Reserve(4);
|
||||
rowmajor->Set<float>({0, 0}, 1.0);
|
||||
rowmajor->Set<float>({0, 1}, 2.0);
|
||||
rowmajor->Set<float>({1, 0}, 3.0);
|
||||
rowmajor->Set<float>({1, 1}, 4.0);
|
||||
|
||||
EXPECT_TRUE(rowmajor->Equal(*colmajor));
|
||||
}
|
||||
@ -341,6 +341,16 @@ TEST_F(LiteralUtilTest, IsAllTuple) {
|
||||
EXPECT_FALSE(tuple->IsAll(1));
|
||||
}
|
||||
|
||||
// Verifies that CreateFromShape works for tuples.
|
||||
TEST_F(LiteralUtilTest, CreateFromShapeTuple) {
|
||||
auto scalar = Literal::CreateR0<float>(0.0);
|
||||
auto matrix = Literal::CreateR2<int32>({{0, 0}, {0, 0}});
|
||||
auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
|
||||
|
||||
auto x = Literal::CreateFromShape(tuple->shape());
|
||||
EXPECT_TRUE(tuple->Equal(*x));
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, IsAll) {
|
||||
EXPECT_TRUE(Literal::CreateR0<bool>(false)->IsAll(0));
|
||||
EXPECT_TRUE(Literal::CreateR0<bool>(true)->IsAll(1));
|
||||
@ -694,7 +704,7 @@ TEST_F(LiteralUtilTest, Copy) {
|
||||
const int64 step[] = {1, 1, 1, 1};
|
||||
uint32 seqnr = 0;
|
||||
auto init_proc = [&](const std::vector<int64>& indexes) {
|
||||
source.get()->Set(indexes, ++seqnr);
|
||||
source->Set(indexes, ++seqnr);
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -705,7 +715,7 @@ TEST_F(LiteralUtilTest, Copy) {
|
||||
const int64 dest_base[] = {6, 4, 12, 2};
|
||||
const int64 copy_size[] = {7, 8, 11, 9};
|
||||
|
||||
TF_EXPECT_OK(blank.get()->Copy(*source, src_base, dest_base, copy_size));
|
||||
TF_EXPECT_OK(blank->Copy(*source, src_base, dest_base, copy_size));
|
||||
std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
|
||||
std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
|
||||
bool matched = true;
|
||||
@ -729,13 +739,13 @@ TEST_F(LiteralUtilTest, Copy) {
|
||||
TEST_F(LiteralUtilTest, CopyScalars) {
|
||||
auto zero = Literal::CreateR0<uint32>(0);
|
||||
auto nine = Literal::CreateR0<uint32>(9);
|
||||
TF_EXPECT_OK(zero.get()->Copy(*nine, {}, {}, {}));
|
||||
TF_EXPECT_OK(zero->Copy(*nine, {}, {}, {}));
|
||||
EXPECT_TRUE(zero->Equal(*nine));
|
||||
|
||||
auto vect = Literal::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
|
||||
TF_EXPECT_OK(zero.get()->Copy(*vect, {5}, {}, {}));
|
||||
TF_EXPECT_OK(zero->Copy(*vect, {5}, {}, {}));
|
||||
EXPECT_EQ(zero->Get<uint32>({}), 17);
|
||||
TF_EXPECT_OK(vect.get()->Copy(*zero, {}, {4}, {}));
|
||||
TF_EXPECT_OK(vect->Copy(*zero, {}, {4}, {}));
|
||||
EXPECT_EQ(vect->Get<uint32>({4}), 17);
|
||||
}
|
||||
|
||||
@ -796,7 +806,7 @@ TEST_F(LiteralUtilTest, Populate) {
|
||||
// with zero.
|
||||
return literal->LinearIndex(indexes) + 17;
|
||||
};
|
||||
TF_EXPECT_OK(literal.get()->Populate<uint32>(generator));
|
||||
TF_EXPECT_OK(literal->Populate<uint32>(generator));
|
||||
|
||||
std::vector<int64> zero_base(data.dimensions.size(), 0);
|
||||
std::vector<int64> step(data.dimensions.size(), 1);
|
||||
|
@ -58,7 +58,7 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
|
||||
}
|
||||
|
||||
int64 elements = ShapeUtil::ElementsIn(shape);
|
||||
result.get()->Resize(elements, std::numeric_limits<float>::quiet_NaN());
|
||||
result->Resize(elements, std::numeric_limits<float>::quiet_NaN());
|
||||
std::vector<float>* field = result->mutable_f32s();
|
||||
char* data = tensorflow::bit_cast<char*>(field->data());
|
||||
uint64 bytes = elements * sizeof(float);
|
||||
|
@ -252,6 +252,20 @@ ReferenceUtil::ReduceWindow4DGeneric(
|
||||
padding);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::BatchNorm4D(
|
||||
const Array4D<float>& input, const Array4D<float>& mean,
|
||||
const Array4D<float>& var, const Array4D<float>& scale,
|
||||
const Array4D<float>& offset, float epsilon) {
|
||||
auto normalized =
|
||||
*MapArray4D(input, mean, [](float a, float b) { return a - b; });
|
||||
normalized = *MapArray4D(normalized, var, [&](float a, float b) {
|
||||
return a / std::sqrt(b + epsilon);
|
||||
});
|
||||
normalized =
|
||||
*MapArray4D(normalized, scale, [](float a, float b) { return a * b; });
|
||||
return MapArray4D(normalized, offset, [](float a, float b) { return a + b; });
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Array4D<float>>
|
||||
ReferenceUtil::SelectAndScatter4DGePlus(
|
||||
const Array4D<float>& operand, const Array4D<float>& source, float init,
|
||||
@ -491,6 +505,30 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
|
||||
}
|
||||
}
|
||||
}
|
||||
if (samples == 0 || kx == 0 || ky == 0 || ox == 0 || oy == 0 || oz == 0 ||
|
||||
iz == 0) {
|
||||
LOG(INFO) << "Output will be trivially empty because one of these "
|
||||
"dimensions is 0: samples: "
|
||||
<< samples << " kx: " << kx << " ky: " << ky << " ox: " << ox
|
||||
<< " oy: " << oy << " oz: " << oz << " iz: " << iz;
|
||||
return result;
|
||||
}
|
||||
bool trivial = true;
|
||||
auto check_trivial = [&trivial](tensorflow::gtl::ArraySlice<int64> indices,
|
||||
float value) {
|
||||
if (value != 0.0) {
|
||||
trivial = false;
|
||||
}
|
||||
};
|
||||
lhs.Each(check_trivial);
|
||||
if (trivial) {
|
||||
LOG(FATAL) << "LHS is all 0.0.";
|
||||
}
|
||||
trivial = true;
|
||||
rhs.Each(check_trivial);
|
||||
if (trivial) {
|
||||
LOG(FATAL) << "RHS is all 0.0.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -566,6 +604,38 @@ ReferenceUtil::ReduceToRowArray2D(
|
||||
return result;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D(
|
||||
const std::vector<float>& array, const std::vector<int64>& bounds,
|
||||
int64 broadcast_from_dim) {
|
||||
auto result =
|
||||
MakeUnique<Array4D<float>>(bounds[0], bounds[1], bounds[2], bounds[3]);
|
||||
for (int64 i = 0; i < result->n1(); ++i) {
|
||||
for (int64 j = 0; j < result->n2(); ++j) {
|
||||
for (int64 k = 0; k < result->n3(); ++k) {
|
||||
for (int64 l = 0; l < result->n4(); ++l) {
|
||||
switch (broadcast_from_dim) {
|
||||
case 0:
|
||||
(*result)(i, j, k, l) = array[i];
|
||||
break;
|
||||
case 1:
|
||||
(*result)(i, j, k, l) = array[j];
|
||||
break;
|
||||
case 2:
|
||||
(*result)(i, j, k, l) = array[k];
|
||||
break;
|
||||
case 3:
|
||||
(*result)(i, j, k, l) = array[l];
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
|
||||
const Array3D<float>& array, float init,
|
||||
tensorflow::gtl::ArraySlice<int64> dims,
|
||||
|
@ -120,6 +120,11 @@ class ReferenceUtil {
|
||||
tensorflow::gtl::ArraySlice<int64> dims,
|
||||
std::function<float(float, float)> reduce_function);
|
||||
|
||||
// Broadcast 1D dimension to 4D, from the dimension `broadcast_from_dim`.
|
||||
static std::unique_ptr<Array4D<float>> Broadcast1DTo4D(
|
||||
const std::vector<float>& array, const std::vector<int64>& bounds,
|
||||
int64 broadcast_from_dim);
|
||||
|
||||
// Returns the result of reducing the 3D array to a 2D array, reducing away
|
||||
// the dimensions specified in dims.
|
||||
static std::unique_ptr<Array2D<float>> Reduce3DTo2D(
|
||||
@ -169,6 +174,12 @@ class ReferenceUtil {
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride,
|
||||
const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding);
|
||||
|
||||
// Batch normalize data.
|
||||
static std::unique_ptr<Array4D<float>> BatchNorm4D(
|
||||
const Array4D<float>& input, const Array4D<float>& mean,
|
||||
const Array4D<float>& var, const Array4D<float>& scale,
|
||||
const Array4D<float>& offset, float epsilon);
|
||||
|
||||
// Performs select and scatter with Greater Than or equal as the select, plus
|
||||
// as the scatter, and Same Padding.
|
||||
static std::unique_ptr<Array4D<float>> SelectAndScatter4DGePlus(
|
||||
@ -396,6 +407,41 @@ class ReferenceUtil {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Applies map_function to each pair of elements in the input lhs and rhs
|
||||
// (4D array) and returns the result.
|
||||
template <typename F>
|
||||
static std::unique_ptr<Array4D<float>> MapArray4D(const Array4D<float>& lhs,
|
||||
const Array4D<float>& rhs,
|
||||
F&& map_function) {
|
||||
return MapWithIndexArray4D(
|
||||
lhs, rhs, [&](float lhs, float rhs, int64, int64, int64, int64) {
|
||||
return map_function(lhs, rhs);
|
||||
});
|
||||
}
|
||||
|
||||
// Applies map_function to each pair of element in lhs and rhs (4D array) and
|
||||
// returns the result.
|
||||
// (plane, depth, height, width) index of each element is also provided as
|
||||
// arguments to map_function.
|
||||
template <typename F>
|
||||
static std::unique_ptr<Array4D<float>> MapWithIndexArray4D(
|
||||
const Array4D<float>& lhs, const Array4D<float>& rhs, F&& map_function) {
|
||||
auto result = MakeUnique<Array4D<float>>(lhs.planes(), lhs.depth(),
|
||||
lhs.height(), lhs.width());
|
||||
for (int64 plane = 0; plane < lhs.planes(); ++plane) {
|
||||
for (int64 depth = 0; depth < lhs.depth(); ++depth) {
|
||||
for (int64 height = 0; height < lhs.height(); ++height) {
|
||||
for (int64 width = 0; width < lhs.width(); ++width) {
|
||||
(*result)(plane, depth, height, width) = map_function(
|
||||
lhs(plane, depth, height, width),
|
||||
rhs(plane, depth, height, width), plane, depth, height, width);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the result of a 2D pad on an input matrix.
|
||||
static std::unique_ptr<Array2D<float>> PadArray2D(
|
||||
const Array2D<float>& operand, const PaddingConfig& padding,
|
||||
|
@ -710,9 +710,10 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":buffer_liveness",
|
||||
":heap_simulator",
|
||||
":hlo",
|
||||
":hlo_ordering",
|
||||
":hlo_proto",
|
||||
":hlo_scheduling",
|
||||
":logical_buffer",
|
||||
":tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -739,6 +740,7 @@ cc_test(
|
||||
":flatten_call_graph",
|
||||
":hlo",
|
||||
":hlo_ordering",
|
||||
":hlo_scheduling",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -751,42 +753,14 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "heap_simulator_test",
|
||||
size = "small",
|
||||
srcs = ["heap_simulator_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_ordering",
|
||||
":logical_buffer",
|
||||
":tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
# The hlo_ordering library contains both hlo_ordering and heap_simulator because
|
||||
# they are mutually dependent.
|
||||
cc_library(
|
||||
name = "hlo_ordering",
|
||||
srcs = [
|
||||
"heap_simulator.cc",
|
||||
"hlo_ordering.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"heap_simulator.h",
|
||||
"hlo_ordering.h",
|
||||
],
|
||||
srcs = ["hlo_ordering.cc"],
|
||||
hdrs = ["hlo_ordering.h"],
|
||||
deps = [
|
||||
":call_graph",
|
||||
":hlo",
|
||||
":hlo_proto",
|
||||
":liveness_util",
|
||||
":logical_buffer",
|
||||
":tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -803,6 +777,77 @@ cc_test(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_ordering",
|
||||
":hlo_scheduling",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "heap_simulator",
|
||||
srcs = ["heap_simulator.cc"],
|
||||
hdrs = ["heap_simulator.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_ordering",
|
||||
":hlo_proto",
|
||||
":liveness_util",
|
||||
":logical_buffer",
|
||||
":tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "heap_simulator_test",
|
||||
size = "small",
|
||||
srcs = ["heap_simulator_test.cc"],
|
||||
deps = [
|
||||
":heap_simulator",
|
||||
":hlo",
|
||||
":hlo_ordering",
|
||||
":logical_buffer",
|
||||
":tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_scheduling",
|
||||
srcs = ["hlo_scheduling.cc"],
|
||||
hdrs = ["hlo_scheduling.h"],
|
||||
deps = [
|
||||
":heap_simulator",
|
||||
":hlo",
|
||||
":hlo_ordering",
|
||||
":hlo_proto",
|
||||
":logical_buffer",
|
||||
":tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "hlo_scheduling_test",
|
||||
size = "small",
|
||||
srcs = ["hlo_scheduling_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_ordering",
|
||||
":hlo_scheduling",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
@ -1177,6 +1222,22 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_value",
|
||||
srcs = ["hlo_value.cc"],
|
||||
hdrs = ["hlo_value.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
"//tensorflow/compiler/xla:shape_tree",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_dataflow_analysis",
|
||||
srcs = [
|
||||
@ -1189,8 +1250,8 @@ cc_library(
|
||||
":call_graph",
|
||||
":hlo",
|
||||
":hlo_ordering",
|
||||
":hlo_value",
|
||||
":liveness_util",
|
||||
"//tensorflow/compiler/xla:shape_tree",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -1198,7 +1259,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1225,20 +1285,32 @@ cc_test(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_alias_analysis",
|
||||
srcs = [
|
||||
"hlo_alias_analysis.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"hlo_alias_analysis.h",
|
||||
],
|
||||
name = "hlo_buffer",
|
||||
srcs = ["hlo_buffer.cc"],
|
||||
hdrs = ["hlo_buffer.h"],
|
||||
deps = [
|
||||
":call_graph",
|
||||
":hlo",
|
||||
":hlo_dataflow_analysis",
|
||||
":logical_buffer",
|
||||
":hlo_value",
|
||||
"//tensorflow/compiler/xla:shape_tree",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_alias_analysis",
|
||||
srcs = ["hlo_alias_analysis.cc"],
|
||||
hdrs = ["hlo_alias_analysis.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_buffer",
|
||||
":hlo_dataflow_analysis",
|
||||
":hlo_value",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -1424,6 +1496,7 @@ cc_library(
|
||||
":hlo",
|
||||
":hlo_dce",
|
||||
":hlo_ordering",
|
||||
":hlo_scheduling",
|
||||
":liveness_util",
|
||||
":logical_buffer",
|
||||
":tuple_points_to_analysis",
|
||||
@ -1520,8 +1593,8 @@ cc_library(
|
||||
"hlo_pass_pipeline.h",
|
||||
],
|
||||
deps = [
|
||||
":compiler",
|
||||
":hlo",
|
||||
":hlo_graph_dumper",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -1730,6 +1803,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/heap_simulator.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
|
@ -62,9 +62,7 @@ CompileOnlyService::CompileOnlyService(
|
||||
std::unique_ptr<Backend> compute_constant_backend)
|
||||
: Service(options, /*backend=*/nullptr,
|
||||
std::move(compute_constant_backend)),
|
||||
compiler_(compiler) {
|
||||
runs_in_client_process_ = true;
|
||||
}
|
||||
compiler_(compiler) {}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileOnlyService::CompileAheadOfTime(
|
||||
@ -124,8 +122,7 @@ CompileOnlyService::CompileAheadOfTime(
|
||||
hlo_modules.push_back(std::move(hlo_module));
|
||||
}
|
||||
|
||||
return compiler_->CompileAheadOfTime(std::move(hlo_modules),
|
||||
MakeHloDumper(), options);
|
||||
return compiler_->CompileAheadOfTime(std::move(hlo_modules), options);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -92,13 +92,6 @@ class AotCompilationOptions {
|
||||
// platform.
|
||||
class Compiler {
|
||||
public:
|
||||
// Callback signature used to dump the HLO graph during compilation.
|
||||
// Different compiler backends will call this as they please, providing
|
||||
// a view of the HLO at different points in compilation -- context for the
|
||||
// dump is indicated by the label string.
|
||||
using HloDumper =
|
||||
std::function<void(const HloModule& module, const string& label)>;
|
||||
|
||||
virtual ~Compiler() {}
|
||||
|
||||
// Returns the ID of the platform that this compiler targets.
|
||||
@ -113,21 +106,20 @@ class Compiler {
|
||||
//
|
||||
// Use the overload below to compile computations that run in parallel.
|
||||
virtual StatusOr<std::unique_ptr<Executable>> Compile(
|
||||
std::unique_ptr<HloModule> module, HloDumper dump_hlo,
|
||||
std::unique_ptr<HloModule> module,
|
||||
perftools::gputools::StreamExecutor* executor) = 0;
|
||||
|
||||
// Compiles a set of HLO modules that can run in parallel, potentially
|
||||
// communicating data between the modules, and returns a corresponding
|
||||
// sequence of executable objects.
|
||||
virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||
std::vector<std::unique_ptr<HloModule>> modules, HloDumper dump_hlo,
|
||||
std::vector<std::unique_ptr<HloModule>> modules,
|
||||
std::vector<perftools::gputools::StreamExecutor*> stream_exec) = 0;
|
||||
|
||||
// Compiles the HLO module for ahead-of-time execution. This is intended for
|
||||
// use in static compilation.
|
||||
virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
||||
HloDumper dump_hlo,
|
||||
const AotCompilationOptions& options) = 0;
|
||||
|
||||
/////
|
||||
|
@ -68,6 +68,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_util",
|
||||
"//tensorflow/compiler/xla/service:hlo_scheduling",
|
||||
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/service:inliner",
|
||||
@ -187,7 +188,6 @@ cc_library(
|
||||
":dot_op_emitter",
|
||||
":elemental_ir_emitter",
|
||||
":ir_emission_utils",
|
||||
":shape_partition",
|
||||
":simple_orc_jit",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -196,7 +196,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
@ -227,7 +226,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
|
||||
@ -288,8 +286,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/legacy_flags:compiler_functor_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm//:analysis",
|
||||
@ -484,7 +480,6 @@ cc_library(
|
||||
":cpu_runtime",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
],
|
||||
)
|
||||
@ -511,7 +506,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/core:lib",
|
||||
@ -548,7 +542,6 @@ cc_test(
|
||||
":shape_partition",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
|
@ -35,8 +35,6 @@ limitations under the License.
|
||||
#include "external/llvm/include/llvm/Transforms/IPO.h"
|
||||
#include "external/llvm/include/llvm/Transforms/IPO/AlwaysInliner.h"
|
||||
#include "external/llvm/include/llvm/Transforms/IPO/PassManagerBuilder.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
|
||||
@ -45,7 +43,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
@ -66,14 +63,9 @@ operator()(llvm::Module& module) const {
|
||||
|
||||
VLOG(2) << "IR before optimizations";
|
||||
XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module));
|
||||
legacy_flags::CompilerFunctorFlags* flags =
|
||||
legacy_flags::GetCompilerFunctorFlags();
|
||||
string dump_path = flags->xla_debug_cpu_dump_ir;
|
||||
if (!dump_path.empty()) {
|
||||
std::unique_ptr<tensorflow::WritableFile> f;
|
||||
TF_CHECK_OK(tensorflow::Env::Default()->NewAppendableFile(dump_path, &f));
|
||||
TF_CHECK_OK(f->Append(llvm_ir::DumpModuleToString(module)));
|
||||
TF_CHECK_OK(f->Close());
|
||||
|
||||
if (pre_optimization_callback_) {
|
||||
TF_CHECK_OK(pre_optimization_callback_(module));
|
||||
}
|
||||
|
||||
// Build up optimization pipeline.
|
||||
@ -99,6 +91,10 @@ operator()(llvm::Module& module) const {
|
||||
VLOG(2) << "IR after optimizations";
|
||||
XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module));
|
||||
|
||||
if (post_optimization_callback_) {
|
||||
TF_CHECK_OK(post_optimization_callback_(module));
|
||||
}
|
||||
|
||||
// Generate code.
|
||||
llvm::MCContext* mc_context;
|
||||
llvm::legacy::PassManager codegen_passes;
|
||||
@ -156,12 +152,7 @@ std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl(
|
||||
{"llvm.tanh.f32", runtime::kTanhV8F32, 8},
|
||||
};
|
||||
|
||||
// Our vectorized library calls are currently implement by calling into Eigen.
|
||||
// As such, only emit calls to these routines if --xla_cpu_use_eigen is
|
||||
// enabled.
|
||||
legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags();
|
||||
if (flags->xla_cpu_use_eigen &&
|
||||
(arch == llvm::Triple::x86 || llvm::Triple::x86_64)) {
|
||||
if (arch == llvm::Triple::x86 || llvm::Triple::x86_64) {
|
||||
llvm::SmallVector<llvm::StringRef, 32> features;
|
||||
feature_string.split(features, ',', -1, /*KeepEmpty=*/false);
|
||||
if (std::find(features.begin(), features.end(), "+sse4.1") !=
|
||||
|
@ -39,13 +39,22 @@ class CompilerFunctor {
|
||||
// Returns a VectorIntrinsics where all intrinsics are available.
|
||||
static VectorIntrinsics AllIntrinsics();
|
||||
|
||||
explicit CompilerFunctor(llvm::TargetMachine* target_machine,
|
||||
const Disassembler* disassembler, int opt_level,
|
||||
const VectorIntrinsics& available_intrinsics)
|
||||
// A callback of this type can be run before and/or after IR-level
|
||||
// optimization to e.g. dump out the generated IR to disk or gather some
|
||||
// statistics.
|
||||
using OptimizationCallback = std::function<Status(const llvm::Module&)>;
|
||||
|
||||
explicit CompilerFunctor(
|
||||
llvm::TargetMachine* target_machine, const Disassembler* disassembler,
|
||||
int opt_level, const VectorIntrinsics& available_intrinsics,
|
||||
OptimizationCallback pre_optimization_callback = nullptr,
|
||||
OptimizationCallback post_optimization_callback = nullptr)
|
||||
: target_machine_(target_machine),
|
||||
disassembler_(CHECK_NOTNULL(disassembler)),
|
||||
opt_level_(opt_level),
|
||||
available_intrinsics_(available_intrinsics) {}
|
||||
available_intrinsics_(available_intrinsics),
|
||||
pre_optimization_callback_(pre_optimization_callback),
|
||||
post_optimization_callback_(post_optimization_callback) {}
|
||||
|
||||
// Compile a Module to an ObjectFile.
|
||||
llvm::object::OwningBinary<llvm::object::ObjectFile> operator()(
|
||||
@ -61,6 +70,8 @@ class CompilerFunctor {
|
||||
const Disassembler* disassembler_;
|
||||
const unsigned opt_level_;
|
||||
const VectorIntrinsics available_intrinsics_;
|
||||
OptimizationCallback pre_optimization_callback_;
|
||||
OptimizationCallback post_optimization_callback_;
|
||||
};
|
||||
|
||||
} // namespace cpu
|
||||
|
@ -15,7 +15,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
@ -30,11 +29,6 @@ namespace xla {
|
||||
namespace cpu {
|
||||
|
||||
StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
|
||||
legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags();
|
||||
if (!flags->xla_cpu_use_eigen) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
for (HloInstruction* hlo :
|
||||
module->entry_computation()->MakeInstructionPostOrder()) {
|
||||
|
@ -69,6 +69,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/service/inliner.h"
|
||||
@ -81,6 +82,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace se = ::perftools::gputools;
|
||||
|
||||
@ -244,9 +246,9 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
|
||||
};
|
||||
} // namespace
|
||||
|
||||
Status CpuCompiler::RunHloPasses(HloModule* module, HloDumper dump_hlo) {
|
||||
Status CpuCompiler::RunHloPasses(HloModule* module) {
|
||||
// Optimization pipeline.
|
||||
HloPassPipeline pipeline("CPU", dump_hlo);
|
||||
HloPassPipeline pipeline("CPU");
|
||||
pipeline.AddInvariantChecker<HloVerifier>();
|
||||
|
||||
// TODO(b/35786417): Re-enable inliner pass after fixing the bug and deciding
|
||||
@ -255,8 +257,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, HloDumper dump_hlo) {
|
||||
|
||||
pipeline.AddPass<ConvCanonicalization>();
|
||||
{
|
||||
auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification",
|
||||
dump_hlo);
|
||||
auto& pass =
|
||||
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
|
||||
pass.AddPass<AlgebraicSimplifier>(
|
||||
/*is_layout_sensitive=*/false,
|
||||
[](const Shape&, const Shape&) { return false; },
|
||||
@ -343,25 +345,45 @@ llvm::CodeGenOpt::Level CodeGenOptLevel(const HloModuleConfig& module_config) {
|
||||
}
|
||||
}
|
||||
|
||||
Status AppendIRToFile(const string& file_name, const string& ir_module_string) {
|
||||
std::unique_ptr<tensorflow::WritableFile> f;
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::Env::Default()->NewAppendableFile(file_name, &f));
|
||||
TF_RETURN_IF_ERROR(f->Append(ir_module_string));
|
||||
TF_RETURN_IF_ERROR(f->Close());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
|
||||
std::unique_ptr<HloModule> module, HloDumper dump_hlo,
|
||||
se::StreamExecutor* stream_exec) {
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec) {
|
||||
VLOG(1) << "Compiling: " << module->name();
|
||||
TF_RET_CHECK(stream_exec != nullptr);
|
||||
std::call_once(llvm_command_line_options_initialized,
|
||||
&InitializeLLVMCommandLineOptions, module->config());
|
||||
|
||||
const string dump_ir_to = module->config().debug_options().xla_dump_ir_to();
|
||||
|
||||
auto dump_ir_to_disk = [dump_ir_to](const llvm::Module& module) {
|
||||
if (!dump_ir_to.empty()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
AppendIRToFile(dump_ir_to, llvm_ir::DumpModuleToString(module)));
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
// Compile must be thread-safe so create a new LLVM context for the module.
|
||||
auto llvm_context = MakeUnique<llvm::LLVMContext>();
|
||||
auto llvm_module =
|
||||
MakeUnique<llvm::Module>("__compute_module", *llvm_context);
|
||||
auto jit = MakeUnique<SimpleOrcJIT>(CompilerTargetOptions(module->config()),
|
||||
CodeGenOptLevel(module->config()));
|
||||
CodeGenOptLevel(module->config()),
|
||||
dump_ir_to_disk, dump_ir_to_disk);
|
||||
llvm_module->setDataLayout(jit->data_layout());
|
||||
llvm_module->setTargetTriple(jit->target_triple().getTriple());
|
||||
|
||||
TF_RETURN_IF_ERROR(RunHloPasses(module.get(), dump_hlo));
|
||||
TF_RETURN_IF_ERROR(RunHloPasses(module.get()));
|
||||
|
||||
HloComputation* computation = module->entry_computation();
|
||||
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx;
|
||||
@ -373,7 +395,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
|
||||
|
||||
std::unique_ptr<Executable> cpu_executable;
|
||||
|
||||
// Cache this flag here since we'll want to access it after the module's
|
||||
// Cache these flags here since we'll want to access them after the module's
|
||||
// ownership is std::moved.
|
||||
const bool embed_ir_in_executable =
|
||||
module->config().debug_options().xla_embed_ir_in_executable();
|
||||
@ -381,6 +403,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
|
||||
module->config().debug_options().xla_dump_debug_json_to();
|
||||
|
||||
if (CpuParallelBackendRequested(module->config())) {
|
||||
VLOG(1) << "Using parallel cpu backend";
|
||||
|
||||
// Run buffer analysis on the HLO graph. This analysis figures out which
|
||||
// temporary buffers are required to run the computation.
|
||||
// DependencyHloOrdering is used for the parallel emitter because the order
|
||||
@ -475,6 +499,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
|
||||
.set_ir_module_string(ir_module_string);
|
||||
}
|
||||
} else {
|
||||
VLOG(1) << "Using sequential cpu backend";
|
||||
|
||||
// Select an order for emitting the HLO instructions for each
|
||||
// computation. Using this sequence enables tighter buffer liveness analysis
|
||||
// and reduced memory usage (as compared to using DependencyHloOrdering).
|
||||
@ -540,11 +566,12 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(1) << "Compilation finished";
|
||||
return std::move(cpu_executable);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> CpuCompiler::Compile(
|
||||
std::vector<std::unique_ptr<HloModule>> modules, HloDumper dump_hlos,
|
||||
std::vector<std::unique_ptr<HloModule>> modules,
|
||||
std::vector<se::StreamExecutor*> stream_execs) {
|
||||
return Unimplemented(
|
||||
"Compilation of multiple HLO modules is not yet supported on CPU.");
|
||||
@ -552,7 +579,6 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> CpuCompiler::Compile(
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
||||
HloDumper dump_hlo,
|
||||
const AotCompilationOptions& aot_options) {
|
||||
TF_RET_CHECK(!modules.empty());
|
||||
std::call_once(llvm_command_line_options_initialized,
|
||||
@ -641,8 +667,9 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
||||
std::vector<std::unique_ptr<AotCompilationResult>> results;
|
||||
for (size_t i = 0; i < modules.size(); ++i) {
|
||||
HloModule* module = modules[i].get();
|
||||
VLOG(1) << "Compiling ahead-of-time: " << module->name();
|
||||
|
||||
TF_RETURN_IF_ERROR(RunHloPasses(module, dump_hlo));
|
||||
TF_RETURN_IF_ERROR(RunHloPasses(module));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
SequentialHloOrdering::HloModuleSequence module_sequence,
|
||||
@ -719,6 +746,8 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
||||
std::move(object_file_data), std::move(buffer_sizes),
|
||||
result_slice.index()));
|
||||
}
|
||||
|
||||
VLOG(1) << "Compilation finished";
|
||||
return std::move(results);
|
||||
}
|
||||
|
||||
|
@ -110,16 +110,15 @@ class CpuCompiler : public Compiler {
|
||||
~CpuCompiler() override {}
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> Compile(
|
||||
std::unique_ptr<HloModule> module, HloDumper dump_hlo,
|
||||
std::unique_ptr<HloModule> module,
|
||||
perftools::gputools::StreamExecutor* stream_exec) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||
std::vector<std::unique_ptr<HloModule>> modules, HloDumper dump_hlo,
|
||||
std::vector<std::unique_ptr<HloModule>> modules,
|
||||
std::vector<perftools::gputools::StreamExecutor*> stream_exec) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
||||
HloDumper dump_hlo,
|
||||
const AotCompilationOptions& options) override;
|
||||
|
||||
perftools::gputools::Platform::Id PlatformId() const override;
|
||||
@ -132,7 +131,7 @@ class CpuCompiler : public Compiler {
|
||||
|
||||
// Runs the HLO passes which are necessary for both optimizations and
|
||||
// correctness.
|
||||
Status RunHloPasses(HloModule* module, HloDumper dump_hlo);
|
||||
Status RunHloPasses(HloModule* module);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler);
|
||||
};
|
||||
|
@ -22,9 +22,9 @@ limitations under the License.
|
||||
#include "external/llvm/include/llvm/IR/Instructions.h"
|
||||
#include "external/llvm/include/llvm/IR/Module.h"
|
||||
#include "external/llvm/include/llvm/IR/Value.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -44,7 +44,8 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs,
|
||||
const llvm_ir::IrArray& lhs_array,
|
||||
const llvm_ir::IrArray& rhs_array,
|
||||
llvm::Value* executable_run_options_value,
|
||||
llvm::IRBuilder<>* ir_builder)
|
||||
llvm::IRBuilder<>* ir_builder,
|
||||
const HloModuleConfig& hlo_module_config)
|
||||
: dot_(dot),
|
||||
transpose_lhs_(transpose_lhs),
|
||||
transpose_rhs_(transpose_rhs),
|
||||
@ -52,18 +53,20 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs,
|
||||
lhs_array_(lhs_array),
|
||||
rhs_array_(rhs_array),
|
||||
executable_run_options_value_(executable_run_options_value),
|
||||
ir_builder_(ir_builder) {}
|
||||
ir_builder_(ir_builder),
|
||||
hlo_module_config_(hlo_module_config) {}
|
||||
|
||||
/* static */ tensorflow::Status DotOpEmitter::EmitDotOperation(
|
||||
const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs,
|
||||
const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
|
||||
const llvm_ir::IrArray& rhs_array,
|
||||
llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder) {
|
||||
llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder,
|
||||
const HloModuleConfig& hlo_module_config) {
|
||||
PrimitiveType type = target_array.GetShape().element_type();
|
||||
TF_RET_CHECK(F32 == type || F64 == type);
|
||||
DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array,
|
||||
lhs_array, rhs_array, executable_run_options_value,
|
||||
ir_builder);
|
||||
ir_builder, hlo_module_config);
|
||||
return dot_emitter.Emit();
|
||||
}
|
||||
|
||||
@ -233,20 +236,20 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() {
|
||||
// The two transpose_... parameters are actually booleans, but we use int32
|
||||
// to avoid target-dependent calling convention details.
|
||||
|
||||
legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags();
|
||||
bool multi_threaded = flags->xla_cpu_multi_thread_eigen;
|
||||
bool multi_threaded_eigen =
|
||||
hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
|
||||
PrimitiveType type = target_array_.GetShape().element_type();
|
||||
llvm::Type* float_type;
|
||||
const char* fn_name;
|
||||
switch (type) {
|
||||
case F32:
|
||||
fn_name = multi_threaded
|
||||
fn_name = multi_threaded_eigen
|
||||
? runtime::kEigenMatmulF32SymbolName
|
||||
: runtime::kEigenSingleThreadedMatmulF32SymbolName;
|
||||
float_type = ir_builder_->getFloatTy();
|
||||
break;
|
||||
case F64:
|
||||
fn_name = multi_threaded
|
||||
fn_name = multi_threaded_eigen
|
||||
? runtime::kEigenMatmulF64SymbolName
|
||||
: runtime::kEigenSingleThreadedMatmulF64SymbolName;
|
||||
float_type = ir_builder_->getDoubleTy();
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "external/llvm/include/llvm/IR/IRBuilder.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -39,7 +40,8 @@ class DotOpEmitter {
|
||||
const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs,
|
||||
const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
|
||||
const llvm_ir::IrArray& rhs_array,
|
||||
llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder);
|
||||
llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder,
|
||||
const HloModuleConfig& hlo_module_config);
|
||||
|
||||
private:
|
||||
DotOpEmitter(const HloInstruction& dot, bool transpose_lhs,
|
||||
@ -47,7 +49,8 @@ class DotOpEmitter {
|
||||
const llvm_ir::IrArray& lhs_array,
|
||||
const llvm_ir::IrArray& rhs_array,
|
||||
llvm::Value* executable_run_options_value,
|
||||
llvm::IRBuilder<>* ir_builder);
|
||||
llvm::IRBuilder<>* ir_builder,
|
||||
const HloModuleConfig& hlo_module_config);
|
||||
|
||||
// Emits the IR to perform the dot operation.
|
||||
tensorflow::Status Emit();
|
||||
@ -82,6 +85,7 @@ class DotOpEmitter {
|
||||
const llvm_ir::IrArray& rhs_array_;
|
||||
llvm::Value* executable_run_options_value_;
|
||||
llvm::IRBuilder<>* ir_builder_;
|
||||
const HloModuleConfig& hlo_module_config_;
|
||||
};
|
||||
|
||||
} // namespace cpu
|
||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
@ -26,11 +25,6 @@ namespace cpu {
|
||||
|
||||
bool PotentiallyImplementedAsEigenConvolution(
|
||||
const HloInstruction& convolution) {
|
||||
legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags();
|
||||
if (!flags->xla_cpu_use_eigen) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// The following conditions are necessary (but not sufficient) for
|
||||
// implementing `convolution` with Eigen convolution:
|
||||
// - the input and kernel have a non-zero number of elements.
|
||||
@ -82,11 +76,6 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
|
||||
} // namespace
|
||||
|
||||
bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
|
||||
legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags();
|
||||
if (!flags->xla_cpu_use_eigen) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// For certain types of Dot, we can call Eigen
|
||||
if (hlo.opcode() == HloOpcode::kDot) {
|
||||
const Shape& lhs_shape = hlo.operand(0)->shape();
|
||||
|
@ -33,7 +33,6 @@ limitations under the License.
|
||||
#include "external/llvm/include/llvm/IR/Intrinsics.h"
|
||||
#include "external/llvm/include/llvm/IR/LLVMContext.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||
@ -777,7 +776,8 @@ Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs,
|
||||
// Dot operation is complicated so we delegate to a helper class.
|
||||
TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation(
|
||||
*dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array,
|
||||
lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_));
|
||||
lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_,
|
||||
hlo_module_config_));
|
||||
|
||||
emitted_value_[dot] = target_address;
|
||||
return Status::OK();
|
||||
@ -862,9 +862,10 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution,
|
||||
int64_type, int64_type, int64_type, int64_type,
|
||||
int64_type, int64_type, int64_type, int64_type},
|
||||
/*isVarArg=*/false);
|
||||
legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags();
|
||||
bool multi_threaded_eigen =
|
||||
hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
|
||||
const char* fn_name =
|
||||
(flags->xla_cpu_multi_thread_eigen
|
||||
(multi_threaded_eigen
|
||||
? runtime::kEigenConvF32SymbolName
|
||||
: runtime::kEigenSingleThreadedConvF32SymbolName);
|
||||
llvm::Function* conv_func = llvm::cast<llvm::Function>(
|
||||
@ -1525,7 +1526,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
|
||||
TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation(
|
||||
*dot, dot->operand(0)->IsRank2Transpose(),
|
||||
dot->operand(1)->IsRank2Transpose(), target_array, lhs_array, rhs_array,
|
||||
GetExecutableRunOptionsArgument(), &ir_builder_));
|
||||
GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_));
|
||||
|
||||
emitted_value_[fusion] = target_address;
|
||||
return Status::OK();
|
||||
@ -1898,11 +1899,14 @@ llvm::Value* IrEmitter::EmitTempBufferPointer(
|
||||
GetTempBuffersArgument(), slice.index(), &ir_builder_);
|
||||
llvm::LoadInst* tempbuf_address_base =
|
||||
ir_builder_.CreateLoad(tempbuf_address_ptr);
|
||||
// Loading the address of a buffer is invariant of the point at which the
|
||||
// load is executed in the program because we never reassign buffers.
|
||||
tempbuf_address_base->setMetadata(
|
||||
llvm::LLVMContext::MD_invariant_load,
|
||||
llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{}));
|
||||
if (hlo_module_config_.debug_options()
|
||||
.xla_llvm_enable_invariant_load_metadata()) {
|
||||
// Loading the address of a buffer is invariant of the point at which the
|
||||
// load is executed in the program because we never reassign buffers.
|
||||
tempbuf_address_base->setMetadata(
|
||||
llvm::LLVMContext::MD_invariant_load,
|
||||
llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{}));
|
||||
}
|
||||
llvm_ir::SetTbaaForInstruction(tempbuf_address_base, target_shape,
|
||||
/*is_pointer_to=*/true);
|
||||
AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size());
|
||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||
#include "external/llvm/include/llvm/Support/CodeGen.h"
|
||||
#include "external/llvm/include/llvm/Support/Host.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
|
||||
@ -143,7 +142,9 @@ CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() {
|
||||
} // namespace
|
||||
|
||||
SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions &target_options,
|
||||
llvm::CodeGenOpt::Level opt_level)
|
||||
llvm::CodeGenOpt::Level opt_level,
|
||||
OptimizationCallback pre_optimization_callback,
|
||||
OptimizationCallback post_optimization_callback)
|
||||
: target_machine_(
|
||||
CHECK_NOTNULL(llvm::EngineBuilder()
|
||||
.setTargetOptions(target_options)
|
||||
@ -154,21 +155,19 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions &target_options,
|
||||
/*MAttrs=*/DetectMachineAttributes()))),
|
||||
disassembler_(*target_machine_),
|
||||
data_layout_(target_machine_->createDataLayout()),
|
||||
compile_layer_(object_layer_,
|
||||
CompilerFunctor(target_machine_.get(), &disassembler_,
|
||||
opt_level, GetAvailableIntrinsics())) {
|
||||
compile_layer_(
|
||||
object_layer_,
|
||||
CompilerFunctor(target_machine_.get(), &disassembler_, opt_level,
|
||||
GetAvailableIntrinsics(), pre_optimization_callback,
|
||||
post_optimization_callback)) {
|
||||
VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str()
|
||||
<< " features: " << target_machine_->getTargetFeatureString().str();
|
||||
}
|
||||
|
||||
SimpleOrcJIT::ModuleHandleT SimpleOrcJIT::AddModule(
|
||||
std::unique_ptr<llvm::Module> module) {
|
||||
// The Orc API adds a whole iterable "set" of modules, so we wrap the module
|
||||
// in a vector.
|
||||
std::vector<std::unique_ptr<llvm::Module>> module_set;
|
||||
module_set.push_back(std::move(module));
|
||||
auto handle = compile_layer_.addModuleSet(
|
||||
std::move(module_set), MakeUnique<llvm::SectionMemoryManager>(),
|
||||
auto handle = compile_layer_.addModule(
|
||||
std::move(module), MakeUnique<llvm::SectionMemoryManager>(),
|
||||
MakeUnique<SimpleResolver>());
|
||||
module_handles_.push_back(handle);
|
||||
return handle;
|
||||
@ -176,8 +175,9 @@ SimpleOrcJIT::ModuleHandleT SimpleOrcJIT::AddModule(
|
||||
|
||||
void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::ModuleHandleT handle) {
|
||||
module_handles_.erase(
|
||||
std::remove(module_handles_.begin(), module_handles_.end(), handle));
|
||||
compile_layer_.removeModuleSet(handle);
|
||||
std::remove(module_handles_.begin(), module_handles_.end(), handle),
|
||||
module_handles_.end());
|
||||
compile_layer_.removeModule(handle);
|
||||
}
|
||||
|
||||
llvm::JITSymbol SimpleOrcJIT::FindSymbol(const std::string &name) {
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "external/llvm/include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
|
||||
#include "external/llvm/include/llvm/IR/Module.h"
|
||||
#include "external/llvm/include/llvm/Target/TargetMachine.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/disassembler.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
@ -41,9 +42,13 @@ namespace cpu {
|
||||
// it's added to the JIT.
|
||||
class SimpleOrcJIT {
|
||||
public:
|
||||
using ObjLayerT = llvm::orc::RTDyldObjectLinkingLayer<>;
|
||||
using CompileLayerT = llvm::orc::IRCompileLayer<ObjLayerT>;
|
||||
using ModuleHandleT = CompileLayerT::ModuleSetHandleT;
|
||||
using ObjLayerT = llvm::orc::RTDyldObjectLinkingLayer;
|
||||
using CompileFtor =
|
||||
std::function<llvm::object::OwningBinary<llvm::object::ObjectFile>(
|
||||
llvm::Module&)>;
|
||||
using CompileLayerT = llvm::orc::IRCompileLayer<ObjLayerT, CompileFtor>;
|
||||
using ModuleHandleT = CompileLayerT::ModuleHandleT;
|
||||
using OptimizationCallback = CompilerFunctor::OptimizationCallback;
|
||||
|
||||
// Create a new JIT, targeting the host architecture.
|
||||
// The |target_options| parameter allows customization of certain code
|
||||
@ -51,8 +56,14 @@ class SimpleOrcJIT {
|
||||
// can be reassociated, etc.).
|
||||
// The |opt_level| parameter controls the optimization level of the code
|
||||
// generator.
|
||||
// The |pre_optimization_callback| is invoked on the module before any IR
|
||||
// level optimizations are applied.
|
||||
// The |post_optimization_callback| is invoked on the module after all IR
|
||||
// level optimizations are applied.
|
||||
SimpleOrcJIT(const llvm::TargetOptions& target_options,
|
||||
llvm::CodeGenOpt::Level opt_level);
|
||||
llvm::CodeGenOpt::Level opt_level,
|
||||
OptimizationCallback pre_optimization_callback,
|
||||
OptimizationCallback post_optimization_callback);
|
||||
|
||||
// Data layout this JIT was created with.
|
||||
const llvm::DataLayout& data_layout() const { return data_layout_; }
|
||||
|
@ -164,8 +164,7 @@ class DfsHloVisitor {
|
||||
HloInstruction* lhs, HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(logical_or, HloOpcode::kLogicalOr);
|
||||
}
|
||||
virtual Status HandleReducePrecision(HloInstruction* reduce_precision,
|
||||
HloInstruction* operand) {
|
||||
virtual Status HandleReducePrecision(HloInstruction* reduce_precision) {
|
||||
return HandleElementwiseUnary(reduce_precision,
|
||||
HloOpcode::kReducePrecision);
|
||||
}
|
||||
|
@ -390,17 +390,111 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
|
||||
if (hlo->operand(0)->shape().element_type() != F32) {
|
||||
return Unimplemented("reduce-precision only implemented for F32");
|
||||
}
|
||||
// As a preliminary implementation, we only implement this for the case
|
||||
// where it is a no-op -- that is, where the exponent and mantissa bit
|
||||
// counts are equal to the (IEEE f32) bit counts for the input values.
|
||||
if (hlo->exponent_bits() != 8) {
|
||||
return Unimplemented("reduce-precision requires 8 exponent bits");
|
||||
}
|
||||
if (hlo->mantissa_bits() != 23) {
|
||||
return Unimplemented("reduce-precision requires 23 mantissa bits");
|
||||
|
||||
// Integer and float types for casting and constant generation.
|
||||
llvm::Type* float_type = x->getType();
|
||||
llvm::IntegerType* int_type = ir_builder_->getInt32Ty();
|
||||
|
||||
// Cast the input value to an integer for bitwise manipulation.
|
||||
llvm::Value* x_as_int = ir_builder_->CreateBitCast(x, int_type);
|
||||
|
||||
if (hlo->mantissa_bits() < 23) {
|
||||
// Last remaining mantissa bit.
|
||||
const uint32_t last_mantissa_bit_mask = 1u << (23 - hlo->mantissa_bits());
|
||||
|
||||
// Compute rounding bias for round-to-nearest with ties to even. This is
|
||||
// equal to a base value of 0111... plus one bit if the last remaining
|
||||
// mantissa bit is 1.
|
||||
const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1;
|
||||
llvm::Value* x_last_mantissa_bit = ir_builder_->CreateLShr(
|
||||
ir_builder_->CreateAnd(
|
||||
x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
|
||||
(23 - hlo->mantissa_bits()));
|
||||
llvm::Value* x_rounding_bias = ir_builder_->CreateAdd(
|
||||
x_last_mantissa_bit,
|
||||
llvm::ConstantInt::get(int_type, base_rounding_bias));
|
||||
|
||||
// Add rounding bias, and mask out truncated bits. Note that the case
|
||||
// where adding the rounding bias overflows into the exponent bits is
|
||||
// correct; the non-masked mantissa bits will all be zero, and the
|
||||
// exponent will be incremented by one.
|
||||
const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
|
||||
x_as_int = ir_builder_->CreateAdd(x_as_int, x_rounding_bias);
|
||||
x_as_int = ir_builder_->CreateAnd(
|
||||
x_as_int, llvm::ConstantInt::get(int_type, truncation_mask));
|
||||
}
|
||||
|
||||
return x;
|
||||
if (hlo->exponent_bits() < 8) {
|
||||
// Masks for f32 values.
|
||||
const uint32_t f32_sign_bit_mask = 1u << 31;
|
||||
const uint32_t f32_exp_bits_mask = 0xffu << 23;
|
||||
|
||||
// An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
|
||||
// significant bit -- is equal to 1.0f for all exponent sizes. Adding
|
||||
// 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
|
||||
// size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
|
||||
// exponent (corresponding to 0.0f).
|
||||
//
|
||||
// Thus, the f32 exponent corresponding to the highest non-infinite
|
||||
// exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
|
||||
// exponent corresponding to the lowest exponent for a bit size of n is
|
||||
// (2^7-1) - 2^(n-1)-1.
|
||||
//
|
||||
// Note that we have already checked that exponents_bits >= 1.
|
||||
const uint32_t f32_exponent_bias = (1 << 7) - 1;
|
||||
const uint32_t reduced_exponent_bias =
|
||||
(1 << (hlo->exponent_bits() - 1)) - 1;
|
||||
const uint32_t reduced_max_exponent =
|
||||
f32_exponent_bias + reduced_exponent_bias;
|
||||
const uint32_t reduced_min_exponent =
|
||||
f32_exponent_bias - reduced_exponent_bias;
|
||||
|
||||
// Do we overflow or underflow?
|
||||
llvm::Value* x_exponent = ir_builder_->CreateAnd(
|
||||
x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
|
||||
llvm::Value* x_overflows = ir_builder_->CreateICmpUGT(
|
||||
x_exponent,
|
||||
llvm::ConstantInt::get(int_type, reduced_max_exponent << 23));
|
||||
llvm::Value* x_underflows = ir_builder_->CreateICmpULE(
|
||||
x_exponent,
|
||||
llvm::ConstantInt::get(int_type, reduced_min_exponent << 23));
|
||||
|
||||
// Compute appropriately-signed values of zero and infinity.
|
||||
llvm::Value* x_signed_zero = ir_builder_->CreateAnd(
|
||||
x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask));
|
||||
llvm::Value* x_signed_inf = ir_builder_->CreateOr(
|
||||
x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
|
||||
|
||||
// Force to zero or infinity if overflow or underflow. (Note that this
|
||||
// truncates all denormal values to zero, rather than rounding them.)
|
||||
x_as_int = ir_builder_->CreateSelect(x_overflows, x_signed_inf, x_as_int);
|
||||
x_as_int = ir_builder_->CreateSelect(x_underflows, x_signed_zero, x_as_int);
|
||||
}
|
||||
|
||||
// Cast the result back to a floating-point type.
|
||||
llvm::Value* result = ir_builder_->CreateBitCast(x_as_int, float_type);
|
||||
|
||||
// Correct result for NaN inputs.
|
||||
//
|
||||
// The exponent handling will "normalize" NaN values to infinities, which is
|
||||
// undesirable (except in the case with no mantissa bits, in which case it
|
||||
// is mandatory). This logic also handles cases where mantissa-rounding
|
||||
// causes a NaN's mantissa to overflow into the exponent bits, which would
|
||||
// otherwise create an erroneous zero value.
|
||||
//
|
||||
// If the fast-math flags are set to assume no NaNs, the comparison is likely
|
||||
// to be optimized away, so there's no point in even emitting it.
|
||||
if (!ir_builder_->getFastMathFlags().noNaNs()) {
|
||||
llvm::Value* x_is_nan = ir_builder_->CreateFCmpUNO(x, x);
|
||||
|
||||
if (hlo->mantissa_bits() > 0) {
|
||||
result = ir_builder_->CreateSelect(x_is_nan, x, result);
|
||||
} else {
|
||||
result = ir_builder_->CreateSelect(
|
||||
x_is_nan, llvm::ConstantFP::getInfinity(float_type), result);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
|
||||
|
@ -21,39 +21,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
/* static */ void Executable::DumpExecutedHlo(
|
||||
const HloModule& module, const string& label,
|
||||
const HloExecutionProfile* profile) {
|
||||
VLOG(2) << "module name = " << module.name();
|
||||
legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags();
|
||||
string generate_hlo_graph_regex;
|
||||
if (!flags->xla_generate_hlo_graph.empty()) {
|
||||
generate_hlo_graph_regex = flags->xla_generate_hlo_graph;
|
||||
} else {
|
||||
generate_hlo_graph_regex =
|
||||
module.config().debug_options().xla_generate_hlo_graph();
|
||||
}
|
||||
if (!generate_hlo_graph_regex.empty() &&
|
||||
RE2::PartialMatch(module.name(), generate_hlo_graph_regex)) {
|
||||
hlo_graph_dumper::DumpGraph(*module.entry_computation(), label,
|
||||
flags->xla_hlo_graph_addresses,
|
||||
flags->xla_hlo_graph_layout, profile);
|
||||
}
|
||||
if (!flags->xla_log_hlo_text.empty() &&
|
||||
RE2::PartialMatch(module.name(), flags->xla_log_hlo_text)) {
|
||||
LOG(INFO) << "HLO for module " << module.name();
|
||||
LOG(INFO) << "Label: " << label;
|
||||
XLA_LOG_LINES(2, module.ToString());
|
||||
}
|
||||
if (!flags->xla_dump_hlo_text_to.empty()) {
|
||||
hlo_graph_dumper::DumpText(module, label, flags->xla_dump_hlo_text_to);
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<std::vector<perftools::gputools::DeviceMemoryBase>>
|
||||
Executable::ExecuteOnStreams(
|
||||
tensorflow::gtl::ArraySlice<const ServiceExecutableRunOptions> run_options,
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/session.pb.h"
|
||||
@ -49,10 +50,6 @@ class Executable {
|
||||
shape_size_function_(std::move(shape_size_function)) {}
|
||||
virtual ~Executable() {}
|
||||
|
||||
// Dumps the executed HLO according to service-associated flags.
|
||||
static void DumpExecutedHlo(const HloModule& module, const string& label,
|
||||
const HloExecutionProfile* profile);
|
||||
|
||||
// Enqueues the compilation result on the provided stream, passing the given
|
||||
// arguments. This call is blocking and returns after the execution is done.
|
||||
//
|
||||
@ -240,7 +237,8 @@ StatusOr<ReturnT> Executable::ExecuteOnStreamWrapper(
|
||||
}
|
||||
}
|
||||
}
|
||||
DumpExecutedHlo(module(), "Service::Execute", profile_ptr);
|
||||
hlo_graph_dumper::MaybeDumpHloModule(module(), "Service::Execute",
|
||||
profile_ptr);
|
||||
}
|
||||
|
||||
return return_value;
|
||||
|
@ -253,7 +253,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:convolution_thunk_flags",
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
@ -498,8 +497,9 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:buffer_liveness",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_ordering",
|
||||
"//tensorflow/compiler/xla/service:hlo_scheduling",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
@ -287,10 +286,7 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune(
|
||||
const ConvolutionDescriptor& convolution_descriptor,
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream) {
|
||||
// TODO(b/29126320): Try cudnn v5's new auto-tuner when it's rolled out.
|
||||
legacy_flags::ConvolutionThunkFlags* flags =
|
||||
legacy_flags::GetConvolutionThunkFlags();
|
||||
if (flags->xla_gpu_autotune_convolution_algorithm &&
|
||||
best_algorithm_.algorithm() == se::dnn::kDefaultAlgorithm) {
|
||||
if (best_algorithm_.algorithm() == se::dnn::kDefaultAlgorithm) {
|
||||
// Auto-tuning either is disabled or only happens in the first run of this
|
||||
// function.
|
||||
VLOG(2) << "Profiling for best convolution algorithm used for "
|
||||
|
@ -81,9 +81,8 @@ class ConvolutionThunk : public Thunk {
|
||||
ConvolutionThunk(const ConvolutionThunk&) = delete;
|
||||
ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
|
||||
|
||||
// Does the convolution for the thunk on "stream". If the
|
||||
// xla_gpu_autotune_convolution_algorithm is turned on, auto-tuning happens on
|
||||
// the first run of this function.
|
||||
// Does the convolution for the thunk on "stream". Auto-tuning happens on the
|
||||
// first run of this function.
|
||||
tensorflow::Status ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations,
|
||||
perftools::gputools::Stream* stream) override;
|
||||
|
@ -119,14 +119,13 @@ string GetLibdeviceDir(const HloModuleConfig& config) {
|
||||
|
||||
// Runs optimization passes on the given HLO module.
|
||||
tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
|
||||
const Compiler::HloDumper& dump_hlo,
|
||||
const se::DeviceDescription& device_desc) {
|
||||
{
|
||||
HloPassPipeline pipeline("optimization", dump_hlo);
|
||||
HloPassPipeline pipeline("optimization");
|
||||
pipeline.AddInvariantChecker<HloVerifier>();
|
||||
{
|
||||
auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
|
||||
"simplification", dump_hlo);
|
||||
auto& pass =
|
||||
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
|
||||
pass.AddPass<AlgebraicSimplifier>(
|
||||
/*is_layout_sensitive=*/false,
|
||||
[](const Shape&, const Shape&) { return false; });
|
||||
@ -146,7 +145,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
|
||||
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
|
||||
}
|
||||
{
|
||||
HloPassFix<HloPassPipeline> fusion("fusion", dump_hlo);
|
||||
HloPassFix<HloPassPipeline> fusion("fusion");
|
||||
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
|
||||
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
|
||||
fusion.AddPass<FusionMerger>();
|
||||
@ -156,14 +155,13 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
|
||||
|
||||
// Modifies the given HLO module so that it will be accepted by IrEmitter.
|
||||
// Unlike optimization passes, the passes are necessary for correctness.
|
||||
tensorflow::Status PrepareHloModuleForIrEmitting(
|
||||
const Compiler::HloDumper& dump_hlo, HloModule* hlo_module) {
|
||||
tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
|
||||
// In some cases, we have to place the result of an instruction in a temporary
|
||||
// buffer. For instance, the buffer that holds an external parameter is
|
||||
// assumed immutable at this point, and should not be reused for output
|
||||
// (b/27180329). Therefore, in that case, we set the output to be a copy of
|
||||
// the parameter.
|
||||
HloPassPipeline pipeline("GPU-ir-emit-prepare", dump_hlo);
|
||||
HloPassPipeline pipeline("GPU-ir-emit-prepare");
|
||||
pipeline.AddInvariantChecker<HloVerifier>();
|
||||
pipeline.AddPass<PadInsertion>();
|
||||
pipeline.AddPass<GpuLayoutAssignment>(
|
||||
@ -230,13 +228,12 @@ GpuCompiler::GpuCompiler()
|
||||
: pointer_size_(llvm::DataLayout(kDataLayout).getPointerSize()) {}
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
|
||||
std::unique_ptr<HloModule> module, HloDumper dump_hlo,
|
||||
se::StreamExecutor* stream_exec) {
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec) {
|
||||
TF_RET_CHECK(stream_exec != nullptr);
|
||||
|
||||
TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), dump_hlo,
|
||||
stream_exec->GetDeviceDescription()));
|
||||
TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(dump_hlo, module.get()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
OptimizeHloModule(module.get(), stream_exec->GetDeviceDescription()));
|
||||
TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get()));
|
||||
|
||||
llvm::LLVMContext llvm_context;
|
||||
std::string buffer;
|
||||
@ -344,16 +341,15 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> GpuCompiler::Compile(
|
||||
std::vector<std::unique_ptr<HloModule>> modules, HloDumper dump_hlos,
|
||||
std::vector<std::unique_ptr<HloModule>> modules,
|
||||
std::vector<se::StreamExecutor*> stream_execs) {
|
||||
return Unimplemented(
|
||||
"Compilation of multiple HLO modules is not yet supported on GPU.");
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
GpuCompiler::CompileAheadOfTime(
|
||||
std::vector<std::unique_ptr<HloModule>> module,
|
||||
HloDumper dump_hlo, const AotCompilationOptions& options) {
|
||||
GpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> module,
|
||||
const AotCompilationOptions& options) {
|
||||
return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime");
|
||||
}
|
||||
|
||||
|
@ -41,17 +41,16 @@ class GpuCompiler : public Compiler {
|
||||
~GpuCompiler() override {}
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> Compile(
|
||||
std::unique_ptr<HloModule> module, HloDumper dump_hlo,
|
||||
std::unique_ptr<HloModule> module,
|
||||
perftools::gputools::StreamExecutor* stream_exec) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||
std::vector<std::unique_ptr<HloModule>> modules, HloDumper dump_hlo,
|
||||
std::vector<std::unique_ptr<HloModule>> modules,
|
||||
std::vector<perftools::gputools::StreamExecutor*> stream_exec) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(
|
||||
std::vector<std::unique_ptr<HloModule>> module,
|
||||
HloDumper dump_hlo, AotCompilationOptions const& options) override;
|
||||
CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> module,
|
||||
AotCompilationOptions const& options) override;
|
||||
|
||||
perftools::gputools::Platform::Id PlatformId() const override;
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
@ -19,9 +19,9 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
|
@ -15,21 +15,21 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
|
||||
|
||||
#include <ostream>
|
||||
#include <queue>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_value.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
@ -38,105 +38,6 @@ using ::tensorflow::str_util::Join;
|
||||
using ::tensorflow::strings::StrAppend;
|
||||
using ::tensorflow::strings::StrCat;
|
||||
|
||||
void HloBuffer::AddValue(const HloValue& value) {
|
||||
// If the value is already contained in this buffer, just return.
|
||||
if (std::find(value_ids_.begin(), value_ids_.end(), value.id()) !=
|
||||
value_ids_.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
value_ids_.push_back(value.id());
|
||||
|
||||
// Add all of the locations of the HloValue to this buffer.
|
||||
for (const HloLocation& location : value.locations()) {
|
||||
if (std::find(locations_.begin(), locations_.end(), location) ==
|
||||
locations_.end()) {
|
||||
locations_.push_back(location);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool HloBuffer::operator==(const HloBuffer& other) const {
|
||||
bool equal = id() == other.id();
|
||||
if (equal) {
|
||||
// DCHECK because these comparisons are expensive (linear time).
|
||||
DCHECK(value_ids() == other.value_ids());
|
||||
DCHECK(locations() == other.locations());
|
||||
}
|
||||
return equal;
|
||||
}
|
||||
|
||||
string HloBuffer::ToString() const {
|
||||
return StrCat("HloBuffer ", id_, ", values: ", Join(value_ids_, ", "));
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer) {
|
||||
out << buffer.ToString();
|
||||
return out;
|
||||
}
|
||||
|
||||
void HloBufferSet::AddBuffer(HloBuffer::Id buffer_id) {
|
||||
if (std::find(buffer_ids_.begin(), buffer_ids_.end(), buffer_id) ==
|
||||
buffer_ids_.end()) {
|
||||
buffer_ids_.push_back(buffer_id);
|
||||
}
|
||||
}
|
||||
|
||||
void HloBufferSet::RemoveBufferOrDie(HloBuffer::Id buffer_id) {
|
||||
auto it = std::find(buffer_ids_.begin(), buffer_ids_.end(), buffer_id);
|
||||
CHECK(it != buffer_ids_.end());
|
||||
buffer_ids_.erase(it);
|
||||
}
|
||||
|
||||
string HloBufferSet::ToString() const {
|
||||
return StrCat("HloBufferSet, buffers: ", Join(buffer_ids_, ", "));
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloBufferSet& buffer_set) {
|
||||
out << buffer_set.ToString();
|
||||
return out;
|
||||
}
|
||||
|
||||
bool InstructionBufferSet::IsAmbiguous() const {
|
||||
bool is_ambiguous = false;
|
||||
ForEachElement(
|
||||
[&is_ambiguous](const ShapeIndex& index, const HloBufferSet& buffer_set) {
|
||||
is_ambiguous |= buffer_set.buffer_ids().size() > 1;
|
||||
});
|
||||
return is_ambiguous;
|
||||
}
|
||||
|
||||
bool InstructionBufferSet::IsDistinct() const {
|
||||
bool is_distinct = true;
|
||||
tensorflow::gtl::FlatSet<HloBuffer::Id> seen_ids;
|
||||
ForEachElement([&is_distinct, &seen_ids](const ShapeIndex& index,
|
||||
const HloBufferSet& buffer_set) {
|
||||
for (HloBuffer::Id buffer_id : buffer_set.buffer_ids()) {
|
||||
auto pair = seen_ids.insert(buffer_id);
|
||||
if (!pair.second) {
|
||||
is_distinct = false;
|
||||
}
|
||||
}
|
||||
});
|
||||
return is_distinct;
|
||||
}
|
||||
|
||||
string InstructionBufferSet::ToString() const {
|
||||
string out =
|
||||
StrCat("InstructionBufferSet(", ShapeUtil::HumanString(shape()), ")\n");
|
||||
ForEachElement([this, &out](const ShapeIndex& index,
|
||||
const HloBufferSet& value_set) {
|
||||
StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n");
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out,
|
||||
const InstructionBufferSet& buffer_set) {
|
||||
out << buffer_set.ToString();
|
||||
return out;
|
||||
}
|
||||
|
||||
HloAliasAnalysis::HloAliasAnalysis(HloModule* module) : module_(module) {}
|
||||
|
||||
void HloAliasAnalysis::InitializeBufferSets() {
|
||||
|
@ -16,182 +16,23 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <iosfwd>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/logical_buffer.h"
|
||||
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// A container which can hold one or more HloValues. An HLO buffer abstractly
|
||||
// represents the allocation which HLO instructions write into and read
|
||||
// from. Generally there is a one-to-one correspondence between HloBuffers and
|
||||
// HloValue where each HloValue in the module is held in a unique HloBuffer. An
|
||||
// exception is the while instruction which updates the loop state in-place. In
|
||||
// this case, we have a single HloBuffer for each HloLocation in the loop state,
|
||||
// but multiple HloValues. For example:
|
||||
//
|
||||
// %init = ...
|
||||
// %while = While(%init, body, condition)
|
||||
//
|
||||
// body:
|
||||
// %body_param = Param(0)
|
||||
// ...
|
||||
// %body_root = ...
|
||||
//
|
||||
// condition:
|
||||
// %cond_param = Param(0)
|
||||
// ...
|
||||
//
|
||||
// For simplicity, assume that %while is array-shaped. In this case, we have a
|
||||
// single HloBuffer which holds the following HloValues: HloValue{%init},
|
||||
// HloValue{%while}, HloValue{%body_param}, HloValue{%body_root}, and
|
||||
// HloValue{%cond_param}.
|
||||
//
|
||||
// HloBuffers may appear at different HloLocations in the module mirroring the
|
||||
// same propery of HloValues. For example:
|
||||
//
|
||||
// %sub = Sub(...)
|
||||
// %add = Add(...)
|
||||
// %tuple = Tuple(%add, %sub)
|
||||
// %gte = GetTupleElement(%tuple, 0)
|
||||
//
|
||||
// In this case, the HloBuffer containing %add appears at the following
|
||||
// locations: HloLocation{%add, {}}, HloLocation{%tuple, {0}}, and
|
||||
// HloLocation{%gte, {}}.
|
||||
//
|
||||
// Different HloLocations which share the same HloBuffer indicate mandatory
|
||||
// aliasing in the HLO module. These locations must share the same memory
|
||||
// allocation for correctness (the backends rely on this property). This differs
|
||||
// from incidental aliasing introduced by memory reuse in BufferAssignment where
|
||||
// different instructions may happen to get the same allocation.
|
||||
class HloBuffer {
|
||||
public:
|
||||
using Id = int64;
|
||||
|
||||
HloBuffer(int64 id) : id_(id) {}
|
||||
|
||||
// Return the unique identifier for this HloBuffer.
|
||||
int64 id() const { return id_; }
|
||||
|
||||
// Add a value to the set of values held by this buffer. Also adds the
|
||||
// HloLocations of the value to the locations vector of the buffer. If the
|
||||
// buffer already contains this value, then this method is a nop.
|
||||
void AddValue(const HloValue& value);
|
||||
|
||||
// Return the IDs of all values contained in this buffer.
|
||||
const std::vector<HloValue::Id>& value_ids() const { return value_ids_; }
|
||||
|
||||
// Return the locations (output of which instruction and at what index) where
|
||||
// the buffer is used. This is exactly the union of the locations of the
|
||||
// HloValues contained by the buffer.
|
||||
const std::vector<HloLocation>& locations() const { return locations_; }
|
||||
|
||||
string ToString() const;
|
||||
|
||||
bool operator==(const HloBuffer& other) const;
|
||||
bool operator!=(const HloBuffer& other) const { return !(*this == other); }
|
||||
|
||||
private:
|
||||
// Unique identifier for this HloBuffer.
|
||||
const Id id_;
|
||||
|
||||
// The set of values contained in the this buffer.
|
||||
std::vector<HloValue::Id> value_ids_;
|
||||
|
||||
// The set of locations where this buffer is used.
|
||||
std::vector<HloLocation> locations_;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer);
|
||||
|
||||
// A class representing the set of possible HloBuffers at a particular
|
||||
// HloLocation (shape index in the output of an instruction) in the XLA
|
||||
// graph. In most cases, the buffer set will have a single HloBuffer indicating
|
||||
// that the HloBuffer which appears at that particular location is known
|
||||
// unambiguously at compile-time. However, tuple-shaped Select instructions can
|
||||
// introduce ambiguity as the tuple elements of the operands are passed by
|
||||
// reference into the output of the Select. For example:
|
||||
//
|
||||
// %pred = ...
|
||||
// %tuple0 = Tuple(%a, %b)
|
||||
// %tuple1 = Tuple(%x, %y)
|
||||
// %select = Select(%pred, %tuple0, %tuple1)
|
||||
//
|
||||
// In this case the HloBufferSet at HloLocation{%select, {0}} contains the
|
||||
// HloBuffer holding %a and the HloBuffer holding %x.
|
||||
class HloBufferSet {
|
||||
public:
|
||||
HloBufferSet() = default;
|
||||
|
||||
// Add the given buffer to this buffer set. If the buffer already exists in
|
||||
// the set, then this is a NOP.
|
||||
void AddBuffer(HloBuffer::Id buffer_id);
|
||||
|
||||
// Removes the given buffer from this buffer set. CHECK fails in the buffer is
|
||||
// not contained in this set.
|
||||
void RemoveBufferOrDie(HloBuffer::Id buffer_id);
|
||||
|
||||
// Returns the unique buffer in this set. CHECK fails if the set does not
|
||||
// contain exactly one buffer.
|
||||
HloBuffer::Id GetUniqueBufferId() const {
|
||||
CHECK_EQ(buffer_ids().size(), 1);
|
||||
return buffer_ids()[0];
|
||||
}
|
||||
|
||||
// Returns the IDs of the HloBuffers contained in this buffer set.
|
||||
const std::vector<HloBuffer::Id>& buffer_ids() const { return buffer_ids_; }
|
||||
|
||||
string ToString() const;
|
||||
|
||||
private:
|
||||
// The IDs of the HloBuffers containted in this buffer set.
|
||||
std::vector<HloBuffer::Id> buffer_ids_;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloBufferSet& buffer_set);
|
||||
|
||||
// A class collecting the HloBuffers in the output of an HLO instruction. For
|
||||
// array-shaped instructions, an InstructionBufferSet trivially holds a single
|
||||
// HloBufferSet. Tuple-shaped InstructionBufferSets hold multiple
|
||||
// HloBufferSets.
|
||||
class InstructionBufferSet : public ShapeTree<HloBufferSet> {
|
||||
public:
|
||||
InstructionBufferSet(const Shape& shape) : ShapeTree<HloBufferSet>(shape) {}
|
||||
|
||||
// Returns true if any HloBufferSet contained in this InstructionBufferSet
|
||||
// is not a singleton.
|
||||
bool IsAmbiguous() const;
|
||||
|
||||
// Returns true if any HloBuffer appears in more than one HloBufferSet
|
||||
// contained in this InstructionBufferSet.
|
||||
bool IsDistinct() const;
|
||||
|
||||
string ToString() const;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out,
|
||||
const InstructionBufferSet& buffer_set);
|
||||
|
||||
class HloAliasAnalysis {
|
||||
public:
|
||||
static StatusOr<std::unique_ptr<HloAliasAnalysis>> Run(HloModule* module);
|
||||
|
139
tensorflow/compiler/xla/service/hlo_buffer.cc
Normal file
139
tensorflow/compiler/xla/service/hlo_buffer.cc
Normal file
@ -0,0 +1,139 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_buffer.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <ostream>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
using ::tensorflow::str_util::Join;
|
||||
using ::tensorflow::strings::StrAppend;
|
||||
using ::tensorflow::strings::StrCat;
|
||||
|
||||
void HloBuffer::AddValue(const HloValue& value) {
|
||||
// If the value is already contained in this buffer, just return.
|
||||
if (std::find(value_ids_.begin(), value_ids_.end(), value.id()) !=
|
||||
value_ids_.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
value_ids_.push_back(value.id());
|
||||
|
||||
// Add all of the locations of the HloValue to this buffer.
|
||||
for (const HloLocation& location : value.locations()) {
|
||||
if (std::find(locations_.begin(), locations_.end(), location) ==
|
||||
locations_.end()) {
|
||||
locations_.push_back(location);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool HloBuffer::operator==(const HloBuffer& other) const {
|
||||
bool equal = id() == other.id();
|
||||
if (equal) {
|
||||
// DCHECK because these comparisons are expensive (linear time).
|
||||
DCHECK(value_ids() == other.value_ids());
|
||||
DCHECK(locations() == other.locations());
|
||||
}
|
||||
return equal;
|
||||
}
|
||||
|
||||
string HloBuffer::ToString() const {
|
||||
return StrCat("HloBuffer ", id_, ", values: ", Join(value_ids_, ", "));
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer) {
|
||||
out << buffer.ToString();
|
||||
return out;
|
||||
}
|
||||
|
||||
void HloBufferSet::AddBuffer(HloBuffer::Id buffer_id) {
|
||||
if (std::find(buffer_ids_.begin(), buffer_ids_.end(), buffer_id) ==
|
||||
buffer_ids_.end()) {
|
||||
buffer_ids_.push_back(buffer_id);
|
||||
}
|
||||
}
|
||||
|
||||
void HloBufferSet::RemoveBufferOrDie(HloBuffer::Id buffer_id) {
|
||||
auto it = std::find(buffer_ids_.begin(), buffer_ids_.end(), buffer_id);
|
||||
CHECK(it != buffer_ids_.end());
|
||||
buffer_ids_.erase(it);
|
||||
}
|
||||
|
||||
string HloBufferSet::ToString() const {
|
||||
return StrCat("HloBufferSet, buffers: ", Join(buffer_ids_, ", "));
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloBufferSet& buffer_set) {
|
||||
out << buffer_set.ToString();
|
||||
return out;
|
||||
}
|
||||
|
||||
bool InstructionBufferSet::IsAmbiguous() const {
|
||||
bool is_ambiguous = false;
|
||||
ForEachElement(
|
||||
[&is_ambiguous](const ShapeIndex& index, const HloBufferSet& buffer_set) {
|
||||
is_ambiguous |= buffer_set.buffer_ids().size() > 1;
|
||||
});
|
||||
return is_ambiguous;
|
||||
}
|
||||
|
||||
bool InstructionBufferSet::IsDistinct() const {
|
||||
bool is_distinct = true;
|
||||
tensorflow::gtl::FlatSet<HloBuffer::Id> seen_ids;
|
||||
ForEachElement([&is_distinct, &seen_ids](const ShapeIndex& index,
|
||||
const HloBufferSet& buffer_set) {
|
||||
for (HloBuffer::Id buffer_id : buffer_set.buffer_ids()) {
|
||||
auto pair = seen_ids.insert(buffer_id);
|
||||
if (!pair.second) {
|
||||
is_distinct = false;
|
||||
}
|
||||
}
|
||||
});
|
||||
return is_distinct;
|
||||
}
|
||||
|
||||
string InstructionBufferSet::ToString() const {
|
||||
string out =
|
||||
StrCat("InstructionBufferSet(", ShapeUtil::HumanString(shape()), ")\n");
|
||||
ForEachElement([this, &out](const ShapeIndex& index,
|
||||
const HloBufferSet& value_set) {
|
||||
StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n");
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out,
|
||||
const InstructionBufferSet& buffer_set) {
|
||||
out << buffer_set.ToString();
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace xla
|
183
tensorflow/compiler/xla/service/hlo_buffer.h
Normal file
183
tensorflow/compiler/xla/service/hlo_buffer.h
Normal file
@ -0,0 +1,183 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_value.h"
|
||||
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// A container which can hold one or more HloValues. An HLO buffer abstractly
|
||||
// represents the allocation which HLO instructions write into and read
|
||||
// from. Generally there is a one-to-one correspondence between HloBuffers and
|
||||
// HloValue where each HloValue in the module is held in a unique HloBuffer. An
|
||||
// exception is the while instruction which updates the loop state in-place. In
|
||||
// this case, we have a single HloBuffer for each HloLocation in the loop state,
|
||||
// but multiple HloValues. For example:
|
||||
//
|
||||
// %init = ...
|
||||
// %while = While(%init, body, condition)
|
||||
//
|
||||
// body:
|
||||
// %body_param = Param(0)
|
||||
// ...
|
||||
// %body_root = ...
|
||||
//
|
||||
// condition:
|
||||
// %cond_param = Param(0)
|
||||
// ...
|
||||
//
|
||||
// For simplicity, assume that %while is array-shaped. In this case, we have a
|
||||
// single HloBuffer which holds the following HloValues: HloValue{%init},
|
||||
// HloValue{%while}, HloValue{%body_param}, HloValue{%body_root}, and
|
||||
// HloValue{%cond_param}.
|
||||
//
|
||||
// HloBuffers may appear at different HloLocations in the module mirroring the
|
||||
// same propery of HloValues. For example:
|
||||
//
|
||||
// %sub = Sub(...)
|
||||
// %add = Add(...)
|
||||
// %tuple = Tuple(%add, %sub)
|
||||
// %gte = GetTupleElement(%tuple, 0)
|
||||
//
|
||||
// In this case, the HloBuffer containing %add appears at the following
|
||||
// locations: HloLocation{%add, {}}, HloLocation{%tuple, {0}}, and
|
||||
// HloLocation{%gte, {}}.
|
||||
//
|
||||
// Different HloLocations which share the same HloBuffer indicate mandatory
|
||||
// aliasing in the HLO module. These locations must share the same memory
|
||||
// allocation for correctness (the backends rely on this property). This differs
|
||||
// from incidental aliasing introduced by memory reuse in BufferAssignment where
|
||||
// different instructions may happen to get the same allocation.
|
||||
class HloBuffer {
|
||||
public:
|
||||
using Id = int64;
|
||||
|
||||
HloBuffer(int64 id) : id_(id) {}
|
||||
|
||||
// Return the unique identifier for this HloBuffer.
|
||||
int64 id() const { return id_; }
|
||||
|
||||
// Add a value to the set of values held by this buffer. Also adds the
|
||||
// HloLocations of the value to the locations vector of the buffer. If the
|
||||
// buffer already contains this value, then this method is a nop.
|
||||
void AddValue(const HloValue& value);
|
||||
|
||||
// Return the IDs of all values contained in this buffer.
|
||||
const std::vector<HloValue::Id>& value_ids() const { return value_ids_; }
|
||||
|
||||
// Return the locations (output of which instruction and at what index) where
|
||||
// the buffer is used. This is exactly the union of the locations of the
|
||||
// HloValues contained by the buffer.
|
||||
const std::vector<HloLocation>& locations() const { return locations_; }
|
||||
|
||||
string ToString() const;
|
||||
|
||||
bool operator==(const HloBuffer& other) const;
|
||||
bool operator!=(const HloBuffer& other) const { return !(*this == other); }
|
||||
|
||||
private:
|
||||
// Unique identifier for this HloBuffer.
|
||||
const Id id_;
|
||||
|
||||
// The set of values contained in the this buffer.
|
||||
std::vector<HloValue::Id> value_ids_;
|
||||
|
||||
// The set of locations where this buffer is used.
|
||||
std::vector<HloLocation> locations_;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer);
|
||||
|
||||
// A class representing the set of possible HloBuffers at a particular
|
||||
// HloLocation (shape index in the output of an instruction) in the XLA
|
||||
// graph. In most cases, the buffer set will have a single HloBuffer indicating
|
||||
// that the HloBuffer which appears at that particular location is known
|
||||
// unambiguously at compile-time. However, tuple-shaped Select instructions can
|
||||
// introduce ambiguity as the tuple elements of the operands are passed by
|
||||
// reference into the output of the Select. For example:
|
||||
//
|
||||
// %pred = ...
|
||||
// %tuple0 = Tuple(%a, %b)
|
||||
// %tuple1 = Tuple(%x, %y)
|
||||
// %select = Select(%pred, %tuple0, %tuple1)
|
||||
//
|
||||
// In this case the HloBufferSet at HloLocation{%select, {0}} contains the
|
||||
// HloBuffer holding %a and the HloBuffer holding %x.
|
||||
class HloBufferSet {
|
||||
public:
|
||||
HloBufferSet() = default;
|
||||
|
||||
// Add the given buffer to this buffer set. If the buffer already exists in
|
||||
// the set, then this is a NOP.
|
||||
void AddBuffer(HloBuffer::Id buffer_id);
|
||||
|
||||
// Removes the given buffer from this buffer set. CHECK fails in the buffer is
|
||||
// not contained in this set.
|
||||
void RemoveBufferOrDie(HloBuffer::Id buffer_id);
|
||||
|
||||
// Returns the unique buffer in this set. CHECK fails if the set does not
|
||||
// contain exactly one buffer.
|
||||
HloBuffer::Id GetUniqueBufferId() const {
|
||||
CHECK_EQ(buffer_ids().size(), 1);
|
||||
return buffer_ids()[0];
|
||||
}
|
||||
|
||||
// Returns the IDs of the HloBuffers contained in this buffer set.
|
||||
const std::vector<HloBuffer::Id>& buffer_ids() const { return buffer_ids_; }
|
||||
|
||||
string ToString() const;
|
||||
|
||||
private:
|
||||
// The IDs of the HloBuffers containted in this buffer set.
|
||||
std::vector<HloBuffer::Id> buffer_ids_;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloBufferSet& buffer_set);
|
||||
|
||||
// A class collecting the HloBuffers in the output of an HLO instruction. For
|
||||
// array-shaped instructions, an InstructionBufferSet trivially holds a single
|
||||
// HloBufferSet. Tuple-shaped InstructionBufferSets hold multiple
|
||||
// HloBufferSets.
|
||||
class InstructionBufferSet : public ShapeTree<HloBufferSet> {
|
||||
public:
|
||||
InstructionBufferSet(const Shape& shape) : ShapeTree<HloBufferSet>(shape) {}
|
||||
|
||||
// Returns true if any HloBufferSet contained in this InstructionBufferSet
|
||||
// is not a singleton.
|
||||
bool IsAmbiguous() const;
|
||||
|
||||
// Returns true if any HloBuffer appears in more than one HloBufferSet
|
||||
// contained in this InstructionBufferSet.
|
||||
bool IsDistinct() const;
|
||||
|
||||
string ToString() const;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out,
|
||||
const InstructionBufferSet& buffer_set);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_
|
@ -25,34 +25,56 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
constexpr char HloCostAnalysis::kFlopsKey[];
|
||||
constexpr char HloCostAnalysis::kTranscendentalsKey[];
|
||||
constexpr char HloCostAnalysis::kBytesAccessedKey[];
|
||||
constexpr char HloCostAnalysis::kSecondsKey[];
|
||||
|
||||
HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size)
|
||||
: HloCostAnalysis(shape_size, {}) {}
|
||||
|
||||
HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size,
|
||||
const Properties& per_second_rates)
|
||||
: shape_size_(shape_size), per_second_rates_(per_second_rates) {}
|
||||
|
||||
Status HloCostAnalysis::Preprocess(HloInstruction* hlo) {
|
||||
// Set current instruction cost values to reasonable default values. Each
|
||||
// handler can overwrite these values. In Postprocess, these value are
|
||||
// handler can overwrite these values. In Postprocess, these values are
|
||||
// accumulated and written to the per-instruction maps.
|
||||
current_flop_count_ = 0;
|
||||
current_transcendental_count_ = 0;
|
||||
current_properties_.clear();
|
||||
current_should_compute_bottleneck_time_ = true;
|
||||
|
||||
// The default element count for an instruction is the sum of elements in the
|
||||
// operands and output. The default ShapeUtil::ByteSizeOf does not handle
|
||||
// opaque types.
|
||||
current_bytes_accessed_ = shape_size_(hlo->shape());
|
||||
// The default number of bytes accessed for an instruction is the sum of the
|
||||
// sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not
|
||||
// handle opaque types.
|
||||
float bytes_accessed = shape_size_(hlo->shape());
|
||||
for (const HloInstruction* operand : hlo->operands()) {
|
||||
current_bytes_accessed_ += shape_size_(operand->shape());
|
||||
bytes_accessed += shape_size_(operand->shape());
|
||||
}
|
||||
current_properties_[kBytesAccessedKey] = bytes_accessed;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::Postprocess(HloInstruction* hlo) {
|
||||
// Accumulate cost values and write into per-instruction maps.
|
||||
flop_count_ += current_flop_count_;
|
||||
hlo_to_flop_count_[hlo] = current_flop_count_;
|
||||
if (current_should_compute_bottleneck_time_) {
|
||||
// Compute the time as the time of the bottleneck, i.e. the slowest property
|
||||
// given the per-second rate of each property.
|
||||
float max_seconds = 0.0f;
|
||||
for (const auto& property : current_properties_) {
|
||||
if (property.first != kSecondsKey) {
|
||||
max_seconds = std::max(
|
||||
max_seconds,
|
||||
property.second / GetProperty(property.first, per_second_rates_));
|
||||
}
|
||||
}
|
||||
current_properties_[kSecondsKey] = max_seconds;
|
||||
}
|
||||
|
||||
transcendental_count_ += current_transcendental_count_;
|
||||
hlo_to_transcendental_count_[hlo] = current_transcendental_count_;
|
||||
|
||||
bytes_accessed_ += current_bytes_accessed_;
|
||||
hlo_to_bytes_accessed_[hlo] = current_bytes_accessed_;
|
||||
TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second);
|
||||
for (const auto& property : current_properties_) {
|
||||
properties_sum_[property.first] += property.second;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -65,15 +87,32 @@ Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) {
|
||||
auto opcode = hlo_instruction->opcode();
|
||||
// We treat the two opcodes (kExp, kPower) as transcendental operations.
|
||||
if (opcode == HloOpcode::kExp || opcode == HloOpcode::kPower) {
|
||||
current_transcendental_count_ = computation_count;
|
||||
current_properties_[kTranscendentalsKey] = computation_count;
|
||||
} else {
|
||||
// Note: transcendental operations are considered a separate category from
|
||||
// FLOPs.
|
||||
current_flop_count_ = computation_count;
|
||||
current_properties_[kFlopsKey] = computation_count;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/*static*/ float HloCostAnalysis::GetProperty(const string& key,
|
||||
const Properties& properties) {
|
||||
auto key_value = properties.find(key);
|
||||
return key_value == properties.end() ? 0.0f : key_value->second;
|
||||
}
|
||||
|
||||
/*static*/ float HloCostAnalysis::GetPropertyForHlo(
|
||||
const HloInstruction& hlo, const string& key,
|
||||
const HloToProperties& hlo_to_properties) {
|
||||
auto it = hlo_to_properties.find(&hlo);
|
||||
if (it == hlo_to_properties.end()) {
|
||||
return 0.0f;
|
||||
} else {
|
||||
return GetProperty(key, it->second);
|
||||
}
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo,
|
||||
HloOpcode opcode) {
|
||||
return HandleElementwiseOp(hlo);
|
||||
@ -97,19 +136,18 @@ Status HloCostAnalysis::HandleClamp(HloInstruction* clamp,
|
||||
return HandleElementwiseOp(clamp);
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleReducePrecision(HloInstruction* hlo,
|
||||
HloInstruction* operand) {
|
||||
Status HloCostAnalysis::HandleReducePrecision(HloInstruction* hlo) {
|
||||
return HandleElementwiseOp(hlo);
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleParameter(HloInstruction* parameter) {
|
||||
current_bytes_accessed_ = 0;
|
||||
current_properties_[kBytesAccessedKey] = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleConstant(HloInstruction* constant,
|
||||
const Literal& literal) {
|
||||
current_bytes_accessed_ = 0;
|
||||
current_properties_[kBytesAccessedKey] = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -117,7 +155,7 @@ Status HloCostAnalysis::HandleGetTupleElement(HloInstruction* get_tuple_element,
|
||||
HloInstruction* operand) {
|
||||
// GetTupleElement forwards a pointer and does not touch each element in the
|
||||
// output.
|
||||
current_bytes_accessed_ = 0;
|
||||
current_properties_[kBytesAccessedKey] = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -155,8 +193,9 @@ Status HloCostAnalysis::HandleTuple(
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
|
||||
// The tuple instruction only gathers pointers from inputs (it doesn't iterate
|
||||
// through them). The memory touched is then only the size of the output
|
||||
// buffer.
|
||||
current_bytes_accessed_ = shape_size_(tuple->shape());
|
||||
// index table of the tuple.
|
||||
|
||||
current_properties_[kBytesAccessedKey] = shape_size_(tuple->shape());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -194,7 +233,7 @@ Status HloCostAnalysis::HandleDot(HloInstruction* dot,
|
||||
}
|
||||
|
||||
// We count an FMA operation as 2 floating point operations.
|
||||
current_flop_count_ = kFmaFlops * fma_count;
|
||||
current_properties_[kFlopsKey] = kFmaFlops * fma_count;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -210,16 +249,17 @@ Status HloCostAnalysis::HandleMap(
|
||||
HloInstruction* map, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
|
||||
HloComputation* function,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> /*static_operands*/) {
|
||||
// Compute the cost of the user function.
|
||||
HloInstruction* function_instruction = function->root_instruction();
|
||||
HloCostAnalysis visitor(shape_size_);
|
||||
TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor));
|
||||
// Compute properties of the mapped function.
|
||||
TF_ASSIGN_OR_RETURN(const Properties sub_properties,
|
||||
ProcessSubcomputation(function));
|
||||
|
||||
// Compute the cost of all elements for this Map operation.
|
||||
int64 element_count = ShapeUtil::ElementsIn(map->shape());
|
||||
current_transcendental_count_ =
|
||||
element_count * visitor.transcendental_count();
|
||||
current_flop_count_ = element_count * visitor.flop_count();
|
||||
const int64 element_count = ShapeUtil::ElementsIn(map->shape());
|
||||
for (const auto& property : sub_properties) {
|
||||
if (property.first != kBytesAccessedKey) {
|
||||
current_properties_[property.first] = property.second * element_count;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -227,16 +267,17 @@ Status HloCostAnalysis::HandleReduce(
|
||||
HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function) {
|
||||
// Compute the cost of the user function.
|
||||
HloInstruction* function_instruction = function->root_instruction();
|
||||
HloCostAnalysis visitor(shape_size_);
|
||||
TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor));
|
||||
TF_ASSIGN_OR_RETURN(const Properties sub_properties,
|
||||
ProcessSubcomputation(function));
|
||||
|
||||
// Compute the cost of all elements for this Reduce operation.
|
||||
int64 reduction_count = ShapeUtil::ElementsIn(arg->shape()) -
|
||||
ShapeUtil::ElementsIn(reduce->shape());
|
||||
current_flop_count_ = reduction_count * visitor.flop_count();
|
||||
current_transcendental_count_ =
|
||||
reduction_count * visitor.transcendental_count();
|
||||
for (const auto& property : sub_properties) {
|
||||
if (property.first != kBytesAccessedKey) {
|
||||
current_properties_[property.first] = property.second * reduction_count;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -244,55 +285,63 @@ Status HloCostAnalysis::HandleReduceWindow(HloInstruction* reduce_window,
|
||||
HloInstruction* operand,
|
||||
const Window& window,
|
||||
HloComputation* function) {
|
||||
// Compute the cost of the user function.
|
||||
HloInstruction* function_instruction = function->root_instruction();
|
||||
HloCostAnalysis visitor(shape_size_);
|
||||
TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor));
|
||||
// Compute the properties of the reduction function.
|
||||
TF_ASSIGN_OR_RETURN(const Properties sub_properties,
|
||||
ProcessSubcomputation(function));
|
||||
|
||||
// Compute the cost of all elements for this ReduceWindow operation. For each
|
||||
// output element, (window_size - 1) number of user computations are applied.
|
||||
auto output_size = ShapeUtil::ElementsIn(reduce_window->shape());
|
||||
int64 window_size = 1;
|
||||
// output element there are window_size - 1 reductions to perform.
|
||||
int64 window_element_count = 1;
|
||||
for (const auto& dimension : window.dimensions()) {
|
||||
window_size *= dimension.size();
|
||||
window_element_count *= dimension.size();
|
||||
}
|
||||
const int64 output_element_count =
|
||||
ShapeUtil::ElementsIn(reduce_window->shape());
|
||||
const int64 reduction_count =
|
||||
(window_element_count - 1) * output_element_count;
|
||||
for (const auto& property : sub_properties) {
|
||||
if (property.first != kBytesAccessedKey) {
|
||||
current_properties_[property.first] = property.second * reduction_count;
|
||||
}
|
||||
}
|
||||
current_flop_count_ = output_size * (window_size - 1) * visitor.flop_count();
|
||||
current_transcendental_count_ =
|
||||
output_size * (window_size - 1) * visitor.transcendental_count();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleSelectAndScatter(HloInstruction* instruction) {
|
||||
// Compute the cost of the select and scatter function.
|
||||
HloInstruction* select = instruction->select()->root_instruction();
|
||||
HloCostAnalysis select_visitor(shape_size_);
|
||||
TF_RETURN_IF_ERROR(select->Accept(&select_visitor));
|
||||
HloInstruction* scatter = instruction->scatter()->root_instruction();
|
||||
HloCostAnalysis scatter_visitor(shape_size_);
|
||||
TF_RETURN_IF_ERROR(scatter->Accept(&scatter_visitor));
|
||||
// Compute the properties of the select and scatter function.
|
||||
// Compute the properties of the reduction function.
|
||||
TF_ASSIGN_OR_RETURN(const Properties select_properties,
|
||||
ProcessSubcomputation(instruction->select()));
|
||||
TF_ASSIGN_OR_RETURN(const Properties scatter_properties,
|
||||
ProcessSubcomputation(instruction->scatter()));
|
||||
|
||||
// Compute the cost of all elements for this operation. For each scatter
|
||||
// source element, (window_size - 1) number of select computations and 1
|
||||
// scatter computation are applied.
|
||||
// source element there are window_size - 1 select computations to perform and
|
||||
// 1 scatter computation to perform.
|
||||
const auto source = instruction->operand(1);
|
||||
const auto source_element_count = ShapeUtil::ElementsIn(source->shape());
|
||||
int64 window_size = 1;
|
||||
int64 window_element_count = 1;
|
||||
for (const auto& dimension : instruction->window().dimensions()) {
|
||||
window_size *= dimension.size();
|
||||
window_element_count *= dimension.size();
|
||||
}
|
||||
const int64 select_count = source_element_count * (window_element_count - 1);
|
||||
for (const auto& property : select_properties) {
|
||||
if (property.first != kBytesAccessedKey) {
|
||||
current_properties_[property.first] += property.second * select_count;
|
||||
}
|
||||
}
|
||||
for (const auto& property : scatter_properties) {
|
||||
if (property.first != kBytesAccessedKey) {
|
||||
current_properties_[property.first] +=
|
||||
property.second * source_element_count;
|
||||
}
|
||||
}
|
||||
current_flop_count_ =
|
||||
source_element_count * ((window_size - 1) * select_visitor.flop_count() +
|
||||
scatter_visitor.flop_count());
|
||||
current_transcendental_count_ =
|
||||
source_element_count *
|
||||
((window_size - 1) * select_visitor.transcendental_count() +
|
||||
scatter_visitor.transcendental_count());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleBitcast(HloInstruction* bitcast) {
|
||||
// A bitcast does no computation and touches no memory.
|
||||
current_bytes_accessed_ = 0;
|
||||
current_properties_[kBytesAccessedKey] = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -332,12 +381,13 @@ Status HloCostAnalysis::HandleConvolution(HloInstruction* convolution,
|
||||
const int64 output_features =
|
||||
convolution->shape().dimensions(dnums.feature_dimension());
|
||||
|
||||
// For each output element, we do one fma per element in the
|
||||
// kernel at some given output feature index.
|
||||
// For each output element, we do one fma per element in the kernel at some
|
||||
// given output feature index.
|
||||
const int64 fmas_per_output_element =
|
||||
ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features;
|
||||
const int64 output_elements = ShapeUtil::ElementsIn(convolution->shape());
|
||||
current_flop_count_ = output_elements * fmas_per_output_element * kFmaFlops;
|
||||
current_properties_[kFlopsKey] =
|
||||
output_elements * fmas_per_output_element * kFmaFlops;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -347,7 +397,7 @@ Status HloCostAnalysis::HandleCrossReplicaSum(HloInstruction* crs) {
|
||||
//
|
||||
// TODO(b/33004697): Compute correct cost here, taking the actual number of
|
||||
// replicas into account.
|
||||
current_flop_count_ = ShapeUtil::ElementsIn(crs->shape());
|
||||
current_properties_[kFlopsKey] = ShapeUtil::ElementsIn(crs->shape());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -356,44 +406,43 @@ Status HloCostAnalysis::HandleRng(HloInstruction* random,
|
||||
// TODO(b/26346211): Implement better estimates for the RNG cost, since the
|
||||
// cost changes with the implementation and the distribution. For now, assume
|
||||
// the cost of each RNG is same as a transcendental operation.
|
||||
current_transcendental_count_ = ShapeUtil::ElementsIn(random->shape());
|
||||
current_properties_[kTranscendentalsKey] =
|
||||
ShapeUtil::ElementsIn(random->shape());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) {
|
||||
// Compute the cost of the fused expression.
|
||||
HloInstruction* fused_expression_root = fusion->fused_expression_root();
|
||||
// Don't compute sizes inside of fused ops. We don't use the size here and the
|
||||
// operations inside might not have a layout.
|
||||
HloCostAnalysis visitor([](const Shape&) { return 0; });
|
||||
TF_RETURN_IF_ERROR(fused_expression_root->Accept(&visitor));
|
||||
// Compute the properties of the fused expression and attribute them to the
|
||||
// fusion node. Use a dummy shape_size to avoid any errors from trying to
|
||||
// calculate the size of a shape that does not have a layout, since nodes
|
||||
// inside fusion nodes do not necessarily have a layout assigned.
|
||||
ShapeSizeFunction shape_size = [](const Shape& shape) { return 0; };
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
current_properties_,
|
||||
ProcessSubcomputation(fusion->fused_instructions_computation(),
|
||||
&shape_size));
|
||||
|
||||
// If a fusion node produces a tuple, it also produces the operands of that
|
||||
// tuple.
|
||||
current_bytes_accessed_ = 0;
|
||||
// Fusion nodes that produce a tuple also produce the entries in the tuple.
|
||||
// Ignore the memory accessed inside fused ops, since fusion is supposed to
|
||||
// prevent intermediate data from touching slow memory.
|
||||
current_properties_[kBytesAccessedKey] = 0;
|
||||
ShapeUtil::ForEachSubshape(
|
||||
fusion->shape(),
|
||||
[this](const Shape& subshape, const ShapeIndex& /*shape_index*/) {
|
||||
current_bytes_accessed_ += shape_size_(subshape);
|
||||
current_properties_[kBytesAccessedKey] += shape_size_(subshape);
|
||||
});
|
||||
|
||||
for (const HloInstruction* operand : fusion->operands()) {
|
||||
current_bytes_accessed_ += shape_size_(operand->shape());
|
||||
current_properties_[kBytesAccessedKey] += shape_size_(operand->shape());
|
||||
}
|
||||
|
||||
// Attribute the cost of the fused expression to the fusion node.
|
||||
current_transcendental_count_ = visitor.transcendental_count();
|
||||
current_flop_count_ = visitor.flop_count();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleCall(HloInstruction* call) {
|
||||
HloCostAnalysis computation_visitor(shape_size_);
|
||||
TF_RETURN_IF_ERROR(call->to_apply()->Accept(&computation_visitor));
|
||||
|
||||
current_flop_count_ = computation_visitor.flop_count();
|
||||
current_transcendental_count_ = computation_visitor.transcendental_count();
|
||||
current_bytes_accessed_ = computation_visitor.bytes_accessed();
|
||||
TF_ASSIGN_OR_RETURN(current_properties_,
|
||||
ProcessSubcomputation(call->to_apply()));
|
||||
current_should_compute_bottleneck_time_ = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -401,34 +450,38 @@ Status HloCostAnalysis::HandleCustomCall(
|
||||
HloInstruction* custom_call,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands,
|
||||
tensorflow::StringPiece custom_call_target) {
|
||||
return Unimplemented("custom-call");
|
||||
return Unimplemented("Custom-call is not implemented for HLO cost analysis.");
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleSort(HloInstruction* sort,
|
||||
HloInstruction* operand_instruction) {
|
||||
// The cost of sort is implementation dependent, so cannot determine at HLO
|
||||
// level. Assume comparison based N*log(N) sorting.
|
||||
// This assumes a comparison based N*log(N) algorithm. As for all ops, the
|
||||
// actual properties of the op depend on the backend implementation.
|
||||
int64 elements = ShapeUtil::ElementsIn(operand_instruction->shape());
|
||||
current_flop_count_ = elements * tensorflow::Log2Ceiling(elements);
|
||||
current_properties_[kFlopsKey] = elements * tensorflow::Log2Ceiling(elements);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleWhile(HloInstruction* xla_while) {
|
||||
// Since the number of iterations of the while node is not statically
|
||||
// determined, we cannot precisely compute the cost of a while node. For now
|
||||
// compute the cost of a single iteration.
|
||||
// TODO(b/26346211): Improve the cost analysis for while node.
|
||||
HloCostAnalysis body_visitor(shape_size_);
|
||||
TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&body_visitor));
|
||||
HloCostAnalysis condition_visitor(shape_size_);
|
||||
TF_RETURN_IF_ERROR(xla_while->while_condition()->Accept(&condition_visitor));
|
||||
// Since the number of iterations of the while node will not always be
|
||||
// something that we can statically analyze, we cannot precisely compute the
|
||||
// cost of a while node. For now compute the cost of a single iteration.
|
||||
//
|
||||
// TODO(b/26346211): Improve the cost analysis for while nodes.
|
||||
TF_ASSIGN_OR_RETURN(const Properties body_properties,
|
||||
ProcessSubcomputation(xla_while->while_body()));
|
||||
|
||||
current_flop_count_ =
|
||||
body_visitor.flop_count() + condition_visitor.flop_count();
|
||||
current_transcendental_count_ = body_visitor.transcendental_count() +
|
||||
condition_visitor.transcendental_count();
|
||||
current_bytes_accessed_ =
|
||||
body_visitor.bytes_accessed() + condition_visitor.bytes_accessed();
|
||||
TF_ASSIGN_OR_RETURN(const Properties condition_properties,
|
||||
ProcessSubcomputation(xla_while->while_condition()));
|
||||
|
||||
current_properties_.clear();
|
||||
for (const auto& property : body_properties) {
|
||||
current_properties_[property.first] += property.second;
|
||||
}
|
||||
for (const auto& property : condition_properties) {
|
||||
current_properties_[property.first] += property.second;
|
||||
}
|
||||
current_should_compute_bottleneck_time_ = false;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -437,19 +490,42 @@ Status HloCostAnalysis::FinishVisit(HloInstruction* root) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
float HloCostAnalysis::flop_count() const {
|
||||
return GetProperty(kFlopsKey, properties_sum_);
|
||||
}
|
||||
|
||||
float HloCostAnalysis::transcendental_count() const {
|
||||
return GetProperty(kTranscendentalsKey, properties_sum_);
|
||||
}
|
||||
|
||||
float HloCostAnalysis::bytes_accessed() const {
|
||||
return GetProperty(kBytesAccessedKey, properties_sum_);
|
||||
}
|
||||
|
||||
float HloCostAnalysis::seconds() const {
|
||||
return GetProperty(kSecondsKey, properties_sum_);
|
||||
}
|
||||
|
||||
int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const {
|
||||
auto it = hlo_to_flop_count_.find(&hlo);
|
||||
return it == hlo_to_flop_count_.end() ? 0 : it->second;
|
||||
return GetPropertyForHlo(hlo, kFlopsKey, hlo_properties_);
|
||||
}
|
||||
|
||||
int64 HloCostAnalysis::transcendental_count(const HloInstruction& hlo) const {
|
||||
auto it = hlo_to_transcendental_count_.find(&hlo);
|
||||
return it == hlo_to_transcendental_count_.end() ? 0 : it->second;
|
||||
return GetPropertyForHlo(hlo, kTranscendentalsKey, hlo_properties_);
|
||||
}
|
||||
|
||||
int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const {
|
||||
auto it = hlo_to_bytes_accessed_.find(&hlo);
|
||||
return it == hlo_to_bytes_accessed_.end() ? 0 : it->second;
|
||||
return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_);
|
||||
}
|
||||
|
||||
StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation(
|
||||
HloComputation* computation, const ShapeSizeFunction* shape_size) {
|
||||
if (shape_size == nullptr) {
|
||||
shape_size = &shape_size_;
|
||||
}
|
||||
HloCostAnalysis visitor(*shape_size, per_second_rates_);
|
||||
TF_RETURN_IF_ERROR(computation->Accept(&visitor));
|
||||
return visitor.properties();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -36,11 +36,18 @@ namespace xla {
|
||||
// operations separately from transcendental operations.
|
||||
class HloCostAnalysis : public DfsHloVisitor {
|
||||
public:
|
||||
// Each HLO is associated to a vector of properties with the indices given
|
||||
// below. Sub-classes can add further properties.
|
||||
typedef std::map<string, float> Properties;
|
||||
static constexpr char kFlopsKey[] = "flops";
|
||||
static constexpr char kTranscendentalsKey[] = "transcendentals";
|
||||
static constexpr char kBytesAccessedKey[] = "bytes accessed";
|
||||
static constexpr char kSecondsKey[] = "seconds";
|
||||
|
||||
// shape_size is a function which returns the size in bytes of the top-level
|
||||
// buffer of a shape.
|
||||
using ShapeSizeFunction = std::function<int64(const Shape&)>;
|
||||
explicit HloCostAnalysis(const ShapeSizeFunction& shape_size)
|
||||
: shape_size_(shape_size) {}
|
||||
explicit HloCostAnalysis(const ShapeSizeFunction& shape_size);
|
||||
|
||||
Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode) override;
|
||||
Status HandleElementwiseBinary(HloInstruction* hlo,
|
||||
@ -56,8 +63,7 @@ class HloCostAnalysis : public DfsHloVisitor {
|
||||
HloInstruction* lhs, HloInstruction* rhs) override;
|
||||
Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
|
||||
HloInstruction* arg, HloInstruction* max) override;
|
||||
Status HandleReducePrecision(HloInstruction* hlo,
|
||||
HloInstruction* operand) override;
|
||||
Status HandleReducePrecision(HloInstruction* hlo) override;
|
||||
Status HandleConcatenate(
|
||||
HloInstruction* concatenate,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
|
||||
@ -119,48 +125,88 @@ class HloCostAnalysis : public DfsHloVisitor {
|
||||
Status Preprocess(HloInstruction* hlo) override;
|
||||
Status Postprocess(HloInstruction* hlo) override;
|
||||
|
||||
// Returns the amount of computations in the graph.
|
||||
int64 flop_count() const { return flop_count_; }
|
||||
int64 transcendental_count() const { return transcendental_count_; }
|
||||
// Set the rates used to calculate the time taken by the computation. These
|
||||
// need to be set before visiting starts.
|
||||
void set_flops_per_second(float value) {
|
||||
per_second_rates_[kFlopsKey] = value;
|
||||
}
|
||||
void set_transcendentals_per_second(float value) {
|
||||
per_second_rates_[kTranscendentalsKey] = value;
|
||||
}
|
||||
void set_bytes_per_second(float value) {
|
||||
per_second_rates_[kBytesAccessedKey] = value;
|
||||
}
|
||||
|
||||
// Returns properties for the computation.
|
||||
float flop_count() const;
|
||||
float transcendental_count() const;
|
||||
float bytes_accessed() const;
|
||||
float seconds() const;
|
||||
|
||||
// Returns the respective cost computed for a particular HLO instruction, or 0
|
||||
// if the HLO was not found to have a cost in the analysis.
|
||||
int64 flop_count(const HloInstruction& hlo) const;
|
||||
int64 transcendental_count(const HloInstruction& hlo) const;
|
||||
|
||||
// Returns the number of bytes read/written.
|
||||
int64 bytes_accessed(const HloInstruction& hlo) const;
|
||||
int64 bytes_accessed() const { return bytes_accessed_; }
|
||||
float seconds(const HloInstruction& hlo) const;
|
||||
|
||||
const Properties& properties() const { return properties_sum_; }
|
||||
const float property(const string& key) const {
|
||||
return GetProperty(key, properties());
|
||||
}
|
||||
|
||||
protected:
|
||||
typedef std::unordered_map<const HloInstruction*, Properties> HloToProperties;
|
||||
|
||||
private:
|
||||
// An FMA counts as two floating point operations in these analyzes.
|
||||
static constexpr int64 kFmaFlops = 2;
|
||||
|
||||
HloCostAnalysis(const ShapeSizeFunction& shape_size,
|
||||
const Properties& per_second_rates);
|
||||
|
||||
// Returns the properties computed from visiting the computation rooted at the
|
||||
// given hlo. Uses shape_size_ to calculate shape sizes if shape_size is null,
|
||||
// otherwise uses shape_size_.
|
||||
StatusOr<Properties> ProcessSubcomputation(
|
||||
HloComputation* computation,
|
||||
const ShapeSizeFunction* shape_size = nullptr);
|
||||
|
||||
// Utility function to handle all element-wise operations.
|
||||
Status HandleElementwiseOp(HloInstruction* hlo_instruction);
|
||||
|
||||
// Returns 0.0f if the key is not present in the properties. Otherwise,
|
||||
// returns the value that the key maps to from the properties parameter.
|
||||
static float GetProperty(const string& key, const Properties& properties);
|
||||
|
||||
// Returns 0.0f if the hlo is not present in hlo_to_properties or if the key
|
||||
// is not present in hlo_to_properties[hlo]. Otherwise, returns the value that
|
||||
// the key maps to in the properties of the given hlo.
|
||||
static float GetPropertyForHlo(const HloInstruction& hlo, const string& key,
|
||||
const HloToProperties& hlo_to_properties);
|
||||
|
||||
// Function which computes the size of the top-level of a given shape (not
|
||||
// including nested elements, if any). If null then bytes_accessed methods
|
||||
// return an error.
|
||||
const ShapeSizeFunction shape_size_;
|
||||
|
||||
// The total number of floating point operations, transcendental operations,
|
||||
// and bytes accesses (read or written) in the computation.
|
||||
int64 flop_count_ = 0;
|
||||
int64 transcendental_count_ = 0;
|
||||
int64 bytes_accessed_ = 0;
|
||||
HloToProperties hlo_properties_;
|
||||
|
||||
// Cost counts of the current instruction. These should be set by each
|
||||
// handlers if different from the default values computed in Preprocess.
|
||||
int64 current_flop_count_;
|
||||
int64 current_transcendental_count_;
|
||||
int64 current_bytes_accessed_;
|
||||
// If true, the time taken will be computed from the rates for each property
|
||||
// and the total time will be the maximum time, which is the time of the
|
||||
// bottleneck.
|
||||
bool current_should_compute_bottleneck_time_;
|
||||
|
||||
// Mapping from HLO instructions to the cost we computed for them in the
|
||||
// course of the graph analysis.
|
||||
std::map<const HloInstruction*, int64> hlo_to_flop_count_;
|
||||
std::map<const HloInstruction*, int64> hlo_to_transcendental_count_;
|
||||
std::map<const HloInstruction*, int64> hlo_to_bytes_accessed_;
|
||||
// The properties of the currently visited instruction. A HandleFoo method can
|
||||
// modify these to change the default values computed in Preprocess.
|
||||
Properties current_properties_;
|
||||
|
||||
// The sum of the properties of all HLOs in the computation.
|
||||
Properties properties_sum_;
|
||||
|
||||
// How much of each property can be processed per second. E.g. if the property
|
||||
// is bytes accessed, this is the number of bytes that can be processed per
|
||||
// second. Is empty if no rates have been set.
|
||||
Properties per_second_rates_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(HloCostAnalysis);
|
||||
};
|
||||
|
@ -332,48 +332,64 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) {
|
||||
using FusionCostAnalysis = ::testing::Test;
|
||||
|
||||
TEST_F(FusionCostAnalysis, LoopFusion) {
|
||||
Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
|
||||
// Do this 4 times with different per-second rates to test the computation of
|
||||
// bottleneck time on fusion nodes.
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
|
||||
|
||||
// Fuse all instructions in complicated expression:
|
||||
//
|
||||
// add = Add(C1, C2)
|
||||
// clamp = Clamp(C2, add, add)
|
||||
// exp = Exp(add)
|
||||
// mul = Mul(exp, C3)
|
||||
// sub = Sub(mul, clamp)
|
||||
// tuple = Tuple({sub, sub, mul, C1})
|
||||
auto c1 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
|
||||
/*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2));
|
||||
auto c2 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
|
||||
/*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2));
|
||||
auto c3 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
|
||||
/*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2));
|
||||
// Fuse all instructions in complicated expression:
|
||||
//
|
||||
// add = Add(C1, C2)
|
||||
// clamp = Clamp(C2, add, add)
|
||||
// exp = Exp(add)
|
||||
// mul = Mul(exp, C3)
|
||||
// sub = Sub(mul, clamp)
|
||||
// tuple = Tuple({sub, sub, mul, C1})
|
||||
auto c1 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
|
||||
/*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2));
|
||||
auto c2 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
|
||||
/*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2));
|
||||
auto c3 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
|
||||
/*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2));
|
||||
|
||||
auto add =
|
||||
HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1.get(), c2.get());
|
||||
auto clamp = HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp, c2.get(),
|
||||
add.get(), add.get());
|
||||
auto exp = HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add.get());
|
||||
auto mul = HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply,
|
||||
exp.get(), c3.get());
|
||||
auto sub = HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract,
|
||||
mul.get(), clamp.get());
|
||||
auto tuple =
|
||||
HloInstruction::CreateTuple({sub.get(), sub.get(), mul.get(), c1.get()});
|
||||
auto add = HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1.get(),
|
||||
c2.get());
|
||||
auto clamp = HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp,
|
||||
c2.get(), add.get(), add.get());
|
||||
auto exp = HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add.get());
|
||||
auto mul = HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply,
|
||||
exp.get(), c3.get());
|
||||
auto sub = HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract,
|
||||
mul.get(), clamp.get());
|
||||
auto tuple = HloInstruction::CreateTuple(
|
||||
{sub.get(), sub.get(), mul.get(), c1.get()});
|
||||
|
||||
auto fusion = HloInstruction::CreateFusion(
|
||||
r2f32, HloInstruction::FusionKind::kLoop, tuple.get());
|
||||
fusion->FuseInstruction(sub.get());
|
||||
fusion->FuseInstruction(mul.get());
|
||||
fusion->FuseInstruction(exp.get());
|
||||
fusion->FuseInstruction(clamp.get());
|
||||
fusion->FuseInstruction(add.get());
|
||||
auto fusion = HloInstruction::CreateFusion(
|
||||
r2f32, HloInstruction::FusionKind::kLoop, tuple.get());
|
||||
fusion->FuseInstruction(sub.get());
|
||||
fusion->FuseInstruction(mul.get());
|
||||
fusion->FuseInstruction(exp.get());
|
||||
fusion->FuseInstruction(clamp.get());
|
||||
fusion->FuseInstruction(add.get());
|
||||
|
||||
HloCostAnalysis fusion_analysis(ShapeSize);
|
||||
ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
|
||||
// The time given these rates at i == 0 is exactly even among the properties
|
||||
// at 1.0 seconds. For other values, one of the rates is slower so that it
|
||||
// becomes the bottleneck.
|
||||
HloCostAnalysis fusion_analysis(ShapeSize);
|
||||
fusion_analysis.set_flops_per_second(16 * (i == 1 ? 1 / 2.0 : 1.0));
|
||||
fusion_analysis.set_transcendentals_per_second(4 *
|
||||
(i == 2 ? 1 / 4.0 : 1.0));
|
||||
fusion_analysis.set_bytes_per_second(64 * (i == 3 ? 1 / 8.0 : 1.0));
|
||||
ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
|
||||
|
||||
EXPECT_EQ(fusion_analysis.flop_count(), 16);
|
||||
EXPECT_EQ(fusion_analysis.transcendental_count(), 4);
|
||||
EXPECT_EQ(fusion_analysis.flop_count(), 16);
|
||||
EXPECT_EQ(fusion_analysis.transcendental_count(), 4);
|
||||
constexpr int64 bytes_accessed = sizeof(float) * 4 * 2 * 2;
|
||||
static_assert(bytes_accessed == 64, "");
|
||||
EXPECT_EQ(fusion_analysis.bytes_accessed(), bytes_accessed);
|
||||
|
||||
EXPECT_EQ(fusion_analysis.seconds(), 1 << i);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FusionCostAnalysis, NoLayout) {
|
||||
|
@ -16,14 +16,11 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iosfwd>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
@ -35,7 +32,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
@ -43,267 +39,6 @@ namespace xla {
|
||||
using ::tensorflow::strings::StrAppend;
|
||||
using ::tensorflow::strings::StrCat;
|
||||
|
||||
string HloLocation::ToString() const {
|
||||
string index_str =
|
||||
ShapeUtil::IsTuple(instruction->shape()) ? (" " + index.ToString()) : "";
|
||||
return StrCat(instruction->FullyQualifiedName(), index_str);
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloLocation& location) {
|
||||
out << location.ToString();
|
||||
return out;
|
||||
}
|
||||
|
||||
string HloUse::ToString() const {
|
||||
string index_str =
|
||||
ShapeUtil::IsTuple(instruction->operand(operand_number)->shape())
|
||||
? (" " + operand_index.ToString())
|
||||
: "";
|
||||
return StrCat(instruction->FullyQualifiedName(), ", operand ", operand_number,
|
||||
index_str);
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloUse& use) {
|
||||
out << use.ToString();
|
||||
return out;
|
||||
}
|
||||
|
||||
HloValue::HloValue(HloValue::Id id, HloInstruction* instruction,
|
||||
const ShapeIndex& index, bool is_phi)
|
||||
: id_(id), is_phi_(is_phi) {
|
||||
// The defining location is always the first element in the locations_ vector.
|
||||
AddLocation(instruction, index);
|
||||
}
|
||||
|
||||
bool HloValue::operator==(const HloValue& other) const {
|
||||
bool equal = defining_instruction() == other.defining_instruction() &&
|
||||
defining_index() == other.defining_index();
|
||||
// If the values are equal they most both be phi (or non phi).
|
||||
CHECK(!(equal && is_phi() != other.is_phi()));
|
||||
return equal;
|
||||
}
|
||||
|
||||
bool HloValue::operator!=(const HloValue& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
string HloValue::ToShortString() const {
|
||||
string index_str = ShapeUtil::IsTuple(defining_instruction()->shape())
|
||||
? defining_index().ToString()
|
||||
: "";
|
||||
return StrCat(is_phi_ ? "PHI " : "",
|
||||
defining_instruction()->FullyQualifiedName(), index_str);
|
||||
}
|
||||
|
||||
string HloValue::ToString(int indent) const {
|
||||
string indentation(indent, ' ');
|
||||
string out = StrCat(indentation, ToShortString(), ", locations:\n");
|
||||
for (const HloLocation& location : locations()) {
|
||||
StrAppend(&out, indentation, " ", location.ToString(), "\n");
|
||||
}
|
||||
StrAppend(&out, indentation, " uses:\n");
|
||||
for (const HloUse& use : uses()) {
|
||||
StrAppend(&out, indentation, " ", use.ToString(), "\n");
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns true if the instruction 'user' may use the value at the given
|
||||
// ShapeIndex in the given operand. Generally, instruction which pass through
|
||||
// values transparently without reading the value are not considered to use the
|
||||
// value.
|
||||
bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
|
||||
const HloInstruction* user) {
|
||||
switch (user->opcode()) {
|
||||
case HloOpcode::kGetTupleElement:
|
||||
case HloOpcode::kCopy:
|
||||
// These instructions only access the top-level values of their
|
||||
// operand. Non-top-level (nested) values are passed through
|
||||
// transparently.
|
||||
CHECK_EQ(operand_number, 0);
|
||||
return index.empty();
|
||||
case HloOpcode::kSelect:
|
||||
// Select does not use any nested elements of its selected-from operands
|
||||
// (operand 1 and 2)
|
||||
CHECK_GE(operand_number, 0);
|
||||
CHECK_LE(operand_number, 2);
|
||||
return operand_number == 0 || index.empty();
|
||||
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kTuple:
|
||||
// These instructions always pass through their operands transparently.
|
||||
return false;
|
||||
|
||||
case HloOpcode::kWhile:
|
||||
// Though the while instructions passes through its operands, we return
|
||||
// true because in SSA form there may be a Phi at the parameter of the
|
||||
// while which is considered a use of its incoming value because the Phi
|
||||
// input values are not passed through into the body computation. Because
|
||||
// this function is used in both SSA and non-SSA forms of the analysis
|
||||
// conservatively return true.
|
||||
return true;
|
||||
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void HloValue::AddLocation(HloInstruction* instruction,
|
||||
const ShapeIndex& index) {
|
||||
// The given location should not already exist in locations_.
|
||||
for (const HloLocation& location : locations_) {
|
||||
DCHECK(!(location.instruction == instruction && location.index == index));
|
||||
}
|
||||
|
||||
locations_.push_back(HloLocation{instruction, index});
|
||||
|
||||
// Update uses.
|
||||
for (HloInstruction* user : instruction->users()) {
|
||||
for (int64 operand_number : user->OperandIndices(instruction)) {
|
||||
if (MayUseOperandValue(operand_number, index, user)) {
|
||||
for (const HloUse& use : uses_) {
|
||||
// Verify that this use does not already exist.
|
||||
DCHECK(!(use.instruction == user &&
|
||||
use.operand_number == operand_number &&
|
||||
use.operand_index == index));
|
||||
}
|
||||
|
||||
uses_.push_back(HloUse{user, operand_number, index});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update liveout status of this HloValue.
|
||||
const HloModule& module = *instruction->parent()->parent();
|
||||
if (instruction == module.entry_computation()->root_instruction()) {
|
||||
live_out_of_module_ = true;
|
||||
}
|
||||
|
||||
if (instruction == instruction->parent()->root_instruction()) {
|
||||
live_out_of_computation_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void HloValue::RemoveLocation(HloInstruction* instruction,
|
||||
const ShapeIndex& index) {
|
||||
// The defining location cannot be removed.
|
||||
CHECK(!(instruction == defining_instruction() && index == defining_index()));
|
||||
|
||||
int64 size_before = locations_.size();
|
||||
locations_.erase(
|
||||
std::remove_if(locations_.begin(), locations_.end(),
|
||||
[instruction, &index](const HloLocation& location) {
|
||||
return location.instruction == instruction &&
|
||||
location.index == index;
|
||||
}),
|
||||
locations_.end());
|
||||
// Only a single location should have been removed.
|
||||
CHECK_EQ(locations_.size(), size_before - 1);
|
||||
|
||||
// Update uses which referred to this location.
|
||||
uses_.erase(std::remove_if(uses_.begin(), uses_.end(),
|
||||
[instruction, &index](const HloUse& use) {
|
||||
return use.instruction->operand(
|
||||
use.operand_number) == instruction &&
|
||||
use.operand_index == index;
|
||||
}),
|
||||
uses_.end());
|
||||
|
||||
// Returns whether this value is contained in the given instruction's output.
|
||||
auto is_contained_in = [this](const HloInstruction* instruction) {
|
||||
for (const HloLocation& location : locations()) {
|
||||
if (location.instruction == instruction) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
const HloModule& module = *instruction->parent()->parent();
|
||||
if (instruction == module.entry_computation()->root_instruction()) {
|
||||
// Value has been removed from a location in the entry root instruction.
|
||||
live_out_of_module_ =
|
||||
is_contained_in(module.entry_computation()->root_instruction());
|
||||
}
|
||||
if (instruction == defining_instruction()->parent()->root_instruction()) {
|
||||
// Value has been removed from the root of the computation the value has
|
||||
// been defined in.
|
||||
live_out_of_computation_ =
|
||||
is_contained_in(defining_instruction()->parent()->root_instruction());
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloValue& value) {
|
||||
out << value.ToShortString();
|
||||
return out;
|
||||
}
|
||||
|
||||
void HloValueSet::SortAndUniquifyValues() {
|
||||
std::sort(value_ids_.begin(), value_ids_.end());
|
||||
value_ids_.erase(std::unique(value_ids_.begin(), value_ids_.end()),
|
||||
value_ids_.end());
|
||||
}
|
||||
|
||||
string HloValueSet::ToString() const {
|
||||
return StrCat("HloValueSet: ", tensorflow::str_util::Join(value_ids_, ", "));
|
||||
}
|
||||
|
||||
/*static */
|
||||
HloValueSet HloValueSet::Union(
|
||||
tensorflow::gtl::ArraySlice<const HloValueSet*> inputs) {
|
||||
HloValueSet union_set;
|
||||
for (const HloValueSet* input : inputs) {
|
||||
for (HloValue::Id value_id : input->value_ids()) {
|
||||
union_set.value_ids_.push_back(value_id);
|
||||
}
|
||||
}
|
||||
union_set.SortAndUniquifyValues();
|
||||
return union_set;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) {
|
||||
out << value_set.ToString();
|
||||
return out;
|
||||
}
|
||||
|
||||
InstructionValueSet InstructionValueSet::Union(
|
||||
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
|
||||
CHECK_GT(inputs.size(), 0);
|
||||
for (int i = 1; i < inputs.size(); ++i) {
|
||||
CHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape()));
|
||||
}
|
||||
InstructionValueSet union_set(inputs[0]->shape());
|
||||
union_set.ForEachMutableElement(
|
||||
[&inputs](const ShapeIndex& index, HloValueSet* value_set) {
|
||||
std::vector<const HloValueSet*> input_sets;
|
||||
for (const InstructionValueSet* input : inputs) {
|
||||
input_sets.push_back(&input->element(index));
|
||||
}
|
||||
*value_set = HloValueSet::Union(input_sets);
|
||||
});
|
||||
return union_set;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out,
|
||||
const InstructionValueSet& instruction_value_set) {
|
||||
out << instruction_value_set.ToString();
|
||||
return out;
|
||||
}
|
||||
|
||||
string InstructionValueSet::ToString() const {
|
||||
string out =
|
||||
StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n");
|
||||
ForEachElement([this, &out](const ShapeIndex& index,
|
||||
const HloValueSet& value_set) {
|
||||
StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n");
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
HloDataflowAnalysis::HloDataflowAnalysis(HloModule* module, bool ssa_form,
|
||||
bool bitcast_defines_value)
|
||||
: module_(module),
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
@ -29,228 +29,17 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_value.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Abstraction which identifies a specific point in the XLA graph. An
|
||||
// HloLocation specifies a ShapeIndex within the output of a specific
|
||||
// instruction.
|
||||
struct HloLocation {
|
||||
HloInstruction* instruction;
|
||||
ShapeIndex index;
|
||||
|
||||
string ToString() const;
|
||||
|
||||
bool operator==(const HloLocation& other) const {
|
||||
return instruction == other.instruction && index == other.index;
|
||||
}
|
||||
bool operator!=(const HloLocation& other) const { return !(*this == other); }
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloLocation& location);
|
||||
|
||||
// Defines a single use of an HLO value.
|
||||
struct HloUse {
|
||||
// Instruction at which the value is used.
|
||||
HloInstruction* instruction;
|
||||
|
||||
// The operand number in which the value is appears.
|
||||
int64 operand_number;
|
||||
|
||||
// The shape index within the operand in which the value appears.
|
||||
ShapeIndex operand_index;
|
||||
|
||||
string ToString() const;
|
||||
|
||||
bool operator==(const HloUse& other) const {
|
||||
return instruction == other.instruction &&
|
||||
operand_number == other.operand_number &&
|
||||
operand_index == other.operand_index;
|
||||
}
|
||||
|
||||
bool operator!=(const HloUse& other) const { return !(*this == other); }
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloUse& use);
|
||||
|
||||
// Class describing a value used by the dataflow analysis. XLA arrays are
|
||||
// trivially a single HloValue. Tuples are made up of more than one HloValue: an
|
||||
// HloValue for the pointer vector, and an HloValue for each child element.
|
||||
//
|
||||
// Every HloValue is defined by a particular instruction and most instructions
|
||||
// define only a single HloValue. Instructions which define a single HloValue
|
||||
// include array-shaped instructions such as Add but also includes Tuple-shaped
|
||||
// instructions such as Tuple. The Tuple instruction defines a single HloValue
|
||||
// which is a vector of pointers to the values containing the Tuple
|
||||
// instruction's operands. Though the result of the Tuple instruction includes
|
||||
// multiple values only the top-level HloValue (the vector of pointers) is
|
||||
// defined by the Tuple instruction. The values containing the tuple elements
|
||||
// are defined by earlier instructions, usually the operands of the Tuple
|
||||
// instruction.
|
||||
//
|
||||
// Instructions which construct both the tuple *and* the tuple elements define
|
||||
// more than one HloValue. This includes (at least) tuple-shaped Constant,
|
||||
// Parameter, Infeed and While instructions. These tuple-shaped instructions do
|
||||
// not assemble a tuple from existing HloValues like the Tuple instruction does,
|
||||
// but rather define all the HloValues in the tuple.
|
||||
class HloValue {
|
||||
public:
|
||||
using Id = int64;
|
||||
|
||||
// Construct an HloValue defined by 'instruction' at shape index 'index'. If
|
||||
// is_phi is true, then this value is a phi value, for example, at the
|
||||
// parameter of a while body computation. Phi values are only used in the SSA
|
||||
// dataflow analysis (HloDataflowAnalysis::ssa_form_ is true).
|
||||
HloValue(HloValue::Id id, HloInstruction* instruction,
|
||||
const ShapeIndex& index, bool is_phi = false);
|
||||
|
||||
// Return a unique identifier for this HloValue. This value is used for stable
|
||||
// sorting and iteration
|
||||
Id id() const { return id_; }
|
||||
|
||||
// Returns whether this value is a phi value.
|
||||
bool is_phi() const { return is_phi_; }
|
||||
|
||||
// Return the location where this value is defined.
|
||||
const HloLocation& defining_location() const { return locations_[0]; }
|
||||
|
||||
// Return the instruction which defines this HloValue.
|
||||
HloInstruction* defining_instruction() const {
|
||||
return defining_location().instruction;
|
||||
}
|
||||
|
||||
// Return the shape index at which this HloValue is defined in the output of
|
||||
// its defining instruction.
|
||||
const ShapeIndex& defining_index() const { return defining_location().index; }
|
||||
|
||||
// Add or remove a location at which the HloValue appears. The definition
|
||||
// location can not be removed. The uses of the HloValue are updated.
|
||||
void AddLocation(HloInstruction* instruction, const ShapeIndex& index);
|
||||
void RemoveLocation(HloInstruction* instruction, const ShapeIndex& index);
|
||||
|
||||
// Return all locations of the HloValue in the module.
|
||||
const std::vector<HloLocation>& locations() const { return locations_; }
|
||||
|
||||
// Return all uses of the HloValue.
|
||||
const std::vector<HloUse>& uses() const { return uses_; }
|
||||
|
||||
// Get whether this HloValue is live out of the module.
|
||||
bool live_out_of_module() const { return live_out_of_module_; }
|
||||
|
||||
// Get whether this HloValue is live out of the computation it is defined in.
|
||||
bool live_out_of_computation() const { return live_out_of_computation_; }
|
||||
|
||||
bool operator==(const HloValue& other) const;
|
||||
bool operator!=(const HloValue& other) const;
|
||||
|
||||
// Return a single-line string representation of the value.
|
||||
string ToShortString() const;
|
||||
|
||||
string ToString(int indent = 0) const;
|
||||
|
||||
private:
|
||||
// Unique identifier for this HloValue. Used for stable sorting and iteration.
|
||||
const Id id_;
|
||||
|
||||
// Whether this instruction is a phi value.
|
||||
const bool is_phi_;
|
||||
|
||||
// The set of locations of this HloValue. The first element is always the
|
||||
// location of the definition.
|
||||
std::vector<HloLocation> locations_;
|
||||
|
||||
// The set of uses of this HloValue.
|
||||
std::vector<HloUse> uses_;
|
||||
|
||||
// Whether this value is live out of the HLO module.
|
||||
bool live_out_of_module_ = false;
|
||||
|
||||
// Whether this value is live out of its computation.
|
||||
bool live_out_of_computation_ = false;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value);
|
||||
|
||||
// A class representing the possible set of HloValues at a particular point
|
||||
// (shape index in the output of an instruction) in the XLA graph. This set
|
||||
// contains the set of reaching HloValue definitions. For a simple array-shaped
|
||||
// instruction like Add, the HloValueSet of the top-level of the instruction's
|
||||
// output trivially contains only the HloValue defined by the instruction. For
|
||||
// instructions which have non-trivial dataflow such as Tuple or Select, the
|
||||
// HloValueSets of the instruction's output contains one or more HloValues
|
||||
// defined by the instruction's operands or defined further up in the XLA graph.
|
||||
class HloValueSet {
|
||||
public:
|
||||
HloValueSet() = default;
|
||||
|
||||
explicit HloValueSet(tensorflow::gtl::ArraySlice<HloValue::Id> value_ids)
|
||||
: value_ids_(value_ids.begin(), value_ids.end()) {
|
||||
SortAndUniquifyValues();
|
||||
}
|
||||
|
||||
// Return the union of the given HloValueSets.
|
||||
static HloValueSet Union(
|
||||
tensorflow::gtl::ArraySlice<const HloValueSet*> inputs);
|
||||
|
||||
// Return the vector of the IDs of all HloValues in the set. Values in the
|
||||
// vector are unique and sorted.
|
||||
const std::vector<HloValue::Id>& value_ids() const { return value_ids_; }
|
||||
|
||||
// Return the unique HLO value in the set. CHECKs if the set does not contain
|
||||
// exactly one value.
|
||||
HloValue::Id GetUniqueValueId() const {
|
||||
CHECK_EQ(value_ids().size(), 1);
|
||||
return value_ids()[0];
|
||||
}
|
||||
|
||||
bool operator==(const HloValueSet& other) const {
|
||||
return value_ids() == other.value_ids();
|
||||
}
|
||||
bool operator!=(const HloValueSet& other) const { return !(*this == other); }
|
||||
|
||||
string ToString() const;
|
||||
|
||||
private:
|
||||
// Sorts value_ and removes duplicates. This should be called after adding any
|
||||
// elements to values_.
|
||||
void SortAndUniquifyValues();
|
||||
|
||||
// HloValues sorted by HloValue::Id.
|
||||
std::vector<HloValue::Id> value_ids_;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloValueSet& hlo_value);
|
||||
|
||||
// A class collecting the HloValues which might be contained in the output of
|
||||
// an HLO instruction. For array-shaped instructions, an InstructionValueSet
|
||||
// trivially holds a single HloValueSet. Tuple-shaped InstructionValueSets
|
||||
// hold multiple HloValueSets.
|
||||
class InstructionValueSet : public ShapeTree<HloValueSet> {
|
||||
public:
|
||||
InstructionValueSet(const Shape& shape) : ShapeTree<HloValueSet>(shape) {}
|
||||
|
||||
// Return the union of the given InstructionValueSets.
|
||||
static InstructionValueSet Union(
|
||||
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
|
||||
|
||||
string ToString() const;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out,
|
||||
const InstructionValueSet& instruction_value_set);
|
||||
|
||||
// Analysis which identifies all HLO values and their uses in an HLO module.
|
||||
class HloDataflowAnalysis {
|
||||
public:
|
||||
|
@ -90,7 +90,7 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
|
||||
}
|
||||
|
||||
auto result = Literal::CreateFromShape(shape);
|
||||
TF_RETURN_IF_ERROR(result.get()->Populate<bool>(
|
||||
TF_RETURN_IF_ERROR(result->Populate<bool>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
return compare_op(lhs_literal.Get<OperandT>(multi_index),
|
||||
rhs_literal.Get<OperandT>(multi_index));
|
||||
@ -119,7 +119,7 @@ StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
|
||||
|
||||
auto result = Literal::CreateFromShape(shape);
|
||||
|
||||
TF_RETURN_IF_ERROR(result.get()->Populate<ReturnT>(
|
||||
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
return unary_op(operand_literal.Get<NativeT>(multi_index));
|
||||
}));
|
||||
@ -439,7 +439,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
auto result = Literal::CreateFromShape(shape);
|
||||
|
||||
TF_RETURN_IF_ERROR(result.get()->Populate<ReturnT>(
|
||||
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
return binary_op(lhs_literal.Get<ReturnT>(multi_index),
|
||||
rhs_literal.Get<ReturnT>(multi_index));
|
||||
@ -476,7 +476,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
auto result = Literal::CreateFromShape(shape);
|
||||
|
||||
TF_RETURN_IF_ERROR(result.get()->Populate<ReturnT>(
|
||||
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
return ternary_op(lhs_literal.Get<LhsType>(multi_index),
|
||||
rhs_literal.Get<RhsType>(multi_index),
|
||||
@ -637,7 +637,7 @@ Status HloEvaluator::HandleConcatenate(
|
||||
|
||||
for (auto operand : operands) {
|
||||
const Shape& operand_shape = operand->shape();
|
||||
TF_RETURN_IF_ERROR(result_literal.get()->Copy(
|
||||
TF_RETURN_IF_ERROR(result_literal->Copy(
|
||||
GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
|
||||
AsInt64Slice(operand_shape.dimensions())));
|
||||
dest_indices[concat_dim] +=
|
||||
@ -769,9 +769,9 @@ Status HloEvaluator::HandleSlice(HloInstruction* slice,
|
||||
|
||||
DimensionVector dest_indices(slice->slice_starts().size(), 0);
|
||||
|
||||
TF_RETURN_IF_ERROR(literal.get()->Copy(GetEvaluatedLiteralFor(operand),
|
||||
slice->slice_starts(), dest_indices,
|
||||
AsInt64Slice(shape.dimensions())));
|
||||
TF_RETURN_IF_ERROR(literal->Copy(GetEvaluatedLiteralFor(operand),
|
||||
slice->slice_starts(), dest_indices,
|
||||
AsInt64Slice(shape.dimensions())));
|
||||
|
||||
evaluated_[slice] = std::move(literal);
|
||||
return Status::OK();
|
||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
|
||||
using ::tensorflow::Env;
|
||||
using ::tensorflow::WriteStringToFile;
|
||||
@ -593,6 +594,31 @@ void DumpText(const HloModule& module, const string& label,
|
||||
do_prefix ? StrCat(prefix, "-", label, ".txt") : StrCat(label, ".txt");
|
||||
string path = JoinPath(directory_path, filename);
|
||||
TF_CHECK_OK(WriteStringToFile(env, path, module.ToString()));
|
||||
LOG(INFO) << "dumping module '" << module.name() << "' to " << path;
|
||||
}
|
||||
|
||||
string MaybeDumpHloModule(const HloModule& module, const string& label,
|
||||
const HloExecutionProfile* profile) {
|
||||
VLOG(2) << "MaybeDumpHloModule called on module " << module.name();
|
||||
string graph_url;
|
||||
const DebugOptions& debug_options = module.config().debug_options();
|
||||
if (!debug_options.xla_generate_hlo_graph().empty() &&
|
||||
RE2::PartialMatch(module.name(),
|
||||
debug_options.xla_generate_hlo_graph())) {
|
||||
graph_url = DumpGraph(*module.entry_computation(), label,
|
||||
debug_options.xla_hlo_graph_addresses(),
|
||||
debug_options.xla_hlo_graph_layout(), profile);
|
||||
}
|
||||
if (!debug_options.xla_log_hlo_text().empty() &&
|
||||
RE2::PartialMatch(module.name(), debug_options.xla_log_hlo_text())) {
|
||||
LOG(INFO) << "HLO for module " << module.name();
|
||||
LOG(INFO) << "Label: " << label;
|
||||
XLA_LOG_LINES(2, module.ToString());
|
||||
}
|
||||
if (!debug_options.xla_generate_hlo_text_to().empty()) {
|
||||
DumpText(module, label, debug_options.xla_generate_hlo_text_to());
|
||||
}
|
||||
return graph_url;
|
||||
}
|
||||
|
||||
} // namespace hlo_graph_dumper
|
||||
|
@ -41,6 +41,13 @@ class GraphRendererInterface {
|
||||
virtual string RenderGraph(const string& graph, GraphKind graph_kind) = 0;
|
||||
};
|
||||
|
||||
// Dump the given HLO module if a dump is requested in its debug options. Based
|
||||
// on the debug options, either a graph dump, a text dump or both may be
|
||||
// generated. If a graph dump is generated, the description (e.g. an URL) is
|
||||
// returned; otherwise an empty string is returned.
|
||||
string MaybeDumpHloModule(const HloModule& module, const string& label,
|
||||
const HloExecutionProfile* profile = nullptr);
|
||||
|
||||
// Dumps a graph of the computation and returns a description of the rendered
|
||||
// graph (e.g., a URL) based on the renderer. The "best" renderer in the
|
||||
// registry is used.
|
||||
|
@ -1883,7 +1883,7 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
|
||||
case HloOpcode::kReverse:
|
||||
return visitor->HandleReverse(this, operands_[0]);
|
||||
case HloOpcode::kReducePrecision:
|
||||
return visitor->HandleReducePrecision(this, operands_[0]);
|
||||
return visitor->HandleReducePrecision(this);
|
||||
case HloOpcode::kSlice:
|
||||
return visitor->HandleSlice(this, operands_[0]);
|
||||
case HloOpcode::kDynamicSlice:
|
||||
|
@ -553,7 +553,7 @@ class HloInstruction {
|
||||
// number added to the variance to avoid divide-by-zero error.
|
||||
//
|
||||
// Precondition: opcode() == HloOpcode::kBatchNormTraining
|
||||
int64 epsilon() const { return epsilon_; }
|
||||
float epsilon() const { return epsilon_; }
|
||||
|
||||
// Returns the infeed configuration string. The infeed configuration includes
|
||||
// any metadata needed for the backend compiler (e.g., infeed buffer address)
|
||||
@ -751,6 +751,16 @@ class HloInstruction {
|
||||
return called_computations_;
|
||||
}
|
||||
|
||||
// Replaces all called computations based on a map function. This is needed
|
||||
// when we clone hlo_computations and want to let the instructions to point
|
||||
// to the newly cloned nodes.
|
||||
void ReplaceCalledComputations(
|
||||
std::function<HloComputation*(HloComputation*)> map_function) {
|
||||
for (int64 i = 0; i < called_computations_.size(); ++i) {
|
||||
called_computations_[i] = map_function(called_computations_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true if this instruction performs an elementwise operation on
|
||||
// `operand_idx`-th operand. An instruction is elementwise on an operand iff,
|
||||
// after performing necessary implicit broadcast
|
||||
|
@ -301,6 +301,36 @@ std::list<HloComputation*> HloModule::MakeComputationPostOrder() const {
|
||||
return post_order;
|
||||
}
|
||||
|
||||
std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) {
|
||||
VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
|
||||
auto module = MakeUnique<HloModule>(name_ + "-" + suffix);
|
||||
module->config_ = config_;
|
||||
module->entry_computation_handle_ = entry_computation_handle_;
|
||||
module->has_entry_computation_handle_ = has_entry_computation_handle_;
|
||||
|
||||
std::unordered_map<HloComputation*, HloComputation*> clone_map;
|
||||
for (auto& computation : computations_) {
|
||||
auto cloned_computation = computation->Clone(suffix);
|
||||
InsertOrDie(&clone_map, computation.get(), cloned_computation.get());
|
||||
|
||||
if (entry_computation_ == computation.get()) {
|
||||
module->AddEntryComputation(std::move(cloned_computation));
|
||||
} else {
|
||||
module->AddEmbeddedComputation(std::move(cloned_computation));
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& cloned_computation : module->computations_) {
|
||||
for (auto& instruction : cloned_computation->instructions()) {
|
||||
// Rewrite instruction's called_computation to point to the cloned
|
||||
// computations.
|
||||
instruction->ReplaceCalledComputations(
|
||||
[&](HloComputation* hlo) { return FindOrDie(clone_map, hlo); });
|
||||
}
|
||||
}
|
||||
return module;
|
||||
}
|
||||
|
||||
uint64 HloModule::RandomNew64() const {
|
||||
tensorflow::mutex_lock l(rng_mutex_);
|
||||
return rng_();
|
||||
|
@ -75,6 +75,9 @@ class HloModule {
|
||||
|
||||
const string& name() const { return name_; }
|
||||
|
||||
// Returns a deep copy of this module including all computations.
|
||||
std::unique_ptr<HloModule> Clone(const string& suffix = "clone");
|
||||
|
||||
// Return a pointer to the entry computation of the module..
|
||||
HloComputation* entry_computation() const {
|
||||
CHECK_NE(nullptr, entry_computation_);
|
||||
|
@ -81,6 +81,30 @@ TEST_F(HloModuleTest, TwoComputationsPostOrder) {
|
||||
EXPECT_EQ(computation2->name(), "Constant.1");
|
||||
}
|
||||
|
||||
TEST_F(HloModuleTest, CloneTest) {
|
||||
// Create and copy a module with a diamond call graph of computations.
|
||||
auto module = CreateNewModule();
|
||||
auto computation1 =
|
||||
module->AddEmbeddedComputation(CreateConstantComputation());
|
||||
auto computation2 =
|
||||
module->AddEmbeddedComputation(CreateCallComputation({computation1}));
|
||||
auto computation3 =
|
||||
module->AddEmbeddedComputation(CreateCallComputation({computation1}));
|
||||
module->AddEntryComputation(
|
||||
CreateCallComputation({computation2, computation3}));
|
||||
|
||||
auto post_order = module->MakeComputationPostOrder();
|
||||
auto cloned_module = module->Clone("copy");
|
||||
auto post_order_copied = cloned_module->MakeComputationPostOrder();
|
||||
|
||||
EXPECT_EQ(post_order.size(), post_order_copied.size());
|
||||
for (auto origin = post_order.begin(), copied = post_order_copied.begin();
|
||||
origin != post_order.end() && copied != post_order_copied.end();
|
||||
++origin, ++copied) {
|
||||
EXPECT_EQ((*origin)->name() + "copy", (*copied)->name());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(HloModuleTest, DiamondComputationsPostOrder) {
|
||||
// Create a module with a diamond call graph of computations.
|
||||
auto module = CreateNewModule();
|
||||
|
@ -15,13 +15,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/heap_simulator.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/logical_buffer.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -252,358 +249,6 @@ string SequentialHloOrdering::ToString() const {
|
||||
return tensorflow::str_util::Join(pieces, "\n");
|
||||
}
|
||||
|
||||
StatusOr<int64> MinimumMemoryForSequence(
|
||||
const SequentialHloOrdering::HloModuleSequence& module_sequence,
|
||||
const LogicalBuffer::SizeFunction& size_function) {
|
||||
if (module_sequence.empty()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const HloModule* module = module_sequence.begin()->first->parent();
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
|
||||
TuplePointsToAnalysis::Run(module));
|
||||
|
||||
// The absolute minimum memory required for a given sequence of instructions
|
||||
// is determined by the sequence of Alloc and Free calls on a simulated heap,
|
||||
// ignoring fragmentation. We run the heap simulation on the whole module,
|
||||
// rather than summing each computation, since it gives us a better lower
|
||||
// bound, by minimizing the liveness of sub-computations.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HeapSimulator::Result result,
|
||||
HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), *module,
|
||||
module_sequence, *points_to_analysis, size_function));
|
||||
return result.heap_size;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Class implementing a list scheduler of HLO instructions which produces a
|
||||
// sequence which minimizes memory usage.
|
||||
class ListScheduler {
|
||||
public:
|
||||
// Construct and return a memory-minimizing sequence of HLO instructions
|
||||
// containing the given HLO computation.
|
||||
static StatusOr<std::vector<const HloInstruction*>> Run(
|
||||
const HloComputation& computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_function) {
|
||||
ListScheduler scheduler(computation, points_to_analysis, size_function);
|
||||
return scheduler.CreateSchedule();
|
||||
}
|
||||
|
||||
private:
|
||||
// The scheduling priority of an instruction is first the number of bytes
|
||||
// freed by scheduling the instruction, and second (tie-breaker) by the number
|
||||
// of users. This is represented as a std::pair containing these two values
|
||||
// (first element is the bytes freed). std::pair provides the necessary
|
||||
// comparison operators.
|
||||
using Priority = std::pair<int64, int64>;
|
||||
|
||||
ListScheduler(const HloComputation& computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_function)
|
||||
: computation_(computation),
|
||||
points_to_analysis_(points_to_analysis),
|
||||
size_function_(size_function) {
|
||||
// Create a map containing the LogicalBuffer uses for each HLO
|
||||
// instruction. An HLO instruction "uses" a LogicalBuffer if the
|
||||
// LogicalBuffer is in an operand of the instruction as indicated by
|
||||
// points-to analysis.
|
||||
for (auto& instruction : computation.instructions()) {
|
||||
buffer_uses_.insert(
|
||||
{instruction.get(), std::unordered_set<const LogicalBuffer*>()});
|
||||
for (auto* operand : instruction->operands()) {
|
||||
for (const LogicalBuffer* buffer :
|
||||
points_to_analysis.GetBuffersDefinedByInstruction(operand)) {
|
||||
buffer_uses_[instruction.get()].insert(buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create map containing the number of unscheduled uses (hlo instructions)
|
||||
// of each logical buffer.
|
||||
for (auto& instruction : computation.instructions()) {
|
||||
for (auto* buffer : points_to_analysis.GetBuffersDefinedByInstruction(
|
||||
instruction.get())) {
|
||||
unscheduled_use_count_[buffer] = 0;
|
||||
}
|
||||
}
|
||||
for (auto& instruction : computation.instructions()) {
|
||||
for (const LogicalBuffer* buffer : buffer_uses_.at(instruction.get())) {
|
||||
++unscheduled_use_count_[buffer];
|
||||
}
|
||||
}
|
||||
|
||||
// Buffers live out of the computation have an implicit use at the end of
|
||||
// the computation.
|
||||
for (const LogicalBuffer* live_out_buffer :
|
||||
points_to_analysis.GetPointsToSet(computation.root_instruction())
|
||||
.CreateFlattenedSet()) {
|
||||
++unscheduled_use_count_[live_out_buffer];
|
||||
}
|
||||
}
|
||||
|
||||
// Returns whether the memory used by the given buffer should be ignored by
|
||||
// the scheduling heuristic.
|
||||
bool IgnoreBuffer(const LogicalBuffer& buffer) {
|
||||
return buffer.instruction()->opcode() == HloOpcode::kParameter ||
|
||||
buffer.instruction()->opcode() == HloOpcode::kConstant;
|
||||
}
|
||||
|
||||
// Return the number of bytes freed if the HLO instruction is scheduled.
|
||||
int64 BytesFreedIfScheduled(const HloInstruction* instruction) {
|
||||
int64 freed_bytes = 0;
|
||||
// Sum the total size of the values last used by this instruction.
|
||||
for (auto* buffer : buffer_uses_.at(instruction)) {
|
||||
if (IgnoreBuffer(*buffer)) {
|
||||
continue;
|
||||
}
|
||||
CHECK_GE(unscheduled_use_count_.at(buffer), 1);
|
||||
if (unscheduled_use_count_.at(buffer) == 1) {
|
||||
// This is the last use of the logical buffer.
|
||||
freed_bytes += size_function_(*buffer);
|
||||
}
|
||||
}
|
||||
// Then subtract the size of the value(s) defined by this instruction.
|
||||
for (auto* buffer :
|
||||
points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
|
||||
if (!IgnoreBuffer(*buffer)) {
|
||||
freed_bytes -= size_function_(*buffer);
|
||||
}
|
||||
}
|
||||
return freed_bytes;
|
||||
}
|
||||
|
||||
// Construct the scheduling priority of the given instruction.
|
||||
Priority GetPriority(const HloInstruction* instruction) {
|
||||
return {BytesFreedIfScheduled(instruction), instruction->user_count()};
|
||||
}
|
||||
|
||||
std::vector<const HloInstruction*> CreateSchedule() {
|
||||
std::vector<const HloInstruction*> schedule;
|
||||
|
||||
// Populate the ready list with instructions which have no operands or
|
||||
// control predecessors.
|
||||
std::unordered_map<const HloInstruction*, int64> unscheduled_pred_count;
|
||||
std::list<const HloInstruction*> ready_list;
|
||||
for (auto& instruction : computation_.instructions()) {
|
||||
// TODO(b/34466113): Replace this and above with successors() or
|
||||
// predecessors() when these methods are added to HloInstruction.
|
||||
for (const HloInstruction* user : instruction->users()) {
|
||||
unscheduled_pred_count[user]++;
|
||||
}
|
||||
for (const HloInstruction* succ : instruction->control_successors()) {
|
||||
unscheduled_pred_count[succ]++;
|
||||
}
|
||||
}
|
||||
for (auto& instruction : computation_.instructions()) {
|
||||
// Instruction with no operands or control predecessors will
|
||||
// not be in the map.
|
||||
if (unscheduled_pred_count.count(instruction.get()) == 0) {
|
||||
ready_list.push_back(instruction.get());
|
||||
}
|
||||
}
|
||||
|
||||
while (!ready_list.empty()) {
|
||||
// Select the highest priority HLO instruction from the ready list.
|
||||
auto best_it = ready_list.begin();
|
||||
Priority best_priority = GetPriority(*best_it);
|
||||
for (auto ready_it = std::next(ready_list.begin());
|
||||
ready_it != ready_list.end(); ++ready_it) {
|
||||
Priority priority = GetPriority(*ready_it);
|
||||
if (priority > best_priority) {
|
||||
best_it = ready_it;
|
||||
best_priority = priority;
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the selected instruction from the ready list and add it to the
|
||||
// schedule.
|
||||
const HloInstruction* best = *best_it;
|
||||
ready_list.erase(best_it);
|
||||
schedule.push_back(best);
|
||||
scheduled_instructions_.insert(best);
|
||||
|
||||
// Update the unscheduled uses of the logical buffers.
|
||||
for (const LogicalBuffer* buffer : buffer_uses_.at(best)) {
|
||||
CHECK_GT(unscheduled_use_count_.at(buffer), 0);
|
||||
--unscheduled_use_count_[buffer];
|
||||
}
|
||||
|
||||
// Add new instructions to ready list.
|
||||
auto update_pred_count = [&unscheduled_pred_count,
|
||||
&ready_list](HloInstruction* inst) {
|
||||
int64 pred_count = --unscheduled_pred_count.at(inst);
|
||||
CHECK_GE(pred_count, 0);
|
||||
if (pred_count == 0) {
|
||||
ready_list.push_back(inst);
|
||||
}
|
||||
};
|
||||
// TODO(b/34466113): Replace this and above with successors() or
|
||||
// predecessors() when these methods are added to HloInstruction.
|
||||
for (HloInstruction* user : best->users()) {
|
||||
update_pred_count(user);
|
||||
}
|
||||
for (HloInstruction* succ : best->control_successors()) {
|
||||
update_pred_count(succ);
|
||||
}
|
||||
}
|
||||
CHECK_EQ(schedule.size(), computation_.instructions().size());
|
||||
CHECK_EQ(scheduled_instructions_.size(),
|
||||
computation_.instructions().size());
|
||||
|
||||
return schedule;
|
||||
}
|
||||
|
||||
const HloComputation& computation_;
|
||||
const TuplePointsToAnalysis& points_to_analysis_;
|
||||
const LogicalBuffer::SizeFunction& size_function_;
|
||||
|
||||
// A map containing the LogicalBuffers that each instruction uses.
|
||||
std::unordered_map<const HloInstruction*,
|
||||
std::unordered_set<const LogicalBuffer*>>
|
||||
buffer_uses_;
|
||||
|
||||
// A map containing the count of unscheduled HLOs which using a particular
|
||||
// LogicalBuffer.
|
||||
std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_;
|
||||
|
||||
// Set of instructions which have been scheduled.
|
||||
std::unordered_set<const HloInstruction*> scheduled_instructions_;
|
||||
};
|
||||
|
||||
int64 SumLogicalBufferSizes(const std::vector<const LogicalBuffer*>& buffers,
|
||||
const LogicalBuffer::SizeFunction& size_function) {
|
||||
int64 size = 0;
|
||||
for (const LogicalBuffer* buffer : buffers) {
|
||||
size += size_function(*buffer);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<const HloInstruction*>> RunDFSMemoryScheduler(
|
||||
const HloComputation& computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_function) {
|
||||
// This ordering is based on DFS post-order, with a heuristic to decide which
|
||||
// operand to visit first. The heuristic is based on 'extra_users', which is
|
||||
// simply users-1 for each instruction. By subtracting 1, we're saying that
|
||||
// instructions with no users or a single user don't count; instructions with
|
||||
// lots of fan-out will be visited earlier.
|
||||
tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users;
|
||||
tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes;
|
||||
for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) {
|
||||
extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1;
|
||||
total_sizes[hlo] = SumLogicalBufferSizes(
|
||||
points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);
|
||||
tensorflow::gtl::FlatSet<const HloInstruction*> unique_operands(
|
||||
hlo->operands().begin(), hlo->operands().end());
|
||||
for (const HloInstruction* operand : unique_operands) {
|
||||
extra_users[hlo] += extra_users[operand];
|
||||
total_sizes[hlo] += total_sizes[operand];
|
||||
}
|
||||
}
|
||||
CHECK_EQ(extra_users.size(), computation.instructions().size());
|
||||
CHECK_EQ(total_sizes.size(), computation.instructions().size());
|
||||
|
||||
// Construct a total order based on DFS post-order, visiting operands in
|
||||
// decreasing cumulative extra user order, and next by cumulative size, with a
|
||||
// tiebreaker by name for determinism.
|
||||
std::vector<const HloInstruction*> sequence;
|
||||
FunctionVisitor visitor([&sequence](HloInstruction* hlo) {
|
||||
sequence.push_back(hlo);
|
||||
return Status::OK();
|
||||
});
|
||||
TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder(
|
||||
&visitor, [&extra_users, &total_sizes](const HloInstruction* a,
|
||||
const HloInstruction* b) {
|
||||
if (extra_users[a] != extra_users[b]) {
|
||||
return extra_users[a] > extra_users[b];
|
||||
}
|
||||
if (total_sizes[a] != total_sizes[b]) {
|
||||
return total_sizes[a] > total_sizes[b];
|
||||
}
|
||||
return a->name() < b->name();
|
||||
}));
|
||||
CHECK_EQ(sequence.size(), computation.instructions().size());
|
||||
return sequence;
|
||||
}
|
||||
|
||||
StatusOr<int64> MinimumMemoryForComputation(
|
||||
const HloComputation& computation,
|
||||
const std::vector<const HloInstruction*>& sequence,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_function) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HeapSimulator::Result result,
|
||||
HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation,
|
||||
sequence, points_to_analysis, size_function));
|
||||
return result.heap_size;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
|
||||
const HloComputation& computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_function) {
|
||||
// We try both a list-scheduler based ordering and a DFS based ordering, and
|
||||
// choose whichever returns a lower min-memory, not accounting for
|
||||
// fragmentation.
|
||||
//
|
||||
// Note that this is just a heuristic. One obvious inaccuracy is that the
|
||||
// memory required for sub-computations might be different when considered
|
||||
// within the caller's context. But it's good enough for now.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<const HloInstruction*> list_sequence,
|
||||
ListScheduler::Run(computation, points_to_analysis, size_function));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const int64 list_memory,
|
||||
MinimumMemoryForComputation(computation, list_sequence,
|
||||
points_to_analysis, size_function));
|
||||
VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes";
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<const HloInstruction*> dfs_sequence,
|
||||
RunDFSMemoryScheduler(computation, points_to_analysis, size_function));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const int64 dfs_memory,
|
||||
MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis,
|
||||
size_function));
|
||||
VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes";
|
||||
|
||||
if (list_memory <= dfs_memory) {
|
||||
VLOG(2) << "Chose min-memory list sequence: " << list_memory << " bytes";
|
||||
return list_sequence;
|
||||
} else {
|
||||
VLOG(2) << "Chose min-memory dfs sequence: " << dfs_memory << " bytes";
|
||||
return dfs_sequence;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<SequentialHloOrdering::HloModuleSequence>
|
||||
CreateMemoryMinimizingSequence(
|
||||
const HloModule& module, const LogicalBuffer::SizeFunction& size_function) {
|
||||
SequentialHloOrdering::HloModuleSequence sequence;
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
|
||||
TuplePointsToAnalysis::Run(&module));
|
||||
for (const auto& computation : module.computations()) {
|
||||
TF_ASSIGN_OR_RETURN(sequence[computation.get()],
|
||||
CreateMemoryMinimizingSequence(
|
||||
*computation, *points_to_analysis, size_function));
|
||||
}
|
||||
return sequence;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
|
||||
const HloComputation& computation,
|
||||
const LogicalBuffer::SizeFunction& size_function) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
|
||||
TuplePointsToAnalysis::Run(computation.parent()));
|
||||
return CreateMemoryMinimizingSequence(computation, *points_to_analysis,
|
||||
size_function);
|
||||
}
|
||||
|
||||
std::ostream& operator<<(
|
||||
std::ostream& out,
|
||||
const SequentialHloOrdering::HloModuleSequence& module_sequence) {
|
||||
|
@ -24,12 +24,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -191,24 +187,6 @@ std::ostream& operator<<(
|
||||
std::ostream& out,
|
||||
const SequentialHloOrdering::HloModuleSequence& module_sequence);
|
||||
|
||||
// Returns the minimum memory required to compute the given module sequence,
|
||||
// assuming no fragmentation.
|
||||
StatusOr<int64> MinimumMemoryForSequence(
|
||||
const SequentialHloOrdering::HloModuleSequence& module_sequence,
|
||||
const LogicalBuffer::SizeFunction& size_function);
|
||||
|
||||
// Returns an HloModuleSequence which seeks to minimize the memory required for
|
||||
// the computation. size_function is the function returning the number of bytes
|
||||
// required for a LogicalBuffer.
|
||||
StatusOr<SequentialHloOrdering::HloModuleSequence>
|
||||
CreateMemoryMinimizingSequence(
|
||||
const HloModule& module, const LogicalBuffer::SizeFunction& size_function);
|
||||
|
||||
// Overload of above that computes the sequence for a single computation.
|
||||
StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
|
||||
const HloComputation& computation,
|
||||
const LogicalBuffer::SizeFunction& size_function);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -217,67 +218,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) {
|
||||
EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param));
|
||||
}
|
||||
|
||||
class MinimumMemoryForSequenceTest : public HloTestBase {};
|
||||
|
||||
TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
|
||||
auto module = CreateNewModule();
|
||||
const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
|
||||
const Shape tuple_shape =
|
||||
ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
|
||||
|
||||
auto cond_builder = HloComputation::Builder("WhileCond");
|
||||
// Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
|
||||
HloInstruction* cond_param = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
|
||||
HloInstruction* cond_iter = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
|
||||
HloInstruction* cond_data = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
|
||||
// Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
|
||||
HloInstruction* cond_lt = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
|
||||
HloOpcode::kLt, cond_iter, cond_data));
|
||||
HloComputation* cond_computation =
|
||||
module->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
auto body_builder = HloComputation::Builder("WhileBody");
|
||||
// Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
|
||||
HloInstruction* body_param = body_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
|
||||
HloComputation* body_computation =
|
||||
module->AddEmbeddedComputation(body_builder.Build());
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
// Entry params: 8 bytes (4 bytes per param), TOTAL=8
|
||||
HloInstruction* iter = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape, "param_iter"));
|
||||
HloInstruction* data = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
|
||||
// Tuple: 16 bytes (8 bytes per pointer), TOTAL=24
|
||||
HloInstruction* tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({iter, data}));
|
||||
// While: 8 bytes (4 bytes per element), TOTAL=32
|
||||
// Both cond and body use a max of 24 bytes, TOTAL=56
|
||||
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
tuple_shape, cond_computation, body_computation, tuple));
|
||||
HloComputation* entry_computation =
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
auto size_fn = [](const LogicalBuffer& buffer) {
|
||||
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
|
||||
};
|
||||
|
||||
SequentialHloOrdering::HloModuleSequence module_sequence;
|
||||
module_sequence[cond_computation] = {cond_param, cond_iter, cond_data,
|
||||
cond_lt};
|
||||
module_sequence[body_computation] = {body_param};
|
||||
module_sequence[entry_computation] = {iter, data, tuple, while_op};
|
||||
EXPECT_EQ(56,
|
||||
MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -30,9 +31,10 @@ using ::tensorflow::strings::StrAppend;
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
void DumpModule(const Compiler::HloDumper& dumper_, const HloModule& module,
|
||||
void DumpModule(const HloModule& module,
|
||||
|
||||
const string& message) {
|
||||
dumper_(module, message);
|
||||
hlo_graph_dumper::MaybeDumpHloModule(module, message);
|
||||
VLOG(2) << "HLO " << message << ":";
|
||||
XLA_VLOG_LINES(2, module.ToString());
|
||||
}
|
||||
@ -75,7 +77,7 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
|
||||
// Emit label containing: "after foo-pass, before bar-pass".
|
||||
message.clear();
|
||||
StrAppend(&message, prefix, ", before ", pass->name());
|
||||
DumpModule(dumper_, *module, message);
|
||||
DumpModule(*module, message);
|
||||
|
||||
TF_RETURN_IF_ERROR(run_invariant_checkers());
|
||||
TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module));
|
||||
@ -85,7 +87,7 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
|
||||
StrAppend(&prefix, name(), ": after ", pass->name());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(run_invariant_checkers());
|
||||
DumpModule(dumper_, *module, prefix + ", pipeline end");
|
||||
DumpModule(*module, prefix + ", pipeline end");
|
||||
return changed;
|
||||
}
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user