- Lower dynamic convolutions into kCustomCalls. - Transform those kCustomCalls into static convolutions in dynamic padder. PiperOrigin-RevId: 339275495 Change-Id: I0e1a6c0ff7f539e63482f1de7d564dca23ab81bc
143 lines
5.6 KiB
C++
143 lines
5.6 KiB
C++
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_
|
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_
|
|
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "absl/container/flat_hash_map.h"
|
|
#include "absl/types/span.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
|
#include "tensorflow/compiler/xla/status.h"
|
|
#include "tensorflow/compiler/xla/statusor.h"
|
|
#include "tensorflow/compiler/xla/types.h"
|
|
#include "tensorflow/core/platform/macros.h"
|
|
|
|
namespace xla {
|
|
|
|
// DynamicDimensionInference analyzes each HLO instruction in a graph and
|
|
// inferences which dimensions are dynamic and which scalar instructions
|
|
// represent the runtime real size of those dynamic dimensions.
|
|
class DynamicDimensionInference {
|
|
public:
|
|
using CustomCallInferenceHandler =
|
|
std::function<Status(HloInstruction*, DynamicDimensionInference*)>;
|
|
|
|
static StatusOr<DynamicDimensionInference> Run(
|
|
HloModule* module,
|
|
CustomCallInferenceHandler custom_call_handler = nullptr);
|
|
|
|
string ToString() const;
|
|
|
|
// If the dimension `dim` of instruction `inst` at `index` has a dynamic size,
|
|
// returns a scalar HloInstruction that represents the runtime size of that
|
|
// dimension. Otherwise returns nullptr.
|
|
HloInstruction* GetDynamicSize(HloInstruction* inst, const ShapeIndex& index,
|
|
int64 dim) const;
|
|
|
|
// Returns dynamic sizes of all dimensions of `inst`'s leaf node at `index`.
|
|
// Static sizes are represented by nullptr.
|
|
std::vector<HloInstruction*> GetDynamicSizes(HloInstruction* inst,
|
|
const ShapeIndex& index) const;
|
|
|
|
// Returns if current instruction contains any dynamic dimension.
|
|
// Recursively go into tuples.
|
|
bool HasDynamicDimension(HloInstruction* inst) const;
|
|
|
|
// Forward dynamic dimension size at `dim` from `inst` to `new_inst`.
|
|
Status ForwardDynamicSize(HloInstruction* inst, HloInstruction* new_inst,
|
|
const ShapeIndex& index);
|
|
|
|
// Update the dynamic mapping so that we know dimension `dim` of instruction
|
|
// `inst` at `index` has a dynamic size, and its runtime size is represented
|
|
// by a scalar instruction `size`.
|
|
void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim,
|
|
HloInstruction* size);
|
|
|
|
// For all tensors whose dynamic dimension is `replace`, replace them with
|
|
// `with`.
|
|
void ReplaceAllDynamicDimensionUsesWith(HloInstruction* replace,
|
|
HloInstruction* with);
|
|
|
|
friend class DynamicDimensionInferenceVisitor;
|
|
|
|
private:
|
|
explicit DynamicDimensionInference(
|
|
HloModule* module, CustomCallInferenceHandler custom_call_handler);
|
|
|
|
// DynamicDimension is used as a key in the dynamic key-value mapping. It
|
|
// unambiguously represents a dynamic dimension of a instruction at a given
|
|
// index.
|
|
struct DynamicDimension {
|
|
// HloInstruction that holds the dimension.
|
|
HloInstruction* inst;
|
|
// Subshape of the instruction that holds the dimension.
|
|
ShapeIndex index;
|
|
// The dimension number of the dynamic dimension at given index of a given
|
|
// instruction.
|
|
int64 dim;
|
|
|
|
// Artifacts needed to make this struct able to be used as a `key` in absl
|
|
// maps. "friend" keywords are added so these functions can be found through
|
|
// ADL.
|
|
template <typename H>
|
|
friend H AbslHashValue(H h, const DynamicDimension& m) {
|
|
return H::combine(std::move(h), m.inst, m.index, m.dim);
|
|
}
|
|
|
|
friend bool operator==(const DynamicDimension& lhs,
|
|
const DynamicDimension& rhs) {
|
|
return lhs.inst == rhs.inst && lhs.index == rhs.index &&
|
|
lhs.dim == rhs.dim;
|
|
}
|
|
};
|
|
|
|
// Copies the internal mapping from instruction `from` to instruction `to`.
|
|
// This is useful when an instruction is replaced by the other during the
|
|
// inferencing process.
|
|
void CopyMapping(HloInstruction* from, HloInstruction* to);
|
|
|
|
// AnalyzeDynamicDimensions starts the analysis of the dynamic dimensions in
|
|
// module_.
|
|
Status AnalyzeDynamicDimensions();
|
|
|
|
// HloModule being analyzed.
|
|
HloModule* module_;
|
|
|
|
// dynamic_mapping_ holds the result of the analysis. It maps a dynamic
|
|
// dimension to a scalar HloInstruction that represents the real dynamic size
|
|
// of the dynamic dimension.
|
|
using DynamicMapping = absl::flat_hash_map<DynamicDimension, HloInstruction*>;
|
|
DynamicMapping dynamic_mapping_;
|
|
|
|
// A convenient mapping from an hlo to the set of dynamic dimensions that it
|
|
// holds.
|
|
using PerHloDynamicDimensions =
|
|
absl::flat_hash_map<HloInstruction*,
|
|
absl::flat_hash_set<DynamicDimension>>;
|
|
PerHloDynamicDimensions per_hlo_dynamic_dimensions_;
|
|
|
|
// A handler for custom calls.
|
|
CustomCallInferenceHandler custom_call_handler_;
|
|
};
|
|
|
|
} // namespace xla
|
|
|
|
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_
|