merge with master
This commit is contained in:
commit
333c864732
13
RELEASE.md
13
RELEASE.md
@ -88,6 +88,7 @@
|
||||
dataset when it is safe to do so. The optimization can be disabled via
|
||||
the `experimental_optimization.reorder_data_discarding_ops` dataset
|
||||
option.
|
||||
* `tf.data.Options` were previously immutable and can now be overriden.
|
||||
* `tf.image`:
|
||||
* Added deterministic `tf.image.stateless_random_*` functions for each
|
||||
`tf.image.random_*` function. Added a new op
|
||||
@ -106,7 +107,8 @@
|
||||
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand.
|
||||
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
|
||||
as an alternative to accepting a `callable` loss.
|
||||
* Added `beta` parameter to FTRL optimizer to match paper.
|
||||
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
|
||||
to match FTRL paper (https://research.google.com/pubs/archive/41159.pdf).
|
||||
* Added `mobilenet_v3` to keras application model.
|
||||
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
|
||||
customization of how gradients are aggregated across devices, as well as
|
||||
@ -155,6 +157,14 @@
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Tracing and Debugging:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.train.Checkpoint`:
|
||||
* Now accepts a `root` argument in the initialization, which generates a
|
||||
checkpoint with a root object. This allows users to create a `Checkpoint`
|
||||
object that is compatible with Keras `model.save_weights()` and
|
||||
`model.load_weights`. The checkpoint is also compatible with the
|
||||
checkpoint saved in the `variables/` folder in the SavedModel.
|
||||
* When restoring, `save_path` can be a path to a SavedModel. The function
|
||||
will automatically find the checkpoint in the SavedModel.
|
||||
* Other:
|
||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||
and "denylist" where possible. Please see
|
||||
@ -251,6 +261,7 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
||||
* Mutable tables now restore checkpointed values when loaded from SavedModel.
|
||||
* GPU
|
||||
* TF 2.3 includes PTX kernels only for [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0 to reduce the TF pip binary size. Earlier releases included PTX for a variety of older compute capabilities.
|
||||
* Remove environmental variable `TF_USE_CUDNN`.
|
||||
* Others
|
||||
* Retain parent namescope for ops added inside `tf.while_loop`/`tf.cond`/`tf.switch_case`.
|
||||
* Update `tf.vectorized_map` to support vectorizing `tf.while_loop` and TensorList operations.
|
||||
|
||||
@ -220,6 +220,7 @@ cc_library(
|
||||
name = "logging",
|
||||
srcs = ["logging.cc"],
|
||||
hdrs = ["logging.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":c_api_macros",
|
||||
"//tensorflow/core/platform:logging",
|
||||
|
||||
@ -240,6 +240,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/experimental/gradients:array_grad",
|
||||
"//tensorflow/c/experimental/gradients:math_grad",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/cc/profiler",
|
||||
|
||||
@ -30,18 +30,26 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesFuncRemoteOutputs) {
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/false,
|
||||
/*remote_func_outputs=*/true);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFuncRemoteOutputs) {
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/false,
|
||||
/*remote_func_outputs=*/true);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
// TODO(b/162618595): Enable this test once we remove the check of remote
|
||||
// outputs in ProcessFunctionLibraryRuntime.
|
||||
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) {
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) {
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false,
|
||||
/*remote_func_outputs=*/true);
|
||||
}
|
||||
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) {
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) {
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false,
|
||||
/*remote_func_outputs=*/true);
|
||||
|
||||
@ -169,6 +169,13 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
|
||||
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
|
||||
}
|
||||
|
||||
if (remote_func_outputs) {
|
||||
const string backing_device =
|
||||
TFE_TensorHandleBackingDeviceName(retvals[0], status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
EXPECT_EQ(backing_device, task2_name);
|
||||
}
|
||||
|
||||
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
||||
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
@ -85,7 +85,11 @@ class GraphOperation : public TracingOperation {
|
||||
return errors::FailedPrecondition(
|
||||
"GraphOperation::Reset must be called before calling SetOpName.");
|
||||
}
|
||||
op_.reset(TF_NewOperation(g_, op_type_.c_str(), op_name));
|
||||
// TODO(b/145674566): We use Graph::NewName to get a unique name here but
|
||||
// this may not be consistent with python's naming policy.
|
||||
mutex_lock l(g_->mu);
|
||||
op_.reset(new TF_OperationDescription(g_, op_type_.c_str(),
|
||||
g_->graph.NewName(op_name).c_str()));
|
||||
return Status::OK();
|
||||
}
|
||||
const string& Name() const override { return op_type_; }
|
||||
|
||||
@ -557,7 +557,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractOpSetOpName(add_op, "my_add1", s);
|
||||
TF_AbstractOpSetOpName(add_op, "my_add", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractTensor* inputs[2] = {arg0, arg1};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
@ -579,7 +579,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractOpSetOpName(add_op, "my_add2", s);
|
||||
TF_AbstractOpSetOpName(add_op, "my_add", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractTensor* inputs[2] = {arg1, arg1};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
|
||||
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
@ -23,25 +24,97 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
Status GradientRegistry::Register(const string& op_name,
|
||||
GradientFunctionFactory factory) {
|
||||
namespace {
|
||||
Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t,
|
||||
AbstractTensorHandle** result) {
|
||||
AbstractOperationPtr op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(op->Reset("ZerosLike", /*raw_device_name=*/nullptr));
|
||||
if (isa<tracing::TracingOperation>(op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
|
||||
absl::StrCat("ZerosLike", ToId(t)).c_str()));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(op->AddInput(t));
|
||||
int num_outputs = 1;
|
||||
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||
TF_RETURN_IF_ERROR(
|
||||
op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
|
||||
*result = outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class IncomingGradientsImpl : public IncomingGradients {
|
||||
public:
|
||||
explicit IncomingGradientsImpl(
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs, Context* ctx,
|
||||
DefaultGradientFunction* default_gradients)
|
||||
: grad_inputs_(grad_inputs),
|
||||
ctx_(ctx),
|
||||
default_gradients_(default_gradients) {}
|
||||
AbstractTensorHandle* operator[](int i) const override {
|
||||
return default_gradients_->get(ctx_, grad_inputs_, i);
|
||||
}
|
||||
size_t size() const override { return grad_inputs_.size(); }
|
||||
|
||||
private:
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs_;
|
||||
Context* ctx_;
|
||||
DefaultGradientFunction* default_gradients_;
|
||||
};
|
||||
|
||||
AllZerosDefaultGradients::AllZerosDefaultGradients(const ForwardOperation& op)
|
||||
: outputs_(op.outputs) {
|
||||
for (auto output : outputs_) {
|
||||
output->Ref();
|
||||
}
|
||||
}
|
||||
AbstractTensorHandle* AllZerosDefaultGradients::get(
|
||||
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs, int i) {
|
||||
if (grad_inputs[i]) {
|
||||
return grad_inputs[i];
|
||||
}
|
||||
if (cached_default_grads_[i]) {
|
||||
return cached_default_grads_[i].get();
|
||||
}
|
||||
AbstractTensorHandle* result = nullptr;
|
||||
Status s = ZerosLike(ctx->ctx, outputs_[i], &result);
|
||||
if (!s.ok()) {
|
||||
if (result) {
|
||||
result->Unref();
|
||||
}
|
||||
VLOG(1) << "Failed to create ZerosLike for index " << i;
|
||||
return nullptr;
|
||||
}
|
||||
cached_default_grads_[i].reset(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
PassThroughDefaultGradients::PassThroughDefaultGradients(
|
||||
const ForwardOperation& op) {}
|
||||
AbstractTensorHandle* PassThroughDefaultGradients::get(
|
||||
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs, int i) {
|
||||
return grad_inputs[i];
|
||||
}
|
||||
|
||||
Status GradientRegistry::Register(
|
||||
const string& op_name, BackwardFunctionFactory backward_function_factory) {
|
||||
auto iter = registry_.find(op_name);
|
||||
if (iter != registry_.end()) {
|
||||
const string error_msg = "Gradient already exists for op: " + op_name + ".";
|
||||
return errors::AlreadyExists(error_msg);
|
||||
}
|
||||
registry_.insert({op_name, factory});
|
||||
registry_.insert({op_name, backward_function_factory});
|
||||
return Status::OK();
|
||||
}
|
||||
Status GradientRegistry::Lookup(
|
||||
const ForwardOperation& op,
|
||||
std::unique_ptr<GradientFunction>* grad_fn) const {
|
||||
std::unique_ptr<BackwardFunction>* backward_function) const {
|
||||
auto iter = registry_.find(op.op_name);
|
||||
if (iter == registry_.end()) {
|
||||
const string error_msg = "No gradient defined for op: " + op.op_name + ".";
|
||||
return errors::NotFound(error_msg);
|
||||
}
|
||||
grad_fn->reset(iter->second(op));
|
||||
backward_function->reset(iter->second(op));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -92,33 +165,8 @@ AbstractTensorHandle* TapeTensor::OnesLike() const {
|
||||
}
|
||||
return outputs[0];
|
||||
}
|
||||
AbstractTensorHandle* TapeTensor::ZerosLike() const {
|
||||
AbstractOperationPtr op(ctx_->CreateOperation());
|
||||
// TODO(srbs): Consider adding a TF_RETURN_NULLPTR_IF_ERROR.
|
||||
Status s = op->Reset("ZerosLike", /*raw_device_name=*/nullptr);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (isa<tracing::TracingOperation>(op.get())) {
|
||||
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
|
||||
absl::StrCat("ZerosLike", ToId(handle_)).c_str());
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
s = op->AddInput(handle_);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
int num_outputs = 1;
|
||||
// TODO(srbs): Figure out who is in charge of releasing this.
|
||||
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return outputs[0];
|
||||
}
|
||||
|
||||
AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; }
|
||||
|
||||
// Returns the number of elements in the gradient tensor.
|
||||
int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const {
|
||||
@ -159,13 +207,16 @@ AbstractTensorHandle* TapeVSpace::AggregateGradients(
|
||||
|
||||
// Calls the passed-in backward function.
|
||||
Status TapeVSpace::CallBackwardFunction(
|
||||
GradientFunction* backward_function,
|
||||
BackwardFunction* backward_function,
|
||||
const std::vector<int64>& unneeded_gradients,
|
||||
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
|
||||
std::vector<AbstractTensorHandle*>* result) const {
|
||||
if (backward_function == nullptr) return Status::OK();
|
||||
Context ctx = {ctx_};
|
||||
return backward_function->Compute(&ctx, output_gradients, result);
|
||||
IncomingGradientsImpl incoming_gradients(
|
||||
output_gradients, &ctx, backward_function->GetDefaultGradientFunction());
|
||||
return backward_function->GetGradientFunction()->Compute(
|
||||
&ctx, incoming_gradients, result);
|
||||
}
|
||||
|
||||
// Looks up the ID of a Gradient.
|
||||
@ -373,15 +424,15 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx,
|
||||
}
|
||||
tape->RecordOperation(
|
||||
op_->Name(), tape_tensors, input_ids, input_dtypes,
|
||||
[registry, forward_op_]() -> GradientFunction* {
|
||||
std::unique_ptr<GradientFunction> grad_fn;
|
||||
Status s = registry.Lookup(*forward_op_, &grad_fn);
|
||||
[registry, forward_op_]() -> BackwardFunction* {
|
||||
std::unique_ptr<BackwardFunction> backward_fn;
|
||||
Status s = registry.Lookup(*forward_op_, &backward_fn);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return grad_fn.release();
|
||||
return backward_fn.release();
|
||||
},
|
||||
[](GradientFunction* ptr) {
|
||||
[](BackwardFunction* ptr) {
|
||||
if (ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
@ -55,18 +55,25 @@ struct Context {
|
||||
public:
|
||||
AbstractContext* ctx;
|
||||
};
|
||||
|
||||
class IncomingGradients {
|
||||
public:
|
||||
virtual AbstractTensorHandle* operator[](int i) const = 0;
|
||||
virtual size_t size() const = 0;
|
||||
virtual ~IncomingGradients() {}
|
||||
};
|
||||
|
||||
class GradientFunction {
|
||||
public:
|
||||
// TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in
|
||||
// `grad_inputs`.
|
||||
virtual Status Compute(Context* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
virtual Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
std::vector<AbstractTensorHandle*>* grad_outputs) = 0;
|
||||
virtual ~GradientFunction() {}
|
||||
};
|
||||
|
||||
// Metadata from the forward operation that is made available to the
|
||||
// gradient registerer to instantiate a GradientFunction.
|
||||
// gradient registerer to instantiate a BackwardFunction.
|
||||
struct ForwardOperation {
|
||||
public:
|
||||
string op_name;
|
||||
@ -76,18 +83,86 @@ struct ForwardOperation {
|
||||
AbstractContext* ctx;
|
||||
};
|
||||
|
||||
using GradientFunctionFactory =
|
||||
std::function<GradientFunction*(const ForwardOperation& op)>;
|
||||
|
||||
// Map from op name to a `GradientFunctionFactory`.
|
||||
class GradientRegistry {
|
||||
// Interface for building default zeros gradients for op outputs which are
|
||||
// missing incoming gradients. Custom implementations of this can be used to
|
||||
// control which of the forward op's output tensors/their metadata needs to
|
||||
// be kept around in memory to build the default zeros grad.
|
||||
//
|
||||
// Some common helper implementations are provided below.
|
||||
class DefaultGradientFunction {
|
||||
public:
|
||||
Status Register(const string& op, GradientFunctionFactory factory);
|
||||
Status Lookup(const ForwardOperation& op,
|
||||
std::unique_ptr<GradientFunction>* grad_fn) const;
|
||||
virtual AbstractTensorHandle* get(
|
||||
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
int i) = 0;
|
||||
virtual ~DefaultGradientFunction() {}
|
||||
};
|
||||
|
||||
// Returns zeros for any `nullptr` in `grad_inputs`.
|
||||
//
|
||||
// This may require keeping track of all of forward op's output
|
||||
// tensors and hence may incur a higher memory footprint. Use sparingly.
|
||||
//
|
||||
// Multiple calls to `AllZerosDefaultGradients::get` return the same tensor
|
||||
// handle.
|
||||
//
|
||||
// The destructor of this class `Unref`'s any cached tensor handles so users of
|
||||
// those tensor handles should `Ref` them in order to keep them alive if needed.
|
||||
class AllZerosDefaultGradients : public DefaultGradientFunction {
|
||||
public:
|
||||
explicit AllZerosDefaultGradients(const ForwardOperation& op);
|
||||
AbstractTensorHandle* get(Context* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
int i) override;
|
||||
|
||||
private:
|
||||
absl::flat_hash_map<string, GradientFunctionFactory> registry_;
|
||||
// TODO(srbs): We do not always need to keep the tensors around. In immediate
|
||||
// execution mode we just need to store the shape and dtype. During tracing
|
||||
// we may need to keep the tensor around if the shape is not full defined.
|
||||
std::vector<AbstractTensorHandle*> outputs_;
|
||||
std::vector<AbstractTensorHandlePtr> cached_default_grads_;
|
||||
};
|
||||
|
||||
// Passes through `grad_inputs` as-is. The `GradientFunction`
|
||||
// will be expected to deal with nullptr in `grad_inputs` if any.
|
||||
class PassThroughDefaultGradients : public DefaultGradientFunction {
|
||||
public:
|
||||
explicit PassThroughDefaultGradients(const ForwardOperation& op);
|
||||
AbstractTensorHandle* get(Context* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
int i) override;
|
||||
};
|
||||
|
||||
// A `BackwardFunction` wraps a `GradientFunction` and a
|
||||
// `DefaultGradientFunction`. Both are owned by this class' instance.
|
||||
class BackwardFunction {
|
||||
public:
|
||||
BackwardFunction(GradientFunction* gradient_function,
|
||||
DefaultGradientFunction* default_gradients)
|
||||
: gradient_function_(gradient_function),
|
||||
default_gradients_(default_gradients) {}
|
||||
GradientFunction* GetGradientFunction() { return gradient_function_.get(); }
|
||||
DefaultGradientFunction* GetDefaultGradientFunction() {
|
||||
return default_gradients_.get();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<GradientFunction> gradient_function_;
|
||||
std::unique_ptr<DefaultGradientFunction> default_gradients_;
|
||||
};
|
||||
|
||||
using BackwardFunctionFactory =
|
||||
std::function<BackwardFunction*(const ForwardOperation& op)>;
|
||||
|
||||
// Map from op name to a `BackwardFunctionFactory`.
|
||||
class GradientRegistry {
|
||||
public:
|
||||
Status Register(const string& op,
|
||||
BackwardFunctionFactory backward_function_factory);
|
||||
Status Lookup(const ForwardOperation& op,
|
||||
std::unique_ptr<BackwardFunction>* backward_function) const;
|
||||
|
||||
private:
|
||||
absl::flat_hash_map<string, BackwardFunctionFactory> registry_;
|
||||
};
|
||||
|
||||
// Returns a unique id for the tensor which is used by the tape to build
|
||||
@ -106,9 +181,16 @@ int64 ToId(AbstractTensorHandle* t);
|
||||
// allow us to trace the data dependencies between operations and hence compute
|
||||
// gradients.
|
||||
//
|
||||
// This also implements `ZerosLike` and `OnesLike` to create the default
|
||||
// This also implements `OnesLike` to create the default
|
||||
// incoming gradients for tensors which do not already have an incoming
|
||||
// gradient.
|
||||
//
|
||||
// `ZerosLike` is not expected to be called and returns a nullptr. The creation
|
||||
// of default zeros grads is handled by the `DefaultGradientFunction` registered
|
||||
// for each op.
|
||||
// TODO(srbs): We need to define `ZerosLike` here to keep the compiler happy.
|
||||
// Figure out a way to avoid this.
|
||||
// TODO(srbs): Should ZerosLike check-fail instead of returning nullptr?
|
||||
class TapeTensor {
|
||||
public:
|
||||
TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx);
|
||||
@ -123,7 +205,7 @@ class TapeTensor {
|
||||
|
||||
private:
|
||||
AbstractTensorHandle* handle_;
|
||||
// The context where OnesLike and ZerosLike ops are to be created.
|
||||
// The context where OnesLike ops are to be created.
|
||||
AbstractContext* ctx_;
|
||||
};
|
||||
|
||||
@ -132,7 +214,7 @@ class TapeTensor {
|
||||
// gradient and for performing gradient aggregation.
|
||||
// See `tensorflow::eager::VSpace` for more details.
|
||||
class TapeVSpace
|
||||
: public eager::VSpace<AbstractTensorHandle, GradientFunction, TapeTensor> {
|
||||
: public eager::VSpace<AbstractTensorHandle, BackwardFunction, TapeTensor> {
|
||||
public:
|
||||
explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {}
|
||||
~TapeVSpace() override {}
|
||||
@ -147,7 +229,7 @@ class TapeVSpace
|
||||
|
||||
// Calls the passed-in backward function.
|
||||
Status CallBackwardFunction(
|
||||
GradientFunction* backward_function,
|
||||
BackwardFunction* backward_function,
|
||||
const std::vector<int64>& unneeded_gradients,
|
||||
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
|
||||
std::vector<AbstractTensorHandle*>* result) const override;
|
||||
@ -168,8 +250,14 @@ class TapeVSpace
|
||||
};
|
||||
|
||||
// A tracing/immediate-execution agnostic tape.
|
||||
//
|
||||
// Gradient functions defined for this library support handling null incoming
|
||||
// gradients. `Tape::ComputeGradient` should be called with
|
||||
// `build_default_zeros_grads=false`. Calling with
|
||||
// `build_default_zeros_grads=true` (the default) is equivalent but just results
|
||||
// in extra work because `TapeTensor::ZerosLike` returns a `nullptr` anyway.
|
||||
using Tape = tensorflow::eager::GradientTape<AbstractTensorHandle,
|
||||
GradientFunction, TapeTensor>;
|
||||
BackwardFunction, TapeTensor>;
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/experimental/gradients/array_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
@ -50,6 +51,7 @@ class CppGradients
|
||||
Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -94,6 +96,26 @@ Status Exp(AbstractContext* ctx, Tape* tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `IdentityN(inputs)` and records it on the tape.
|
||||
Status IdentityN(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr identity_n_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(Reset(identity_n_op.get(), "IdentityN",
|
||||
/*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(identity_n_op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<TracingOperation>(identity_n_op.get())
|
||||
->SetOpName("my_identity_n"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInputList(identity_n_op.get(), inputs, &forward_op));
|
||||
int num_retvals = outputs.size();
|
||||
return Execute(identity_n_op.get(), ctx, outputs, &num_retvals, &forward_op,
|
||||
tape, registry);
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] + inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
@ -116,7 +138,8 @@ Status AddGradModel(AbstractContext* ctx,
|
||||
vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads));
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto add_output : add_outputs) {
|
||||
add_output->Unref();
|
||||
}
|
||||
@ -146,7 +169,8 @@ Status ExpGradModel(AbstractContext* ctx,
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(exp_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads));
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto exp_output : exp_outputs) {
|
||||
exp_output->Unref();
|
||||
}
|
||||
@ -155,6 +179,41 @@ Status ExpGradModel(AbstractContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// ignored, y = IdentityN(inputs[0], inputs[1])
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
// This should return [nullptr, 1].
|
||||
Status IdentityNGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0]));
|
||||
tape->Watch(ToId(inputs[1]));
|
||||
|
||||
vector<AbstractTensorHandle*> identity_n_outputs(2);
|
||||
TF_RETURN_IF_ERROR(IdentityN(ctx, tape, inputs,
|
||||
absl::MakeSpan(identity_n_outputs), registry));
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(identity_n_outputs[1])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto identity_n_output : identity_n_outputs) {
|
||||
identity_n_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractContext* BuildFunction(const char* fn_name) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -389,18 +448,79 @@ TEST_P(CppGradients, TestExpGrad) {
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestIdentityNGrad) {
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x1)
|
||||
// tape.watch(x2)
|
||||
// unused, y = IdentityN([x1, x2])
|
||||
// outputs = tape.gradient(y, [x1, x2])
|
||||
// Expected: [nullptr, 1]
|
||||
//
|
||||
// This test is interesting because the current implementation of GradientTape
|
||||
// would return [0, 1] whereas we use build_default_zeros_grads=false here
|
||||
// so we get back [nullptr, 1].
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x1;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x1.reset(x_raw);
|
||||
}
|
||||
AbstractTensorHandlePtr x2;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x2.reset(x_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(IdentityNGradModel, ctx.get(), {x1.get(), x2.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
EXPECT_EQ(outputs[0], nullptr);
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[1], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
// TODO(b/160888630): Enable this test with mlir after AddInputList is
|
||||
// supported. It is needed for AddN op which is used for gradient aggregation.
|
||||
// supported. It is needed for IdentityN.
|
||||
// TODO(b/164171226): Enable this test with tfrt after AddInputList is
|
||||
// supported. It is needed for IdentityN.
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(true, false),
|
||||
::testing::Combine(::testing::Values("graphdef"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
::testing::Combine(::testing::Values("graphdef"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#endif
|
||||
|
||||
@ -80,6 +80,21 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "parallel_device_lib_test",
|
||||
srcs = ["parallel_device_lib_test.cc"],
|
||||
deps = [
|
||||
":parallel_device_lib",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device_testlib",
|
||||
testonly = 1,
|
||||
|
||||
@ -118,6 +118,9 @@ class DeviceThread {
|
||||
int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_);
|
||||
// Outputs
|
||||
std::vector<TensorHandlePtr> op_outputs_ TF_GUARDED_BY(execution_mutex_);
|
||||
// TF_Status is an incomplete type and so can't be stack allocated. To avoid
|
||||
// unnecessary allocations each Execute call, we keep one heap-allocated
|
||||
// version for the thread.
|
||||
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
|
||||
|
||||
const std::string device_;
|
||||
@ -188,6 +191,9 @@ std::vector<TensorHandlePtr> DeviceThread::Join(TF_Status* status) {
|
||||
if (TF_GetCode(status_.get()) != TF_OK) {
|
||||
TF_SetStatus(status, TF_GetCode(status_.get()),
|
||||
TF_Message(status_.get()));
|
||||
// Reset the member `status_` so future op executions (after recovery from
|
||||
// the bad `status`) start with an OK status.
|
||||
TF_SetStatus(status_.get(), TF_OK, "");
|
||||
}
|
||||
execution_state_ = ExecutionState::kIdle;
|
||||
result = std::move(op_outputs_);
|
||||
@ -319,21 +325,36 @@ ParallelDevice::Execute(TFE_Context* context,
|
||||
std::move(device_inputs), attributes,
|
||||
expected_max_outputs);
|
||||
}
|
||||
StatusPtr first_bad_status(nullptr);
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
DeviceThread* device_thread = device_threads_[device_index].get();
|
||||
per_device_output_tensors.push_back(device_thread->Join(status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
// We will run every Join even if there are bad statuses in case the user
|
||||
// wants to recover and continue running ops on the parallel device (which
|
||||
// would otherwise deadlock).
|
||||
if (TF_GetCode(status) != TF_OK && first_bad_status == nullptr) {
|
||||
first_bad_status.reset(TF_NewStatus());
|
||||
TF_SetStatus(first_bad_status.get(), TF_GetCode(status),
|
||||
TF_Message(status));
|
||||
}
|
||||
|
||||
if (device_index == 0) {
|
||||
first_op_output_count = per_device_output_tensors.rbegin()->size();
|
||||
} else {
|
||||
if (per_device_output_tensors.rbegin()->size() != first_op_output_count) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
if (first_bad_status == nullptr &&
|
||||
per_device_output_tensors.rbegin()->size() != first_op_output_count) {
|
||||
first_bad_status.reset(TF_NewStatus());
|
||||
TF_SetStatus(first_bad_status.get(), TF_INTERNAL,
|
||||
"Parallel ops produced different numbers of tensors.");
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (first_bad_status != nullptr) {
|
||||
TF_SetStatus(status, TF_GetCode(first_bad_status.get()),
|
||||
TF_Message(first_bad_status.get()));
|
||||
return result;
|
||||
}
|
||||
// For each output of the original operation, pack the per-device
|
||||
// TensorHandles we've computed into a single parallel TensorHandle.
|
||||
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
|
||||
|
||||
@ -0,0 +1,84 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace parallel_device {
|
||||
|
||||
TEST(PARALLEL_DEVICE_LIB, TestOpWithError) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::vector<std::string> devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
ParallelDevice parallel_device(std::move(devices));
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> handle_op(
|
||||
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
auto outputs =
|
||||
parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(),
|
||||
"VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
|
||||
/*expected_max_outputs=*/1, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs;
|
||||
std::vector<ParallelTensor*> handle_inputs;
|
||||
handle_inputs.reserve(handles.size());
|
||||
for (auto& handle : handles) {
|
||||
handle_inputs.push_back(handle.get());
|
||||
}
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> read_op(
|
||||
TFE_NewOp(context.get(), "ReadVariableOp", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetAttrType(read_op.get(), "dtype", TF_FLOAT);
|
||||
parallel_device.Execute(context.get(), handle_inputs, "ReadVariableOp",
|
||||
TFE_OpGetAttrs(read_op.get()),
|
||||
/*expected_max_outputs=*/1, status.get());
|
||||
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
|
||||
TF_SetStatus(status.get(), TF_OK, "");
|
||||
|
||||
// Check that ops still run successfully on the device.
|
||||
parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(),
|
||||
"VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
|
||||
/*expected_max_outputs=*/1, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
} // namespace parallel_device
|
||||
} // namespace tensorflow
|
||||
@ -146,13 +146,16 @@ class GradientTape {
|
||||
// once) and produces the gradient of the target tensors with respect to the
|
||||
// source tensors. The output gradients are used if not empty and not
|
||||
// null. The result is populated with one tensor per target element.
|
||||
// When running backward functions, builds zeros-like tensors for
|
||||
// incoming grads which are nullptrs, unless `build_default_zeros_grads`
|
||||
// is set to false.
|
||||
Status ComputeGradient(
|
||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
|
||||
const gtl::ArraySlice<int64> target_tensor_ids,
|
||||
const gtl::ArraySlice<int64> source_tensor_ids,
|
||||
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
|
||||
gtl::ArraySlice<Gradient*> output_gradients,
|
||||
std::vector<Gradient*>* result);
|
||||
std::vector<Gradient*>* result, bool build_default_zeros_grads = true);
|
||||
|
||||
bool IsPersistent() const { return persistent_; }
|
||||
|
||||
@ -655,8 +658,8 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
|
||||
const gtl::ArraySlice<int64> target_tensor_ids,
|
||||
const gtl::ArraySlice<int64> source_tensor_ids,
|
||||
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
|
||||
gtl::ArraySlice<Gradient*> output_gradients,
|
||||
std::vector<Gradient*>* result) {
|
||||
gtl::ArraySlice<Gradient*> output_gradients, std::vector<Gradient*>* result,
|
||||
bool build_default_zeros_grads) {
|
||||
std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
|
||||
source_tensor_ids.end());
|
||||
BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
|
||||
@ -717,14 +720,14 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
|
||||
const int64 id = trace.output_tensor_info[i].GetID();
|
||||
auto grad_it = gradients.find(id);
|
||||
if (grad_it == gradients.end()) {
|
||||
auto func_name_it =
|
||||
FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
|
||||
if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() &&
|
||||
func_name_it->second.find(i) != func_name_it->second.end()) {
|
||||
out_gradients.push_back(nullptr);
|
||||
} else {
|
||||
out_gradients.push_back(nullptr);
|
||||
zero_indices.push_back(i);
|
||||
out_gradients.push_back(nullptr);
|
||||
if (build_default_zeros_grads) {
|
||||
auto func_name_it =
|
||||
FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
|
||||
if (func_name_it == FunctionsAcceptingNoneForIndicesMap()->end() ||
|
||||
func_name_it->second.find(i) == func_name_it->second.end()) {
|
||||
zero_indices.push_back(i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
any_gradient_nonzero = true;
|
||||
@ -745,6 +748,7 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
|
||||
}
|
||||
}
|
||||
std::vector<Gradient*> in_gradients;
|
||||
DCHECK(build_default_zeros_grads || zero_indices.empty());
|
||||
if (any_gradient_nonzero) {
|
||||
for (const auto i : zero_indices) {
|
||||
out_gradients[i] = trace.output_tensor_info[i].ZerosLike();
|
||||
|
||||
@ -26,6 +26,7 @@ cc_library(
|
||||
}),
|
||||
deps = [
|
||||
":aws_crypto",
|
||||
"//tensorflow/c:logging",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"@aws",
|
||||
|
||||
@ -38,6 +38,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h"
|
||||
#include "tensorflow/c/logging.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// Implementation of a filesystem for S3 environments.
|
||||
@ -281,6 +282,7 @@ void Cleanup(TF_RandomAccessFile* file) {
|
||||
|
||||
static int64_t ReadS3Client(S3File* s3_file, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status) {
|
||||
TF_VLog(3, "ReadFile using S3Client\n");
|
||||
Aws::S3::Model::GetObjectRequest get_object_request;
|
||||
get_object_request.WithBucket(s3_file->bucket).WithKey(s3_file->object);
|
||||
Aws::String bytes =
|
||||
@ -306,12 +308,14 @@ static int64_t ReadS3Client(S3File* s3_file, uint64_t offset, size_t n,
|
||||
|
||||
static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status) {
|
||||
TF_VLog(3, "Using TransferManager\n");
|
||||
auto create_download_stream = [&]() {
|
||||
return Aws::New<TFS3UnderlyingStream>(
|
||||
"S3ReadStream",
|
||||
Aws::New<Aws::Utils::Stream::PreallocatedStreamBuf>(
|
||||
"S3ReadStream", reinterpret_cast<unsigned char*>(buffer), n));
|
||||
};
|
||||
TF_VLog(3, "Created stream to read with transferManager\n");
|
||||
auto handle = s3_file->transfer_manager->DownloadFile(
|
||||
s3_file->bucket, s3_file->object, offset, n, create_download_stream);
|
||||
handle->WaitUntilFinished();
|
||||
@ -322,6 +326,10 @@ static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n,
|
||||
Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE &&
|
||||
retries++ < kDownloadRetries) {
|
||||
// Only failed parts will be downloaded again.
|
||||
TF_VLog(
|
||||
1,
|
||||
"Retrying read of s3://%s/%s after failure. Current retry count: %u\n",
|
||||
s3_file->bucket.c_str(), s3_file->object.c_str(), retries);
|
||||
s3_file->transfer_manager->RetryDownload(handle);
|
||||
handle->WaitUntilFinished();
|
||||
}
|
||||
@ -341,6 +349,8 @@ static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n,
|
||||
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status) {
|
||||
auto s3_file = static_cast<S3File*>(file->plugin_file);
|
||||
TF_VLog(1, "ReadFilefromS3 s3://%s/%s from %u for n: %u\n",
|
||||
s3_file->bucket.c_str(), s3_file->object.c_str(), offset, n);
|
||||
if (s3_file->use_multi_part_download)
|
||||
return ReadS3TransferManager(s3_file, offset, n, buffer, status);
|
||||
else
|
||||
@ -416,6 +426,8 @@ void Sync(const TF_WritableFile* file, TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return;
|
||||
}
|
||||
TF_VLog(1, "WriteFileToS3: s3://%s/%s\n", s3_file->bucket.c_str(),
|
||||
s3_file->object.c_str());
|
||||
auto position = static_cast<int64_t>(s3_file->outfile->tellp());
|
||||
auto handle = s3_file->transfer_manager->UploadFile(
|
||||
s3_file->outfile, s3_file->bucket, s3_file->object,
|
||||
@ -426,6 +438,10 @@ void Sync(const TF_WritableFile* file, TF_Status* status) {
|
||||
while (handle->GetStatus() == Aws::Transfer::TransferStatus::FAILED &&
|
||||
retries++ < kUploadRetries) {
|
||||
// if multipart upload was used, only the failed parts will be re-sent
|
||||
TF_VLog(1,
|
||||
"Retrying upload of s3://%s/%s after failure. Current retry count: "
|
||||
"%u\n",
|
||||
s3_file->bucket.c_str(), s3_file->object.c_str(), retries);
|
||||
s3_file->transfer_manager->RetryUpload(s3_file->outfile, handle);
|
||||
handle->WaitUntilFinished();
|
||||
}
|
||||
@ -613,6 +629,7 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void Stat(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_FileStatistics* stats, TF_Status* status) {
|
||||
TF_VLog(1, "Stat on path: %s\n", path);
|
||||
Aws::String bucket, object;
|
||||
ParseS3Path(path, true, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
@ -737,6 +754,8 @@ static void SimpleCopyFile(const Aws::String& source,
|
||||
const Aws::String& bucket_dst,
|
||||
const Aws::String& object_dst, S3File* s3_file,
|
||||
TF_Status* status) {
|
||||
TF_VLog(1, "SimpleCopyFile from %s to %s/%s\n", bucket_dst.c_str(),
|
||||
object_dst.c_str());
|
||||
Aws::S3::Model::CopyObjectRequest copy_object_request;
|
||||
copy_object_request.WithCopySource(source)
|
||||
.WithBucket(bucket_dst)
|
||||
@ -801,6 +820,8 @@ static void MultiPartCopy(const Aws::String& source,
|
||||
const Aws::String& object_dst, const size_t num_parts,
|
||||
const uint64_t file_size, S3File* s3_file,
|
||||
TF_Status* status) {
|
||||
TF_VLog(1, "MultiPartCopy from %s to %s/%s\n", bucket_dst.c_str(),
|
||||
object_dst.c_str());
|
||||
Aws::S3::Model::CreateMultipartUploadRequest create_multipart_upload_request;
|
||||
create_multipart_upload_request.WithBucket(bucket_dst).WithKey(object_dst);
|
||||
|
||||
@ -827,6 +848,8 @@ static void MultiPartCopy(const Aws::String& source,
|
||||
auto chunk_size =
|
||||
s3_file->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD];
|
||||
|
||||
TF_VLog(1, "Copying from %s in %u parts of size %u each\n", source.c_str(),
|
||||
num_parts, chunk_size);
|
||||
size_t retries = 0;
|
||||
while (retries++ < 3) {
|
||||
// Queue up parts.
|
||||
@ -891,6 +914,9 @@ static void MultiPartCopy(const Aws::String& source,
|
||||
status);
|
||||
} else {
|
||||
// Retry.
|
||||
TF_Log(TF_ERROR,
|
||||
"Retrying failed copy of part %u due to an error with S3\n",
|
||||
part_number);
|
||||
num_finished_parts--;
|
||||
}
|
||||
}
|
||||
@ -967,6 +993,7 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
|
||||
|
||||
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
TF_VLog(1, "DeleteFile: %s\n", path);
|
||||
Aws::String bucket, object;
|
||||
ParseS3Path(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
@ -985,6 +1012,7 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
TF_VLog(1, "CreateDir: %s\n", path);
|
||||
Aws::String bucket, object;
|
||||
ParseS3Path(path, true, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
@ -1026,6 +1054,7 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
TF_VLog(1, "DeleteDir: %s\n", path);
|
||||
Aws::String bucket, object;
|
||||
ParseS3Path(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
@ -1060,6 +1089,7 @@ void DeleteDir(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
void RenameFile(const TF_Filesystem* filesystem, const char* src,
|
||||
const char* dst, TF_Status* status) {
|
||||
TF_VLog(1, "RenameFile from: %s to %s\n", src, dst);
|
||||
Aws::String bucket_src, object_src;
|
||||
ParseS3Path(src, false, &bucket_src, &object_src, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
@ -1120,6 +1150,7 @@ void RenameFile(const TF_Filesystem* filesystem, const char* src,
|
||||
|
||||
int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
||||
char*** entries, TF_Status* status) {
|
||||
TF_VLog(1, "GetChildren for path: %s\n", path);
|
||||
Aws::String bucket, prefix;
|
||||
ParseS3Path(path, true, &bucket, &prefix, status);
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
|
||||
@ -3,6 +3,24 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "array_grad",
|
||||
srcs = ["array_grad.cc"],
|
||||
hdrs = [
|
||||
"array_grad.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/c/eager:gradients",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "math_grad",
|
||||
srcs = ["math_grad.cc"],
|
||||
|
||||
48
tensorflow/c/experimental/gradients/array_grad.cc
Normal file
48
tensorflow/c/experimental/gradients/array_grad.cc
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/gradients/array_grad.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace {
|
||||
using std::vector;
|
||||
class IdentityNGradientFunction : public GradientFunction {
|
||||
public:
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
grad_outputs->resize(grad_inputs.size(), nullptr);
|
||||
for (int i = 0; i < grad_inputs.size(); i++) {
|
||||
auto grad_input = grad_inputs[i];
|
||||
// TODO(srbs): Should we add a copy contructor to AbstractTensorHandle
|
||||
// that takes care of this similar to `Tensor`?
|
||||
if (grad_input) {
|
||||
grad_input->Ref();
|
||||
}
|
||||
(*grad_outputs)[i] = grad_input;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
~IdentityNGradientFunction() override {}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
BackwardFunction* IdentityNRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new IdentityNGradientFunction;
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
26
tensorflow/c/experimental/gradients/array_grad.h
Normal file
26
tensorflow/c/experimental/gradients/array_grad.h
Normal file
@ -0,0 +1,26 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_
|
||||
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
BackwardFunction* IdentityNRegisterer(const ForwardOperation& op);
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_
|
||||
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
|
||||
@ -29,8 +30,7 @@ namespace {
|
||||
|
||||
class AddGradientFunction : public GradientFunction {
|
||||
public:
|
||||
Status Compute(Context* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
grad_outputs->resize(2);
|
||||
vector<AbstractTensorHandle*> identity_outputs(1);
|
||||
@ -54,8 +54,7 @@ class ExpGradientFunction : public GradientFunction {
|
||||
explicit ExpGradientFunction(AbstractTensorHandle* exp) : exp_(exp) {
|
||||
exp->Ref();
|
||||
}
|
||||
Status Compute(Context* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
vector<AbstractTensorHandle*> conj_outputs(1);
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -74,12 +73,22 @@ class ExpGradientFunction : public GradientFunction {
|
||||
|
||||
} // namespace
|
||||
|
||||
GradientFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
return new AddGradientFunction;
|
||||
BackwardFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new AddGradientFunction;
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
GradientFunction* ExpRegisterer(const ForwardOperation& op) {
|
||||
return new ExpGradientFunction(op.outputs[0]);
|
||||
BackwardFunction* ExpRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new ExpGradientFunction(op.outputs[0]);
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
|
||||
@ -19,8 +19,8 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
GradientFunction* AddRegisterer(const ForwardOperation& op);
|
||||
GradientFunction* ExpRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* AddRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/platform/bfloat16.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -38,6 +38,7 @@ tf_kernel_library(
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "histogram_summary_op",
|
||||
prefix = "histogram_summary_op",
|
||||
@ -52,7 +53,6 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
tf_kernel_library(
|
||||
name = "merge_summary_op",
|
||||
prefix = "merge_summary_op",
|
||||
|
||||
@ -20,8 +20,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/selective_registration.h"
|
||||
#include "tensorflow/core/framework/summary.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/lib/histogram/histogram.h"
|
||||
#include "tensorflow/core/platform/bfloat16.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
||||
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/core/platform/default/logging.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/selective_registration.h"
|
||||
#include "tensorflow/core/framework/summary.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/platform/bfloat16.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
||||
@ -65,9 +65,24 @@ static ElementsAttr getSplat(Builder* b, Value val, T constant) {
|
||||
|
||||
// Returns DenseElementsAttr of rank zero with the given element type and the
|
||||
// value.
|
||||
// Requires `ty` to be either FloatType of IntegerType.
|
||||
// Requires `ty` to be either FloatType, IntegerType, or ComplexType.
|
||||
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
|
||||
|
||||
// Enum type used to specify scalar argument to GetScalarLimitOfType.
|
||||
enum ScalarLimit {
|
||||
kLowest, // The scalar corresponding to numeric_limits<T>::lowest.
|
||||
kInfinityLowest, // Like kMax, but returns -infinity where available.
|
||||
kMax, // The scalar corresponding to numeric_limits<T>::max.
|
||||
kInfinityMax, // Like kMax, but returns infinity where available.
|
||||
};
|
||||
|
||||
// Returns a scalar limit value for the given type.
|
||||
//
|
||||
// The argument 'limit' describes which scalar value to return.
|
||||
//
|
||||
// Requires `ty` to be either FloatType or IntegerType.
|
||||
DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit);
|
||||
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@ -60,10 +60,76 @@ DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
|
||||
if (auto float_ty = ty.dyn_cast<FloatType>()) {
|
||||
APFloat value(float_ty.getFloatSemantics(), raw_value);
|
||||
return DenseElementsAttr::get(scalar_ty, value);
|
||||
} else if (auto int_ty = ty.dyn_cast<IntegerType>()) {
|
||||
APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
|
||||
return DenseElementsAttr::get(scalar_ty, value);
|
||||
} else if (auto complex_ty = ty.dyn_cast<ComplexType>()) {
|
||||
Type complex_element_ty = complex_ty.getElementType();
|
||||
if (complex_element_ty.isF32()) {
|
||||
return DenseElementsAttr::get(
|
||||
scalar_ty, static_cast<std::complex<float>>(raw_value));
|
||||
} else if (complex_element_ty.isF64()) {
|
||||
return DenseElementsAttr::get(
|
||||
scalar_ty, static_cast<std::complex<double>>(raw_value));
|
||||
}
|
||||
}
|
||||
auto int_ty = ty.cast<IntegerType>();
|
||||
APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
|
||||
return DenseElementsAttr::get(scalar_ty, value);
|
||||
llvm_unreachable("unsupported type");
|
||||
}
|
||||
|
||||
static APFloat GetScalarLimitOfFloatType(FloatType float_ty,
|
||||
ScalarLimit limit) {
|
||||
auto &semantics = float_ty.getFloatSemantics();
|
||||
switch (limit) {
|
||||
case kLowest:
|
||||
return APFloat::getLargest(semantics, /*negative=*/true);
|
||||
case kInfinityLowest:
|
||||
return APFloat::getInf(semantics, /*negative=*/true);
|
||||
case kMax:
|
||||
return APFloat::getLargest(semantics, /*negative=*/false);
|
||||
case kInfinityMax:
|
||||
return APFloat::getInf(semantics, /*negative=*/false);
|
||||
}
|
||||
llvm_unreachable("invalid limit");
|
||||
}
|
||||
|
||||
// Returns a scalar value for the given integer type.
|
||||
//
|
||||
// The argument 'scalar' describes which scalar value to return. `integer_value`
|
||||
// is used to specify the integer value for kInteger. For any other scalar,
|
||||
// integer_value is ignored.
|
||||
static APInt GetScalarLimitOfIntegerType(IntegerType integer_ty,
|
||||
ScalarLimit limit) {
|
||||
unsigned width = integer_ty.getWidth();
|
||||
switch (limit) {
|
||||
case kLowest:
|
||||
case kInfinityLowest:
|
||||
if (integer_ty.isUnsigned()) {
|
||||
return APInt::getMinValue(width);
|
||||
} else {
|
||||
return APInt::getSignedMinValue(width);
|
||||
}
|
||||
|
||||
case kMax:
|
||||
case kInfinityMax:
|
||||
if (integer_ty.isUnsigned()) {
|
||||
return APInt::getMaxValue(width);
|
||||
} else {
|
||||
return APInt::getSignedMaxValue(width);
|
||||
}
|
||||
}
|
||||
llvm_unreachable("invalid limit");
|
||||
}
|
||||
|
||||
DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit) {
|
||||
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
|
||||
if (auto float_ty = ty.dyn_cast<FloatType>()) {
|
||||
return DenseElementsAttr::get(scalar_ty,
|
||||
GetScalarLimitOfFloatType(float_ty, limit));
|
||||
} else if (auto integer_ty = ty.dyn_cast<IntegerType>()) {
|
||||
return DenseElementsAttr::get(
|
||||
scalar_ty, GetScalarLimitOfIntegerType(integer_ty, limit));
|
||||
}
|
||||
llvm_unreachable("unsupported type");
|
||||
}
|
||||
|
||||
} // namespace hlo
|
||||
|
||||
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/register.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/InitAllDialects.h"
|
||||
@ -80,6 +81,8 @@ int main(int argc, char **argv) {
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
|
||||
// Register any pass manager command line options.
|
||||
mlir::registerAsmPrinterCLOptions();
|
||||
mlir::registerMLIRContextCLOptions();
|
||||
mlir::registerPassManagerCLOptions();
|
||||
mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");
|
||||
|
||||
|
||||
@ -29,6 +29,7 @@ filegroup(
|
||||
"ir/tfl_ops.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||
],
|
||||
@ -227,6 +228,7 @@ cc_library(
|
||||
"@llvm-project//mlir:DerivedAttributeOpInterface",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:InferTypeOpInterface",
|
||||
"@llvm-project//mlir:LoopLikeInterface",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:SideEffects",
|
||||
@ -500,6 +502,7 @@ gentbl(
|
||||
tblgen = "//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen",
|
||||
td_file = "ir/tfl_ops.td",
|
||||
td_srcs = [
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
"ir/tfl_op_interfaces.td",
|
||||
|
||||
@ -133,63 +133,59 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
||||
return Status(error::INVALID_ARGUMENT,
|
||||
"'isSigned' can only be set for 8-bits integer type");
|
||||
}
|
||||
switch (type.getKind()) {
|
||||
case mlir::StandardTypes::F32:
|
||||
return tflite::TensorType_FLOAT32;
|
||||
case mlir::StandardTypes::F16:
|
||||
return tflite::TensorType_FLOAT16;
|
||||
case mlir::StandardTypes::F64:
|
||||
return tflite::TensorType_FLOAT64;
|
||||
case mlir::TF::TensorFlowTypes::STRING:
|
||||
return tflite::TensorType_STRING;
|
||||
case mlir::TF::TensorFlowTypes::QUINT8:
|
||||
return tflite::TensorType_UINT8;
|
||||
case mlir::StandardTypes::Complex: {
|
||||
auto ftype = type.cast<mlir::ComplexType>().getElementType();
|
||||
if (ftype && ftype.isF32()) {
|
||||
return tflite::TensorType_COMPLEX64;
|
||||
}
|
||||
if (ftype && ftype.isF64()) {
|
||||
return tflite::TensorType_COMPLEX128;
|
||||
}
|
||||
return Status(error::INVALID_ARGUMENT, "Unsupported type");
|
||||
|
||||
if (type.isF32()) {
|
||||
return tflite::TensorType_FLOAT32;
|
||||
} else if (type.isF16()) {
|
||||
return tflite::TensorType_FLOAT16;
|
||||
} else if (type.isF64()) {
|
||||
return tflite::TensorType_FLOAT64;
|
||||
} else if (type.isa<mlir::TF::StringType>()) {
|
||||
return tflite::TensorType_STRING;
|
||||
} else if (type.isa<mlir::TF::Quint8Type>()) {
|
||||
return tflite::TensorType_UINT8;
|
||||
} else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
|
||||
auto ftype = complex_type.getElementType();
|
||||
if (ftype.isF32()) {
|
||||
return tflite::TensorType_COMPLEX64;
|
||||
}
|
||||
case mlir::StandardTypes::Integer: {
|
||||
const auto& itype = type.cast<mlir::IntegerType>();
|
||||
switch (itype.getWidth()) {
|
||||
case 1:
|
||||
return tflite::TensorType_BOOL;
|
||||
case 8:
|
||||
return itype.isUnsigned() ? tflite::TensorType_UINT8
|
||||
: tflite::TensorType_INT8;
|
||||
case 16:
|
||||
return tflite::TensorType_INT16;
|
||||
case 32:
|
||||
return tflite::TensorType_INT32;
|
||||
case 64:
|
||||
return tflite::TensorType_INT64;
|
||||
}
|
||||
if (ftype.isF64()) {
|
||||
return tflite::TensorType_COMPLEX128;
|
||||
}
|
||||
case mlir::quant::QuantizationTypes::UniformQuantized: {
|
||||
auto qtype = type.cast<mlir::quant::UniformQuantizedType>();
|
||||
return GetTFLiteType(qtype.getStorageType(), qtype.isSigned());
|
||||
return Status(error::INVALID_ARGUMENT, "Unsupported type");
|
||||
} else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
|
||||
switch (itype.getWidth()) {
|
||||
case 1:
|
||||
return tflite::TensorType_BOOL;
|
||||
case 8:
|
||||
return itype.isUnsigned() ? tflite::TensorType_UINT8
|
||||
: tflite::TensorType_INT8;
|
||||
case 16:
|
||||
return tflite::TensorType_INT16;
|
||||
case 32:
|
||||
return tflite::TensorType_INT32;
|
||||
case 64:
|
||||
return tflite::TensorType_INT64;
|
||||
}
|
||||
case mlir::quant::QuantizationTypes::UniformQuantizedPerAxis: {
|
||||
auto qtype = type.cast<mlir::quant::UniformQuantizedPerAxisType>();
|
||||
return GetTFLiteType(qtype.getStorageType(), qtype.isSigned());
|
||||
}
|
||||
case mlir::TF::TensorFlowTypes::RESOURCE: {
|
||||
// Treat tf.resource values as integer values in flatbuffer.
|
||||
// TODO(b/146131919): Maybe need to have a detailed design for supporting
|
||||
// other resource types beyonds hash table resources and resource
|
||||
// variables.
|
||||
return tflite::TensorType_INT32;
|
||||
}
|
||||
default:
|
||||
// TFLite export fills FLOAT32 for unknown data types. Returning an error
|
||||
// for now for safety and this could be revisited when required.
|
||||
return Status(error::INVALID_ARGUMENT, "Unsupported type");
|
||||
} else if (auto q_uniform_type =
|
||||
type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
|
||||
return GetTFLiteType(q_uniform_type.getStorageType(),
|
||||
q_uniform_type.isSigned());
|
||||
|
||||
} else if (auto q_peraxis_type =
|
||||
type.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
|
||||
return GetTFLiteType(q_peraxis_type.getStorageType(),
|
||||
q_peraxis_type.isSigned());
|
||||
} else if (type.isa<mlir::TF::ResourceType>()) {
|
||||
// Treat tf.resource values as integer values in flatbuffer.
|
||||
// TODO(b/146131919): Maybe need to have a detailed design for supporting
|
||||
// other resource types beyonds hash table resources and resource
|
||||
// variables.
|
||||
return tflite::TensorType_INT32;
|
||||
}
|
||||
// TFLite export fills FLOAT32 for unknown data types. Returning an error
|
||||
// for now for safety and this could be revisited when required.
|
||||
return Status(error::INVALID_ARGUMENT, "Unsupported type");
|
||||
}
|
||||
|
||||
static bool IsConst(Operation* op) {
|
||||
|
||||
@ -95,40 +95,34 @@ static tflite::MirrorPadMode ConvertTFL_MirrorPaddingAttrForOptionWriter(
|
||||
|
||||
static tflite::TensorType ConvertDerivedTypeAttrForOptionWriter(
|
||||
mlir::Type type, flatbuffers::FlatBufferBuilder* builder) {
|
||||
switch (type.getKind()) {
|
||||
case mlir::StandardTypes::F16:
|
||||
return tflite::TensorType_FLOAT16;
|
||||
case mlir::StandardTypes::F32:
|
||||
return tflite::TensorType_FLOAT32;
|
||||
case mlir::TF::TensorFlowTypes::STRING:
|
||||
return tflite::TensorType_STRING;
|
||||
case mlir::StandardTypes::Complex: {
|
||||
auto etype = type.cast<mlir::ComplexType>().getElementType();
|
||||
if (etype.isF32()) {
|
||||
return tflite::TensorType_COMPLEX64;
|
||||
}
|
||||
llvm_unreachable("invalid complex Type in conversion");
|
||||
if (type.isF16()) {
|
||||
return tflite::TensorType_FLOAT16;
|
||||
} else if (type.isF32()) {
|
||||
return tflite::TensorType_FLOAT32;
|
||||
} else if (type.isa<mlir::TF::StringType>()) {
|
||||
return tflite::TensorType_STRING;
|
||||
} else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
|
||||
if (complex_type.getElementType().isF32()) {
|
||||
return tflite::TensorType_COMPLEX64;
|
||||
}
|
||||
case mlir::StandardTypes::Integer: {
|
||||
const auto& itype = type.cast<mlir::IntegerType>();
|
||||
switch (itype.getWidth()) {
|
||||
case 1:
|
||||
return tflite::TensorType_BOOL;
|
||||
case 8:
|
||||
return tflite::TensorType_INT8;
|
||||
case 16:
|
||||
return tflite::TensorType_INT16;
|
||||
case 32:
|
||||
return tflite::TensorType_INT32;
|
||||
case 64:
|
||||
return tflite::TensorType_INT64;
|
||||
default:
|
||||
llvm_unreachable("invalid integer Type in conversion");
|
||||
}
|
||||
llvm_unreachable("invalid complex Type in conversion");
|
||||
} else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
|
||||
switch (itype.getWidth()) {
|
||||
case 1:
|
||||
return tflite::TensorType_BOOL;
|
||||
case 8:
|
||||
return tflite::TensorType_INT8;
|
||||
case 16:
|
||||
return tflite::TensorType_INT16;
|
||||
case 32:
|
||||
return tflite::TensorType_INT32;
|
||||
case 64:
|
||||
return tflite::TensorType_INT64;
|
||||
default:
|
||||
llvm_unreachable("invalid integer Type in conversion");
|
||||
}
|
||||
default:
|
||||
llvm_unreachable("invalid Type in conversion");
|
||||
}
|
||||
llvm_unreachable("invalid Type in conversion");
|
||||
}
|
||||
|
||||
// I32Attr already returns an int as required by flatbuffer builders.
|
||||
|
||||
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
@ -253,9 +254,8 @@ struct TensorFlowLiteInlinerInterface : public DialectInlinerInterface {
|
||||
}
|
||||
};
|
||||
|
||||
struct TensorFlowLiteOpFolderDialectInterface
|
||||
: public OpFolderDialectInterface {
|
||||
using OpFolderDialectInterface::OpFolderDialectInterface;
|
||||
struct TensorFlowLiteDialectFoldInterface : public DialectFoldInterface {
|
||||
using DialectFoldInterface::DialectFoldInterface;
|
||||
|
||||
// Registered hook to check if the given region, which is attached to an
|
||||
// operation that is *not* isolated from above (i.e. no internal regions
|
||||
@ -275,7 +275,7 @@ TensorFlowLiteDialect::TensorFlowLiteDialect(mlir::MLIRContext *context)
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
|
||||
>();
|
||||
addInterfaces<TensorFlowLiteInlinerInterface,
|
||||
TensorFlowLiteOpFolderDialectInterface>();
|
||||
TensorFlowLiteDialectFoldInterface>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1028,9 +1028,12 @@ static LogicalResult Verify(PackOp op) {
|
||||
// Check axis bounds.
|
||||
if (input_type.hasRank()) {
|
||||
int64_t axis_value = op.axis().getSExtValue();
|
||||
if (abs(axis_value) > input_type.getRank())
|
||||
return op.emitOpError("op attribute 'axis' is out of bounds, got ")
|
||||
<< axis_value;
|
||||
if (axis_value < 0) axis_value += input_type.getRank() + 1;
|
||||
if (axis_value < 0 || axis_value >= input_type.getRank() + 1)
|
||||
return op.emitOpError()
|
||||
<< "op attribute 'axis' should be in range [-rank - 1, rank + 1), "
|
||||
<< "got rank = " << input_type.getRank()
|
||||
<< ", and axis = " << op.axis().getSExtValue();
|
||||
}
|
||||
|
||||
// Make sure all inputs have the same shape and element type.
|
||||
@ -1443,12 +1446,59 @@ void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
|
||||
// TODO(b/133486129): Implement shape inference for unpack
|
||||
|
||||
static LogicalResult Verify(UnpackOp op) {
|
||||
// TODO(antiagainst): Implement other checks as in
|
||||
// tensorflow/lite/kernels/unpack.cc
|
||||
LogicalResult UnpackOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> loc, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
UnpackOpAdaptor op(operands, attributes);
|
||||
// TODO(jpienaar): Refactor inferReturnTypes.
|
||||
if (failed(op.verify(loc.hasValue() ? *loc : UnknownLoc::get(context))))
|
||||
return failure();
|
||||
|
||||
if (op.getOperation()->getNumResults() != op.num())
|
||||
return op.emitOpError("output count should match 'num' attribute");
|
||||
if (operands.size() != 1) {
|
||||
return emitOptionalError(loc, "input count should be equal to 1");
|
||||
}
|
||||
|
||||
const int64_t num_value = op.num().getInt();
|
||||
auto input_type = operands[0].getType().dyn_cast<ShapedType>();
|
||||
if (!input_type || !input_type.hasRank()) {
|
||||
// If input is unranked, then so is output.
|
||||
inferredReturnTypes.assign(
|
||||
num_value, UnrankedTensorType::get(input_type.getElementType()));
|
||||
return success();
|
||||
}
|
||||
|
||||
if (input_type.getNumElements() <= 0) {
|
||||
return emitOptionalError(
|
||||
loc, "number of elements in input shoule be larger than 0");
|
||||
}
|
||||
|
||||
const int64_t rank = input_type.getRank();
|
||||
if (rank <= 0) {
|
||||
return emitOptionalError(loc, "input should be of rank larger than 0");
|
||||
}
|
||||
|
||||
int64_t axis_value = op.axis().getInt();
|
||||
if (axis_value < 0) {
|
||||
axis_value += rank;
|
||||
}
|
||||
if (axis_value < 0 || axis_value >= rank) {
|
||||
return emitOptionalError(
|
||||
loc, "attribute 'axis' should be in range [-rank, rank), got axis = ",
|
||||
op.axis().getInt(), ", and rank = ", rank);
|
||||
}
|
||||
|
||||
if (!ShapedType::isDynamic(input_type.getDimSize(axis_value)) &&
|
||||
input_type.getDimSize(axis_value) != num_value) {
|
||||
return emitOptionalError(loc, "output count should match 'num' attribute");
|
||||
}
|
||||
|
||||
auto output_shape = llvm::to_vector<4>(input_type.getShape());
|
||||
output_shape.erase(output_shape.begin() + axis_value);
|
||||
|
||||
auto output_type =
|
||||
RankedTensorType::get(output_shape, input_type.getElementType());
|
||||
inferredReturnTypes.assign(num_value, output_type);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
|
||||
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#define TFL_OPS
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/LoopLikeInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
|
||||
@ -3028,7 +3029,8 @@ def TFL_TransposeOp : TFL_Op<"transpose", [
|
||||
def TFL_UnpackOp : TFL_Op<"unpack", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultElementType,
|
||||
SameOperandsAndResultsScale]> {
|
||||
SameOperandsAndResultsScale,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Unpacks a tensor along a dimension into multiple tensors";
|
||||
|
||||
let description = [{
|
||||
@ -3059,8 +3061,6 @@ def TFL_UnpackOp : TFL_Op<"unpack", [
|
||||
TFL_VariadicTensorOf<[F32, I1, I8, UI8, I32, QI8, QUI8, I16, QI16]>:$outputs
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
|
||||
@ -1139,9 +1139,15 @@ func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x
|
||||
|
||||
// -----
|
||||
|
||||
func @packNegInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<2x1x4xi32> {
|
||||
func @packNegInputAxis2(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x2x4xi32> {
|
||||
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<2x1x4xi32>
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x2x4xi32>
|
||||
return %0 : tensor<1x2x4xi32>
|
||||
}
|
||||
|
||||
func @packNegInputAxis3(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<2x1x4xi32> {
|
||||
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = -3 : i32, values_count = 2 : i32}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = -3 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<2x1x4xi32>
|
||||
return %0 : tensor<2x1x4xi32>
|
||||
}
|
||||
|
||||
@ -1172,7 +1178,7 @@ func @pack(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
// -----
|
||||
|
||||
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
// expected-error @+1 {{op attribute 'axis' is out of bounds, got 3}}
|
||||
// expected-error @+1 {{op attribute 'axis' should be in range [-rank - 1, rank + 1), got rank = 1, and axis = 3}}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 3 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
@ -1183,7 +1189,22 @@ func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
|
||||
// CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32}
|
||||
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
|
||||
// CHECK: "tfl.unpack"(%arg0) {axis = -1 : i32, num = 3 : i32}
|
||||
%0:3 = "tfl.unpack"(%arg0) {axis = -1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<3xi32> {
|
||||
// CHECK: "tfl.unpack"(%arg0) {axis = -2 : i32, num = 2 : i32}
|
||||
%0:2 = "tfl.unpack"(%arg0) {axis = -2 : i32, num = 2 : i32} : (tensor<2x3xi32>) -> (tensor<3xi32>, tensor<3xi32>)
|
||||
return %0#0 : tensor<3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -1204,6 +1225,45 @@ func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
|
||||
|
||||
// -----
|
||||
|
||||
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
|
||||
// expected-error @+1 {{attribute 'axis' should be in range [-rank, rank), got axis = 2, and rank = 2}}
|
||||
%0:3 = "tfl.unpack"(%arg0) {axis = 2 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
|
||||
// expected-error @+1 {{attribute 'axis' should be in range [-rank, rank), got axis = -3, and rank = 2}}
|
||||
%0:3 = "tfl.unpack"(%arg0) {axis = -3 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @unpack(%arg0: tensor<i32>) -> tensor<2xi32> {
|
||||
// expected-error @+1 {{input should be of rank larger than 0}}
|
||||
%0:3 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 3 : i32} : (tensor<i32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
|
||||
// expected-error @+1 {{op inferred type incompatible with return type of operation}}
|
||||
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2x1xi32>, tensor<2xi32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @unpack(%arg0: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) {
|
||||
%0:2 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 2 : i32} : (tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>)
|
||||
return %0#0, %0#1 : tensor<*xi32>, tensor<*xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testMean
|
||||
func @testMean(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2xf32> {
|
||||
// CHECK: "tfl.mean"(%arg0, %arg1) {keep_dims = false}
|
||||
|
||||
@ -30,80 +30,66 @@ stream_executor::port::StatusOr<ConstantOp> CreateConstOpWithSingleValue(
|
||||
Type element_type = shaped_type.getElementType();
|
||||
ShapedType scalar_type = RankedTensorType::get({}, element_type);
|
||||
Attribute attr;
|
||||
switch (element_type.getKind()) {
|
||||
case mlir::StandardTypes::F16: {
|
||||
auto floatType = mlir::FloatType::getF16(element_type.getContext());
|
||||
auto floatAttr =
|
||||
mlir::FloatAttr::get(floatType, static_cast<float>(value));
|
||||
std::vector<Attribute> floatValues({floatAttr});
|
||||
attr = DenseElementsAttr::get(scalar_type, floatValues);
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::BF16: {
|
||||
auto floatType = mlir::FloatType::getBF16(element_type.getContext());
|
||||
auto floatAttr =
|
||||
mlir::FloatAttr::get(floatType, static_cast<float>(value));
|
||||
std::vector<Attribute> floatValues({floatAttr});
|
||||
attr = DenseElementsAttr::get(scalar_type, floatValues);
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::F32: {
|
||||
attr =
|
||||
DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::Complex: {
|
||||
auto etype = element_type.cast<mlir::ComplexType>().getElementType();
|
||||
if (etype.isF32()) {
|
||||
auto dialect = etype.getContext()->getRegisteredDialect("tf");
|
||||
tensorflow::TensorProto repr;
|
||||
repr.set_dtype(tensorflow::DT_COMPLEX64);
|
||||
if (element_type.isF16()) {
|
||||
auto floatType = mlir::FloatType::getF16(element_type.getContext());
|
||||
auto floatAttr = mlir::FloatAttr::get(floatType, static_cast<float>(value));
|
||||
std::vector<Attribute> floatValues({floatAttr});
|
||||
attr = DenseElementsAttr::get(scalar_type, floatValues);
|
||||
} else if (element_type.isBF16()) {
|
||||
auto floatType = mlir::FloatType::getBF16(element_type.getContext());
|
||||
auto floatAttr = mlir::FloatAttr::get(floatType, static_cast<float>(value));
|
||||
std::vector<Attribute> floatValues({floatAttr});
|
||||
attr = DenseElementsAttr::get(scalar_type, floatValues);
|
||||
} else if (element_type.isF32()) {
|
||||
attr =
|
||||
DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
|
||||
} else if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
|
||||
auto etype = complex_type.getElementType();
|
||||
if (etype.isF32()) {
|
||||
auto dialect = etype.getContext()->getRegisteredDialect("tf");
|
||||
tensorflow::TensorProto repr;
|
||||
repr.set_dtype(tensorflow::DT_COMPLEX64);
|
||||
|
||||
tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
|
||||
shape->set_unknown_rank(false);
|
||||
shape->add_dim()->set_size(int64_t{1});
|
||||
std::string content;
|
||||
auto complex_value =
|
||||
std::complex<float>(static_cast<float>(value), 0.0f);
|
||||
content.assign(reinterpret_cast<const char*>(&complex_value),
|
||||
sizeof(complex_value));
|
||||
repr.set_tensor_content(content);
|
||||
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
|
||||
tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
|
||||
shape->set_unknown_rank(false);
|
||||
shape->add_dim()->set_size(int64_t{1});
|
||||
std::string content;
|
||||
auto complex_value = std::complex<float>(static_cast<float>(value), 0.0f);
|
||||
content.assign(reinterpret_cast<const char*>(&complex_value),
|
||||
sizeof(complex_value));
|
||||
repr.set_tensor_content(content);
|
||||
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
|
||||
|
||||
attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
|
||||
attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
|
||||
} else {
|
||||
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
|
||||
"Unsupported type");
|
||||
}
|
||||
} else if (auto itype = element_type.dyn_cast<mlir::IntegerType>()) {
|
||||
switch (itype.getWidth()) {
|
||||
case 8:
|
||||
attr = DenseElementsAttr::get<int8_t>(scalar_type,
|
||||
static_cast<int8_t>(value));
|
||||
break;
|
||||
}
|
||||
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
|
||||
"Unsupported type");
|
||||
case 16:
|
||||
attr = DenseElementsAttr::get<int16_t>(scalar_type,
|
||||
static_cast<int16_t>(value));
|
||||
break;
|
||||
case 32:
|
||||
attr = DenseElementsAttr::get<int32_t>(scalar_type,
|
||||
static_cast<int32_t>(value));
|
||||
break;
|
||||
case 64:
|
||||
attr = DenseElementsAttr::get<int64_t>(scalar_type,
|
||||
static_cast<int64_t>(value));
|
||||
break;
|
||||
default:
|
||||
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
|
||||
"Unsupported type");
|
||||
}
|
||||
case mlir::StandardTypes::Integer: {
|
||||
const auto& itype = element_type.cast<mlir::IntegerType>();
|
||||
switch (itype.getWidth()) {
|
||||
case 8:
|
||||
attr = DenseElementsAttr::get<int8_t>(scalar_type,
|
||||
static_cast<int8_t>(value));
|
||||
break;
|
||||
case 16:
|
||||
attr = DenseElementsAttr::get<int16_t>(scalar_type,
|
||||
static_cast<int16_t>(value));
|
||||
break;
|
||||
case 32:
|
||||
attr = DenseElementsAttr::get<int32_t>(scalar_type,
|
||||
static_cast<int32_t>(value));
|
||||
break;
|
||||
case 64:
|
||||
attr = DenseElementsAttr::get<int64_t>(scalar_type,
|
||||
static_cast<int64_t>(value));
|
||||
break;
|
||||
default:
|
||||
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
|
||||
"Unsupported type");
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
|
||||
"Unsupported type");
|
||||
} else {
|
||||
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
|
||||
"Unsupported type");
|
||||
}
|
||||
return rewriter->create<ConstantOp>(loc, scalar_type, attr);
|
||||
}
|
||||
|
||||
@ -32,7 +32,10 @@ void init_types(py::module& m) {
|
||||
[](mlir::FunctionType& ft) { return ft.getResults().vec(); });
|
||||
|
||||
py::class_<mlir::FloatType, mlir::Type>(m, "FloatType")
|
||||
.def("get", &mlir::FloatType::get);
|
||||
.def("getBF16", &mlir::FloatType::getBF16)
|
||||
.def("getF16", &mlir::FloatType::getF16)
|
||||
.def("getF32", &mlir::FloatType::getF32)
|
||||
.def("getF64", &mlir::FloatType::getF64);
|
||||
|
||||
py::class_<mlir::IntegerType, mlir::Type>(m, "IntegerType")
|
||||
.def("get", py::overload_cast<unsigned, mlir::MLIRContext*>(
|
||||
|
||||
@ -722,6 +722,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
@ -787,6 +788,7 @@ cc_library(
|
||||
"transforms/tpu_extract_head_tail_outside_compilation.cc",
|
||||
"transforms/tpu_extract_outside_compilation.cc",
|
||||
"transforms/tpu_host_computation_expansion.cc",
|
||||
"transforms/tpu_identity_pruning.cc",
|
||||
"transforms/tpu_merge_variables_with_execute.cc",
|
||||
"transforms/tpu_outside_compilation_cluster.cc",
|
||||
"transforms/tpu_rewrite_pass.cc",
|
||||
@ -1269,7 +1271,7 @@ cc_library(
|
||||
name = "tf_dialect_passes",
|
||||
srcs = [
|
||||
"transforms/constant_fold.cc",
|
||||
"transforms/dialect_hooks.cc",
|
||||
"transforms/decode_attributes_hook.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/constant_fold.h",
|
||||
@ -1632,6 +1634,7 @@ cc_library(
|
||||
deps = [
|
||||
":lower_tf_inc_gen",
|
||||
":tensorflow",
|
||||
":tensorflow_ops",
|
||||
":tensorflow_types",
|
||||
"//tensorflow/core:framework",
|
||||
"@llvm-project//llvm:Support",
|
||||
|
||||
@ -21,11 +21,13 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/SCCIterator.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/CallGraph.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
@ -35,6 +37,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
@ -134,12 +137,46 @@ class BacktrackAnalysis {
|
||||
return GetAnalysisForRegion(region);
|
||||
}
|
||||
|
||||
// Returns the backtrack analysis for the given region if it exists.
|
||||
// If the region has not yet been analyzed, returns llvm::None.
|
||||
Optional<const InfoT*> GetAnalysisIfExists(Region& region) const {
|
||||
auto it = info_map_.find(®ion);
|
||||
if (it == info_map_.end()) return llvm::None;
|
||||
return &it->second;
|
||||
}
|
||||
|
||||
Optional<const InfoT*> GetAnalysisIfExists(FuncOp func) const {
|
||||
return GetAnalysisIfExists(func.getBody());
|
||||
}
|
||||
|
||||
private:
|
||||
llvm::SmallDenseMap<Region*, InfoT> info_map_;
|
||||
};
|
||||
|
||||
// Analyzes all regions attached to all operations in the module.
|
||||
BacktrackAnalysis::BacktrackAnalysis(ModuleOp module) {
|
||||
const CallGraph call_graph(module);
|
||||
|
||||
// Visit functions bottom up when doing the analysis. Note that SCC iterator
|
||||
// has the property that if there is an edge from SCC1->SCC2, SCC1 is visited
|
||||
// after SCC2, i.e., the graph is traversed bottom up just the way we want.
|
||||
auto scc_begin = llvm::scc_begin(&call_graph);
|
||||
auto scc_end = llvm::scc_end(&call_graph);
|
||||
for (auto& scc : make_range(scc_begin, scc_end)) {
|
||||
// Each SCC node is a collection of callgraph nodes that form a cycle. We
|
||||
// will visit these nodes in an arbitrary order. If a node being visited
|
||||
// calls a function that has not yet been analyzed, we will not be able to
|
||||
// backtrack through that function call (our analysis will be correct but
|
||||
// pessimistic).
|
||||
for (CallGraphNode* node : scc) {
|
||||
if (node->isExternal()) continue;
|
||||
Region* region = node->getCallableRegion();
|
||||
GetOrCreateAnalysis(*region);
|
||||
}
|
||||
}
|
||||
|
||||
// This above call graph analysis will cover all regions attached to functions
|
||||
// but we also need to analyze regions attached to other ops.
|
||||
module.walk([this](Operation* op) {
|
||||
for (Region& region : op->getRegions()) GetOrCreateAnalysis(region);
|
||||
});
|
||||
@ -160,6 +197,18 @@ Value BacktrackAnalysis::BacktrackValue(Value value) {
|
||||
value = island.GetYield().getOperand(res_index);
|
||||
} else if (isa<IdentityNOp, IdentityOp>(op)) {
|
||||
value = op->getOperand(res_index);
|
||||
} else if (auto call = dyn_cast<CallOpInterface>(op)) {
|
||||
FuncOp func = dyn_cast<FuncOp>(call.resolveCallable());
|
||||
if (!func) break;
|
||||
// Check if the function being called has been analyzed. if not,
|
||||
// we cannot backtrack the value further.
|
||||
Optional<const InfoT*> callee_info = GetAnalysisIfExists(func);
|
||||
if (!callee_info) break;
|
||||
Optional<int> passthrough_arg = callee_info.getValue()->GetArg(res_index);
|
||||
if (!passthrough_arg) break;
|
||||
value = call.getArgOperands()[passthrough_arg.getValue()];
|
||||
} else if (isa<tf_device::LaunchOp, tf_device::ClusterOp>(op)) {
|
||||
value = op->getRegion(0).front().getTerminator()->getOperand(res_index);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
@ -359,6 +408,13 @@ ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
|
||||
AddValueUniqueIDMapping(result, kUnknownResourceId);
|
||||
}
|
||||
}
|
||||
} else if (isa<tf_device::LaunchOp, tf_device::ClusterOp>(op)) {
|
||||
Region& region = op->getRegion(0);
|
||||
const auto& body_info = backtrack_analysis.GetAnalysisForRegion(region);
|
||||
for (auto result : filter_resources(op->getResults())) {
|
||||
Value body_result = body_info.GetValue(result.getResultNumber());
|
||||
PropagateInputToOutput(body_result, result);
|
||||
}
|
||||
} else {
|
||||
assign_unknown_id_to_all(op->getResults());
|
||||
}
|
||||
|
||||
@ -41,6 +41,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
|
||||
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
@ -184,11 +185,18 @@ class MlirAbstractOp : public TracingOperation {
|
||||
}
|
||||
|
||||
private:
|
||||
// Return true is there are still unfilled ODS slots for adding more inputs.
|
||||
bool IsNextODSArgAvailable();
|
||||
|
||||
MLIRContext* context_;
|
||||
MlirFunctionContext* function_context_;
|
||||
SmallVector<Value, 8> operands_;
|
||||
llvm::StringMap<Attribute> attrs_;
|
||||
std::unique_ptr<OperationState> state_;
|
||||
// This is the index of the next ODS operand that will be added with AddInput
|
||||
// or AddInput;
|
||||
int current_ods_input_ = 0;
|
||||
const tensorflow::OpDef* op_def_ = nullptr;
|
||||
const char* op_name_ = nullptr;
|
||||
string tf_op_type_;
|
||||
// TODO(srbs): Use this.
|
||||
@ -267,6 +275,10 @@ Status MlirAbstractOp::Reset(const char* op, const char* device_name) {
|
||||
return tensorflow::errors::FailedPrecondition(
|
||||
"Reset called on already built op.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::OpRegistry::Global()->LookUpOpDef(op, &op_def_));
|
||||
assert(op_def_);
|
||||
|
||||
tf_op_type_ = op;
|
||||
std::string name = "tf.";
|
||||
name += op;
|
||||
@ -315,45 +327,17 @@ Status MlirAbstractOp::AddRef(Type type, Type* output_type) {
|
||||
Status MlirAbstractOp::Create(ArrayRef<Value> operands,
|
||||
OperationState** state) {
|
||||
state_->operands = llvm::to_vector<4>(operands);
|
||||
const tensorflow::OpDef* op_def;
|
||||
auto node_name = state_->name.getStringRef().drop_front(
|
||||
TensorFlowDialect::getDialectNamespace().size() + 1);
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::OpRegistry::Global()->LookUpOpDef(node_name.str(), &op_def));
|
||||
Builder builder(context_);
|
||||
// Process operands according to the op_def and infer derived attributes.
|
||||
int current_operand = 0;
|
||||
for (const tensorflow::OpDef::ArgDef& input_arg : op_def->input_arg()) {
|
||||
if (!input_arg.number_attr().empty()) {
|
||||
// TODO(b/156122856): we don't support variadic operands.
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"Unsupported 'number_attr' for '", input_arg.number_attr(), "'");
|
||||
} else if (!input_arg.type_list_attr().empty()) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Unsupported 'type_list_attr' for '", input_arg.number_attr(), "'");
|
||||
}
|
||||
if (current_operand >= operands.size()) {
|
||||
return tensorflow::errors::InvalidArgument("Missing operand for '",
|
||||
input_arg.name(), "'");
|
||||
}
|
||||
Type expected_type;
|
||||
if (input_arg.type() != tensorflow::DT_INVALID) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertDataTypeToTensor(input_arg.type(), builder, &expected_type));
|
||||
Type output_type;
|
||||
if (input_arg.is_ref())
|
||||
TF_RETURN_IF_ERROR(AddRef(expected_type, &output_type));
|
||||
expected_type = output_type;
|
||||
} else {
|
||||
expected_type = operands[current_operand].getType();
|
||||
}
|
||||
if (!input_arg.type_attr().empty()) {
|
||||
attrs_[input_arg.type_attr()] = TypeAttr::get(expected_type);
|
||||
}
|
||||
++current_operand;
|
||||
}
|
||||
|
||||
for (const tensorflow::OpDef::ArgDef& output_arg : op_def->output_arg()) {
|
||||
if (current_ods_input_ != op_def_->input_arg_size())
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
absl::StrCat("Mismatch in operands number: got ", current_ods_input_,
|
||||
" expected ", op_def_->input_arg_size(), " ; for op ",
|
||||
state_->name.getStringRef().str()));
|
||||
|
||||
// Process results according to the op_def and infer types for derived
|
||||
// attributes.
|
||||
for (const tensorflow::OpDef::ArgDef& output_arg : op_def_->output_arg()) {
|
||||
int original_size = state_->types.size();
|
||||
if (!output_arg.number_attr().empty()) {
|
||||
// Same type repeated "repeats" times.
|
||||
@ -605,12 +589,38 @@ Status MlirFunctionContext::AddParameter(tensorflow::DataType dtype,
|
||||
}
|
||||
|
||||
Status MlirAbstractOp::AddInput(AbstractTensorHandle* input) {
|
||||
if (current_ods_input_ >= op_def_->input_arg_size())
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
absl::StrCat("More Input() (", current_ods_input_, ") calls than the ",
|
||||
op_def_->input_arg_size(), " allowed input_args ; for op ",
|
||||
state_->name.getStringRef().str()));
|
||||
|
||||
auto* operand = dyn_cast<MlirTensor>(input);
|
||||
if (!operand) {
|
||||
if (!operand)
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Unable to cast input to MlirTensor");
|
||||
}
|
||||
operands_.push_back(operand->getValue());
|
||||
|
||||
// Get the next ArgDef and use it to infer the derived attributes associated
|
||||
// to this input.
|
||||
const tensorflow::OpDef::ArgDef& arg_def =
|
||||
op_def_->input_arg(current_ods_input_++);
|
||||
Type expected_type;
|
||||
if (arg_def.type() != tensorflow::DT_INVALID) {
|
||||
Builder builder(context_);
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::ConvertDataType(arg_def.type(), builder, &expected_type));
|
||||
if (arg_def.is_ref()) {
|
||||
Type output_type;
|
||||
TF_RETURN_IF_ERROR(AddRef(expected_type, &output_type));
|
||||
expected_type = output_type;
|
||||
}
|
||||
} else {
|
||||
expected_type = operands_.back().getType();
|
||||
}
|
||||
if (!arg_def.type_attr().empty())
|
||||
attrs_[arg_def.type_attr()] = TypeAttr::get(expected_type);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
Status MlirFunctionContext::Finalize(OutputList* outputs,
|
||||
|
||||
@ -54,9 +54,6 @@ namespace tf_executor {
|
||||
|
||||
namespace {
|
||||
|
||||
using TF::DropRefType;
|
||||
using TF::DropTypeSubTypes;
|
||||
|
||||
struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
|
||||
@ -75,9 +72,8 @@ struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
|
||||
}
|
||||
};
|
||||
|
||||
struct TensorFlowExecutorOpFolderDialectInterface
|
||||
: public OpFolderDialectInterface {
|
||||
using OpFolderDialectInterface::OpFolderDialectInterface;
|
||||
struct TensorFlowExecutorDialectFoldInterface : public DialectFoldInterface {
|
||||
using DialectFoldInterface::DialectFoldInterface;
|
||||
|
||||
// Registered hook to check if the given region, which is attached to an
|
||||
// operation that is *not* isolated from above (i.e. no internal regions
|
||||
@ -100,7 +96,7 @@ TensorFlowExecutorDialect::TensorFlowExecutorDialect(MLIRContext *context)
|
||||
>();
|
||||
|
||||
addInterfaces<TensorFlowExecutorInlinerInterface,
|
||||
TensorFlowExecutorOpFolderDialectInterface>();
|
||||
TensorFlowExecutorDialectFoldInterface>();
|
||||
|
||||
addTypes<ControlType, TokenType>();
|
||||
}
|
||||
@ -551,8 +547,8 @@ LogicalResult Verify(SwitchNOp switchn) {
|
||||
<< operand0_tensor_type << " vs " << output_tensor_type;
|
||||
}
|
||||
Type broadcasted_type = OpTrait::util::getBroadcastedType(
|
||||
DropRefType(DropTypeSubTypes(operand0_tensor_type)),
|
||||
DropRefType(DropTypeSubTypes(output_tensor_type)));
|
||||
TF::DropRefAndSubTypes(operand0_tensor_type),
|
||||
TF::DropRefAndSubTypes(output_tensor_type));
|
||||
if (!broadcasted_type) {
|
||||
return switchn.emitOpError()
|
||||
<< "expects data operand to be broadcastable with all output types"
|
||||
@ -668,8 +664,8 @@ LogicalResult Verify(MergeOp merge) {
|
||||
<< operand_tensor_ty << " vs " << output_tensor_ty;
|
||||
}
|
||||
Type broadcasted_type = OpTrait::util::getBroadcastedType(
|
||||
DropRefType(DropTypeSubTypes(output_tensor_ty)),
|
||||
DropRefType(DropTypeSubTypes(operand_tensor_ty)));
|
||||
TF::DropRefAndSubTypes(output_tensor_ty),
|
||||
TF::DropRefAndSubTypes(operand_tensor_ty));
|
||||
if (!broadcasted_type)
|
||||
return merge.emitOpError()
|
||||
<< "expects all operands to be broadcastable with output type"
|
||||
|
||||
@ -136,7 +136,7 @@ Inputs must be of same size and shape.
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef, TF_CwiseBinary]>,
|
||||
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x + y element-wise.";
|
||||
|
||||
@ -859,15 +859,15 @@ about broadcasting
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Complex128, TF_Complex64]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Complex128, TF_Complex64]>:$y,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$adj_y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Complex128, TF_Complex64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -1404,6 +1404,38 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CholeskyOp : TF_Op<"Cholesky", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Computes the Cholesky decomposition of one or more square matrices.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
|
||||
form square matrices.
|
||||
|
||||
The input has to be symmetric and positive definite. Only the lower-triangular
|
||||
part of the input will be used for this operation. The upper-triangular part
|
||||
will not be read.
|
||||
|
||||
The output is a tensor of the same shape as the input
|
||||
containing the Cholesky decompositions for all input submatrices `[..., :, :]`.
|
||||
|
||||
**Note**: The gradient computation on GPU is faster for large matrices but
|
||||
not for large batch dimensions when the submatrices are small. In this
|
||||
case it might be faster to use the CPU.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> {
|
||||
let summary = "Clips tensor values to a specified min and max.";
|
||||
|
||||
@ -2025,17 +2057,73 @@ and `B, D, F, H` as group 1. Thus we get the outputs:
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F32, I32, TF_Uint32]>:$input,
|
||||
TensorOf<[BF16, F16, F32, I32, TF_Uint32]>:$input,
|
||||
I32Tensor:$group_assignment
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F32, I32, TF_Uint32]>:$output
|
||||
TensorOf<[BF16, F16, F32, I32, TF_Uint32]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CumprodOp : TF_Op<"Cumprod", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> {
|
||||
let summary = [{
|
||||
Compute the cumulative product of the tensor `x` along `axis`.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
By default, this op performs an inclusive cumprod, which means that the first
|
||||
element of the input is identical to the first element of the output:
|
||||
|
||||
```python
|
||||
tf.cumprod([a, b, c]) # => [a, a * b, a * b * c]
|
||||
```
|
||||
|
||||
By setting the `exclusive` kwarg to `True`, an exclusive cumprod is
|
||||
performed instead:
|
||||
|
||||
```python
|
||||
tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b]
|
||||
```
|
||||
|
||||
By setting the `reverse` kwarg to `True`, the cumprod is performed in the
|
||||
opposite direction:
|
||||
|
||||
```python
|
||||
tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c]
|
||||
```
|
||||
|
||||
This is more efficient than using separate `tf.reverse` ops.
|
||||
|
||||
The `reverse` and `exclusive` kwargs can also be combined:
|
||||
|
||||
```python
|
||||
tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1]
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
|
||||
TF_I32OrI64Tensor:$axis,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$exclusive,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$reverse
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$out
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> {
|
||||
let summary = "Compute the cumulative sum of the tensor `x` along `axis`.";
|
||||
|
||||
@ -2084,6 +2172,10 @@ tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0]
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_DataFormatDimMapOp : TF_Op<"DataFormatDimMap", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
@ -2109,6 +2201,40 @@ the source data format.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_DebugIdentityV2Op : TF_Op<"DebugIdentityV2", []> {
|
||||
let summary = "Debug Identity V2 Op.";
|
||||
|
||||
let description = [{
|
||||
Provides an identity mapping from input to output, while writing the content of
|
||||
the input tensor by calling DebugEventsWriter.
|
||||
|
||||
The semantics of the input tensor depends on tensor_debug_mode. In typical
|
||||
usage, the input tensor comes directly from the user computation only when
|
||||
graph_debug_mode is FULL_TENSOR (see protobuf/debug_event.proto for a
|
||||
list of all the possible values of graph_debug_mode). For the other debug modes,
|
||||
the input tensor should be produced by an additional op or subgraph that
|
||||
computes summary information about one or more tensors.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$input,
|
||||
|
||||
StrAttr:$tfdbg_context_id,
|
||||
StrAttr:$op_name,
|
||||
DefaultValuedAttr<I64Attr, "-1">:$output_slot,
|
||||
DefaultValuedAttr<I64Attr, "-1">:$tensor_debug_mode,
|
||||
DefaultValuedAttr<StrArrayAttr, "{}">:$debug_urls,
|
||||
DefaultValuedAttr<I64Attr, "1000">:$circular_buffer_size,
|
||||
StrAttr:$tfdbg_run_id
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_DecodeAndCropJpegOp : TF_Op<"DecodeAndCropJpeg", [NoSideEffect]> {
|
||||
let summary = "Decode and Crop a JPEG-encoded image to a uint8 tensor.";
|
||||
|
||||
@ -2421,6 +2547,40 @@ this op runs. The length of the list is returned in two cases:
|
||||
);
|
||||
}
|
||||
|
||||
def TF_DiagOp : TF_Op<"Diag", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "Returns a diagonal tensor with a given diagonal values.";
|
||||
|
||||
let description = [{
|
||||
Given a `diagonal`, this operation returns a tensor with the `diagonal` and
|
||||
everything else padded with zeros. The diagonal is computed as follows:
|
||||
|
||||
Assume `diagonal` has dimensions [D1,..., Dk], then the output is a tensor of
|
||||
rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where:
|
||||
|
||||
`output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik]` and 0 everywhere else.
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
# 'diagonal' is [1, 2, 3, 4]
|
||||
tf.diag(diagonal) ==> [[1, 0, 0, 0]
|
||||
[0, 2, 0, 0]
|
||||
[0, 0, 3, 0]
|
||||
[0, 0, 0, 4]]
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$diagonal
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_DiagPartOp : TF_Op<"DiagPart", [NoSideEffect]> {
|
||||
let summary = "Returns the diagonal part of the tensor.";
|
||||
|
||||
@ -4185,6 +4345,22 @@ tf.imag(input) ==> [4.75, 5.75]
|
||||
TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_InfeedDequeueOp : TF_Op<"InfeedDequeue", []> {
|
||||
let summary = [{
|
||||
A placeholder op for a value that will be fed into the computation.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_ShapeAttr:$shape
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_InitializeTableFromTextFileV2Op : TF_Op<"InitializeTableFromTextFileV2", []> {
|
||||
let summary = "Initializes a table from a text file.";
|
||||
|
||||
@ -5673,6 +5849,74 @@ tf.matrix_set_diag(input, diagonals, k = (-1, 2), align="LEFT_RIGHT")
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_MatrixTriangularSolveOp : TF_Op<"MatrixTriangularSolve", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Solves systems of linear equations with upper or lower triangular matrices by backsubstitution.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
`matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form
|
||||
square matrices. If `lower` is `True` then the strictly upper triangular part
|
||||
of each inner-most matrix is assumed to be zero and not accessed.
|
||||
If `lower` is False then the strictly lower triangular part of each inner-most
|
||||
matrix is assumed to be zero and not accessed.
|
||||
`rhs` is a tensor of shape `[..., M, N]`.
|
||||
|
||||
The output is a tensor of shape `[..., M, N]`. If `adjoint` is
|
||||
`True` then the innermost matrices in `output` satisfy matrix equations
|
||||
`matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
|
||||
If `adjoint` is `False` then the strictly then the innermost matrices in
|
||||
`output` satisfy matrix equations
|
||||
`adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`.
|
||||
|
||||
Note, the batch shapes for the inputs only need to broadcast.
|
||||
|
||||
Example:
|
||||
```python
|
||||
|
||||
a = tf.constant([[3, 0, 0, 0],
|
||||
[2, 1, 0, 0],
|
||||
[1, 0, 1, 0],
|
||||
[1, 1, 1, 1]], dtype=tf.float32)
|
||||
|
||||
b = tf.constant([[4],
|
||||
[2],
|
||||
[4],
|
||||
[2]], dtype=tf.float32)
|
||||
|
||||
x = tf.linalg.triangular_solve(a, b, lower=True)
|
||||
x
|
||||
# <tf.Tensor: shape=(4, 1), dtype=float32, numpy=
|
||||
# array([[ 1.3333334 ],
|
||||
# [-0.66666675],
|
||||
# [ 2.6666665 ],
|
||||
# [-1.3333331 ]], dtype=float32)>
|
||||
|
||||
# in python3 one can use `a@x`
|
||||
tf.matmul(a, x)
|
||||
# <tf.Tensor: shape=(4, 1), dtype=float32, numpy=
|
||||
# array([[4. ],
|
||||
# [2. ],
|
||||
# [4. ],
|
||||
# [1.9999999]], dtype=float32)>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$matrix,
|
||||
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$rhs,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "true">:$lower,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$adjoint
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_MaxOp : TF_Op<"Max", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Computes the maximum of elements across dimensions of a tensor.
|
||||
@ -5823,7 +6067,7 @@ def TF_MergeSummaryOp : TF_Op<"MergeSummary", [NoSideEffect, SameOperandsAndResu
|
||||
|
||||
let description = [{
|
||||
This op creates a
|
||||
[`Summary`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/summary.proto)
|
||||
[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
|
||||
protocol buffer that contains the union of all the values in the input
|
||||
summaries.
|
||||
|
||||
@ -6054,7 +6298,7 @@ the result here is consistent with a truncating divide. E.g.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef, TF_CwiseBinary]>,
|
||||
def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_SameOperandsAndResultElementTypeResolveRef]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x * y element-wise.";
|
||||
|
||||
@ -7215,9 +7459,6 @@ def TF_RangeDatasetOp : TF_Op<"RangeDataset", []> {
|
||||
Creates a dataset with a range of values. Corresponds to python's xrange.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I64Tensor:$start,
|
||||
I64Tensor:$stop,
|
||||
@ -9464,7 +9705,7 @@ I.e., \\(y = x * x = x^2\\).
|
||||
|
||||
def TF_SquaredDifferenceOp : TF_Op<"SquaredDifference", [Commutative, NoSideEffect, ResultsBroadcastableShape]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns (x - y)(x - y) element-wise.";
|
||||
let summary = "Returns conj(x - y)(x - y) element-wise.";
|
||||
|
||||
let description = [{
|
||||
*NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting
|
||||
@ -9576,6 +9817,49 @@ def TF_StackV2Op : TF_Op<"StackV2", []> {
|
||||
);
|
||||
}
|
||||
|
||||
def TF_StatelessMultinomialOp : TF_Op<"StatelessMultinomial", [NoSideEffect]> {
|
||||
let summary = "Draws samples from a multinomial distribution.";
|
||||
|
||||
let arguments = (ins
|
||||
TF_IntOrFpTensor:$logits,
|
||||
I32Tensor:$num_samples,
|
||||
TF_I32OrI64Tensor:$seed
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_I32OrI64Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<2>;
|
||||
TF_DerivedResultTypeAttr output_dtype = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_StatelessRandomNormalOp : TF_Op<"StatelessRandomNormal", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Outputs deterministic pseudorandom values from a normal distribution.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
The generated values will have mean 0 and standard deviation 1.
|
||||
|
||||
The outputs are a deterministic function of `shape` and `seed`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_I32OrI64Tensor:$shape,
|
||||
TF_I32OrI64Tensor:$seed
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_FpTensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_StatelessRandomUniformOp : TF_Op<"StatelessRandomUniform", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Outputs deterministic pseudorandom random values from a uniform distribution.
|
||||
@ -9602,6 +9886,33 @@ The outputs are a deterministic function of `shape` and `seed`.
|
||||
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_StatelessRandomUniformIntOp : TF_Op<"StatelessRandomUniformInt", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Outputs deterministic pseudorandom random integers from a uniform distribution.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
The generated values follow a uniform distribution in the range `[minval, maxval)`.
|
||||
|
||||
The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxval`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_I32OrI64Tensor:$shape,
|
||||
TF_I32OrI64Tensor:$seed,
|
||||
TF_I32OrI64Tensor:$minval,
|
||||
TF_I32OrI64Tensor:$maxval
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_I32OrI64Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>;
|
||||
}
|
||||
|
||||
def TF_StatelessTruncatedNormalOp : TF_Op<"StatelessTruncatedNormal", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Outputs deterministic pseudorandom values from a truncated normal distribution.
|
||||
@ -9871,7 +10182,7 @@ Examples:
|
||||
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef, TF_CwiseBinary]>,
|
||||
def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_SameOperandsAndResultElementTypeResolveRef]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x - y element-wise.";
|
||||
|
||||
@ -9926,6 +10237,25 @@ retained with length 1.
|
||||
>];
|
||||
}
|
||||
|
||||
def TF_SymbolicGradientOp : TF_Op<"SymbolicGradient", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Computes the gradient function for function f via backpropagation.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_Tensor>:$input,
|
||||
|
||||
SymbolRefAttr:$f
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
|
||||
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TPUCompilationResultOp : TF_Op<"TPUCompilationResult", [NoSideEffect]> {
|
||||
let summary = "Returns the result of a TPU compilation.";
|
||||
|
||||
@ -11901,6 +12231,13 @@ https://www.tensorflow.org/performance/xla/operation_semantics#pad
|
||||
def TF_XlaRecvFromHostOp : TF_Op<"XlaRecvFromHost", []> {
|
||||
let summary = "An op to receive a tensor from the host.";
|
||||
|
||||
let description = [{
|
||||
output: the tensor that will be received from the host.
|
||||
Toutput: element type for output.
|
||||
shape: shape for output.
|
||||
key: A unique identifier for this region used to match up host transfers.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_ShapeAttr:$shape,
|
||||
StrAttr:$key
|
||||
@ -11945,6 +12282,31 @@ def TF_XlaReplicaIdOp : TF_Op<"XlaReplicaId", [NoSideEffect]> {
|
||||
);
|
||||
}
|
||||
|
||||
def TF_XlaScatterOp : TF_Op<"XlaScatter", [NoSideEffect]> {
|
||||
let summary = "Wraps the XLA Scatter operator documented at";
|
||||
|
||||
let description = [{
|
||||
https://www.tensorflow.org/xla/operation_semantics#scatter.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$operand,
|
||||
TF_I32OrI64Tensor:$scatter_indices,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$updates,
|
||||
|
||||
SymbolRefAttr:$update_computation,
|
||||
StrAttr:$dimension_numbers,
|
||||
BoolAttr:$indices_are_sorted
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_XlaSelfAdjointEigOp : TF_Op<"XlaSelfAdjointEig", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Computes the eigen decomposition of a batch of self-adjoint matrices
|
||||
@ -11977,6 +12339,12 @@ i=0...N-1.
|
||||
def TF_XlaSendToHostOp : TF_Op<"XlaSendToHost", []> {
|
||||
let summary = "An op to send a tensor to the host.";
|
||||
|
||||
let description = [{
|
||||
input: the tensor that will be sent to the host.
|
||||
Tinput: element type for input.
|
||||
key: A unique identifier for this region used to match up host transfers.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$input,
|
||||
|
||||
@ -12183,18 +12551,17 @@ Compiles a computations for execution on one or more TPU devices.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
For the internal use of the distributed TPU compiler. Note that currently only
|
||||
single TPU device is supported.
|
||||
For the internal use of the distributed TPU compiler.
|
||||
|
||||
'mlir_module' is a serialized MLIR module with a `main` function that contains
|
||||
target computation.
|
||||
'dynamic_shapes' contains dynamic shapes of arguments whose shapes were not
|
||||
known statically at TPUReplication rewrite time.
|
||||
'metadata' is a serialized TPUCompileMetadataProto describing
|
||||
the shapes and types of the inputs to the computation, as well as a mapping onto
|
||||
the TPU pod topology.
|
||||
'program' output is a string key that is passed to the _TPUExecute op and
|
||||
used to look up the program in the compilation cache.
|
||||
'metadata' is a serialized TPUCompileMetadataProto describing the shapes and
|
||||
types of the inputs to the computation, as well as a mapping onto the TPU pod
|
||||
topology.
|
||||
'program' output is a string key that is passed to the TPUExecute op and used to
|
||||
look up the program in the compilation cache.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
@ -12231,6 +12598,28 @@ rewrite passes must replace this op with a _TPUCompileMlir op `program` output.
|
||||
);
|
||||
}
|
||||
|
||||
def TF__UnaryOpsCompositionOp : TF_Op<"_UnaryOpsComposition", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = [{
|
||||
*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
expected to create these operators.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F16, F32, F64]>:$x,
|
||||
|
||||
StrArrayAttr:$op_names
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F16, F32, F64]>:$y
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF__XlaHostComputeMlirOp : TF_Op<"_XlaHostComputeMlir", []> {
|
||||
let summary = [{
|
||||
A pseudo-op to represent host-side computation in an XLA program.
|
||||
|
||||
@ -55,6 +55,8 @@ limitations under the License.
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/DecodeAttributesInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/FoldInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
@ -112,6 +114,22 @@ bool HasSingleUse(FuncOp func) {
|
||||
return true;
|
||||
}
|
||||
|
||||
struct TFConstantFoldInterface : public DialectFoldInterface {
|
||||
TFConstantFoldInterface(Dialect *dialect) : DialectFoldInterface(dialect) {}
|
||||
LogicalResult Fold(Operation *op, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<OpFoldResult> &results) const final {
|
||||
return TensorFlowDialect::constantFold(op, operands, results);
|
||||
}
|
||||
};
|
||||
|
||||
struct TFDecodeAttributesInterface : public DialectDecodeAttributesInterface {
|
||||
TFDecodeAttributesInterface(Dialect *dialect)
|
||||
: DialectDecodeAttributesInterface(dialect) {}
|
||||
LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) const {
|
||||
return TensorFlowDialect::decode(input, output);
|
||||
}
|
||||
};
|
||||
|
||||
struct TFInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
|
||||
@ -206,6 +224,9 @@ std::vector<TensorFlowDialect::AdditionalOpFunction>
|
||||
*TensorFlowDialect::additional_operation_hooks_ =
|
||||
new std::vector<TensorFlowDialect::AdditionalOpFunction>();
|
||||
|
||||
TensorFlowDialect::ConstantFoldHook TensorFlowDialect::constant_fold_hook_;
|
||||
TensorFlowDialect::DecodeConstantHook TensorFlowDialect::decode_constant_hook_;
|
||||
|
||||
TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
|
||||
: Dialect(/*name=*/"tf", context, TypeID::get<TensorFlowDialect>()) {
|
||||
addOperations<
|
||||
@ -217,7 +238,8 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
|
||||
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
|
||||
>();
|
||||
addInterfaces<TFInlinerInterface>();
|
||||
addInterfaces<TFInlinerInterface, TFDecodeAttributesInterface,
|
||||
TFConstantFoldInterface>();
|
||||
addAttributes<ShapeAttr, FuncAttr>();
|
||||
|
||||
// Support unknown operations because not all TensorFlow operations are
|
||||
@ -385,20 +407,20 @@ Type TensorFlowDialect::parseType(DialectAsmParser &parser) const {
|
||||
// Prints a type registered to this dialect.
|
||||
void TensorFlowDialect::printType(Type ty, DialectAsmPrinter &os) const {
|
||||
assert(ty.isa<TensorFlowType>());
|
||||
switch (ty.getKind()) {
|
||||
default:
|
||||
llvm_unreachable("unexpected tensorflow type kind");
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
|
||||
case TensorFlowTypes::enumerant: \
|
||||
os << name; \
|
||||
break;
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
|
||||
if (auto derived_ty = ty.dyn_cast<tftype##Type>()) { \
|
||||
os << name; \
|
||||
return; \
|
||||
}
|
||||
#define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name) \
|
||||
case TensorFlowTypes::enumerant: \
|
||||
Print##tftype##Type(ty.cast<tftype##Type>(), os); \
|
||||
break;
|
||||
if (auto derived_ty = ty.dyn_cast<tftype##Type>()) { \
|
||||
Print##tftype##Type(derived_ty, os); \
|
||||
return; \
|
||||
}
|
||||
// NOLINTNEXTLINE
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
|
||||
}
|
||||
|
||||
llvm_unreachable("unexpected tensorflow type kind");
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
@ -116,10 +116,35 @@ class TensorFlowDialect : public Dialect {
|
||||
0, (addOperation(AbstractOperation::get<Args>(*this)), 0)...};
|
||||
}
|
||||
|
||||
using ConstantFoldHook = LogicalResult (*)(Operation *, ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &);
|
||||
static void RegisterConstantFoldHook(ConstantFoldHook fn) {
|
||||
constant_fold_hook_ = std::move(fn);
|
||||
}
|
||||
|
||||
static LogicalResult constantFold(Operation *op, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
if (constant_fold_hook_) return constant_fold_hook_(op, operands, results);
|
||||
return failure();
|
||||
}
|
||||
|
||||
using DecodeConstantHook = LogicalResult (*)(OpaqueElementsAttr input,
|
||||
ElementsAttr &output);
|
||||
static void RegisterDecodeConstantHook(DecodeConstantHook fn) {
|
||||
decode_constant_hook_ = std::move(fn);
|
||||
}
|
||||
static LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) {
|
||||
if (decode_constant_hook_) return decode_constant_hook_(input, output);
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
// Hook functions which may add additional operations to the dialect.
|
||||
// These are invoked at construction time.
|
||||
static std::vector<AdditionalOpFunction> *additional_operation_hooks_;
|
||||
|
||||
static ConstantFoldHook constant_fold_hook_;
|
||||
static DecodeConstantHook decode_constant_hook_;
|
||||
};
|
||||
|
||||
} // namespace TF
|
||||
|
||||
@ -97,10 +97,10 @@ An n-way switch statement, implementing the following:
|
||||
Variadic<TF_Tensor>:$input,
|
||||
|
||||
Confined<SymbolRefArrayAttr, [ArrayMinCount<1>]>:$branches,
|
||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
|
||||
|
||||
// Used to map StatelessCase and Case to a common op.
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_stateless
|
||||
// Used to map StatelessCase and Case op defined in TensorFlow to a common
|
||||
// op.
|
||||
BoolAttr:$is_stateless
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@ -109,10 +109,55 @@ An n-way switch statement, implementing the following:
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
|
||||
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
|
||||
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_CaseRegionOp : TF_Op<"CaseRegion",
|
||||
[SingleBlockImplicitTerminator<"YieldOp">, NoRegionArguments]> {
|
||||
let summary = [{
|
||||
An n-way switch statement which calls a single branch function.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
An n-way switch statement, implementing the following:
|
||||
```
|
||||
switch (branch_index) {
|
||||
case 0:
|
||||
output = branches[0](input);
|
||||
break;
|
||||
case 1:
|
||||
output = branches[1](input);
|
||||
break;
|
||||
...
|
||||
case [[nbranches-1]]:
|
||||
default:
|
||||
output = branches[nbranches-1](input);
|
||||
break;
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I32Tensor:$branch_index,
|
||||
|
||||
// Used to map StatelessCase and Case op defined in TensorFlow to a common
|
||||
// op.
|
||||
BoolAttr:$is_stateless
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$output
|
||||
);
|
||||
|
||||
let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
// In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with
|
||||
// its type encoding the tensor's shape and data type.
|
||||
def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect,
|
||||
@ -292,7 +337,7 @@ else_branch: A function that takes 'inputs' and returns a list of
|
||||
}
|
||||
|
||||
def TF_YieldOp : TF_Op<"Yield",
|
||||
[Terminator, ParentOneOf<["IfRegionOp", "WhileRegionOp"]>]> {
|
||||
[Terminator, ParentOneOf<["CaseRegionOp", "IfRegionOp", "WhileRegionOp"]>]> {
|
||||
let summary = "Yield operation";
|
||||
|
||||
let description = [{
|
||||
|
||||
@ -477,7 +477,47 @@ LogicalResult FoldConstantCaseOp::matchAndRewrite(
|
||||
|
||||
void CaseOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<FoldConstantCaseOp>(context);
|
||||
results.insert<FoldConstantCaseOp, DropAttributes<CaseOp>>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CaseRegionOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO(lyandy): Extract similar checks for CaseOp.
|
||||
static LogicalResult Verify(CaseRegionOp op) {
|
||||
if (op.branches().empty())
|
||||
return op.emitOpError() << "expects to have at least 1 region";
|
||||
|
||||
if (!IsOfRankOrUnranked(op.branch_index(), 0))
|
||||
return op.emitOpError() << "expects 'branch_index' to be a scalar, but got "
|
||||
<< op.branch_index().getType();
|
||||
|
||||
DenseIntElementsAttr branch_index_attr;
|
||||
if (matchPattern(op.branch_index(), m_Constant(&branch_index_attr))) {
|
||||
assert(branch_index_attr.getNumElements() == 1);
|
||||
int64_t branch_index = branch_index_attr.getSplatValue<IntegerAttr>()
|
||||
.getValue()
|
||||
.getSExtValue();
|
||||
if (branch_index < 0)
|
||||
return op.emitOpError()
|
||||
<< "expects 'branch_index' to be non-negative, but got "
|
||||
<< branch_index;
|
||||
|
||||
if (branch_index >= op.branches().size())
|
||||
return op.emitOpError()
|
||||
<< "expects 'branch_index' to be less than the number of regions ("
|
||||
<< op.branches().size() << "), but got " << branch_index;
|
||||
}
|
||||
|
||||
for (auto region_and_idx : llvm::enumerate(op.branches())) {
|
||||
std::string region_name =
|
||||
llvm::formatv("region #{0}", region_and_idx.index()).str();
|
||||
if (failed(VerifyRegionResults(op, region_and_idx.value(), region_name)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -734,6 +774,35 @@ void ConcatV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CumsumOp and CumprodOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename OpT, typename std::enable_if<llvm::is_one_of<
|
||||
OpT, CumsumOp, CumprodOp>::value>::type * = nullptr>
|
||||
static LogicalResult Verify(OpT op) {
|
||||
if (!IsOfRankOrUnranked(op.axis(), 0))
|
||||
return op.emitOpError("requires scalar axis operand");
|
||||
|
||||
DenseIntElementsAttr axis_attr;
|
||||
if (matchPattern(op.axis(), m_Constant(&axis_attr))) {
|
||||
auto input_ty = op.x().getType().template dyn_cast<RankedTensorType>();
|
||||
if (input_ty) {
|
||||
int64_t rank = input_ty.getRank();
|
||||
assert(axis_attr.getNumElements() == 1 &&
|
||||
"scalar attribute should have exactly one element");
|
||||
int64_t axis = (*axis_attr.begin()).getSExtValue();
|
||||
if (axis < -rank || axis >= rank) {
|
||||
return op.emitError()
|
||||
<< "axis operand should be within range [" << -rank << ", "
|
||||
<< rank << "); actual value: " << axis;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConcatOffsetOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -33,7 +33,7 @@ namespace TF {
|
||||
static inline LogicalResult VerifyRefTypeMatch(mlir::Type type,
|
||||
mlir::Type maybe_ref_type) {
|
||||
if (auto ref_type = maybe_ref_type.dyn_cast<mlir::TF::TensorFlowRefType>())
|
||||
return success(ref_type.RemoveRef().getKind() == type.getKind());
|
||||
return success(ref_type.RemoveRef().getTypeID() == type.getTypeID());
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "mlir/Dialect/Traits.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
|
||||
@ -100,7 +101,7 @@ mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
|
||||
if (a == b) return a;
|
||||
}
|
||||
}
|
||||
if (a.getKind() != b.getKind()) return nullptr;
|
||||
if (a.getTypeID() != b.getTypeID()) return nullptr;
|
||||
|
||||
// If either is not a type that contain subtypes then the types are not cast
|
||||
// compatible.
|
||||
@ -178,127 +179,116 @@ ResultShapeIterator::ResultShapeIterator(Operation::result_iterator it)
|
||||
// TF types helper functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool TensorFlowType::classof(Type type) {
|
||||
return type.getDialect().getNamespace() == "tf";
|
||||
}
|
||||
bool TensorFlowRefType::classof(Type type) {
|
||||
return type.isa<
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name)
|
||||
#define HANDLE_TF_REF_TYPE(tftype, enumerant, name) tftype##Type,
|
||||
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
|
||||
// NOLINTNEXTLINE
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
|
||||
>();
|
||||
}
|
||||
bool TensorFlowTypeWithSubtype::classof(Type type) {
|
||||
return type.isa<ResourceType, VariantType>();
|
||||
}
|
||||
|
||||
TensorFlowType TensorFlowRefType::get(Type type) {
|
||||
MLIRContext* ctx = type.getContext();
|
||||
switch (getElementTypeOrSelf(type).getKind()) {
|
||||
case StandardTypes::F16:
|
||||
return HalfRefType::get(ctx);
|
||||
case StandardTypes::F32:
|
||||
return FloatRefType::get(ctx);
|
||||
case StandardTypes::F64:
|
||||
return DoubleRefType::get(ctx);
|
||||
case StandardTypes::BF16:
|
||||
return Bfloat16RefType::get(ctx);
|
||||
case StandardTypes::Complex: {
|
||||
const auto& etype = type.cast<ComplexType>().getElementType();
|
||||
switch (getElementTypeOrSelf(etype).getKind()) {
|
||||
case StandardTypes::F32:
|
||||
return Complex64RefType::get(ctx);
|
||||
case StandardTypes::F64:
|
||||
return Complex128RefType::get(ctx);
|
||||
default:
|
||||
llvm_unreachable("unexpected complex type");
|
||||
}
|
||||
type = getElementTypeOrSelf(type);
|
||||
if (type.isF16()) {
|
||||
return HalfRefType::get(ctx);
|
||||
} else if (type.isF32()) {
|
||||
return FloatRefType::get(ctx);
|
||||
} else if (type.isF64()) {
|
||||
return DoubleRefType::get(ctx);
|
||||
} else if (type.isBF16()) {
|
||||
return Bfloat16RefType::get(ctx);
|
||||
} else if (auto complex_type = type.dyn_cast<ComplexType>()) {
|
||||
Type etype = complex_type.getElementType();
|
||||
if (etype.isF32()) {
|
||||
return Complex64RefType::get(ctx);
|
||||
} else if (etype.isF64()) {
|
||||
return Complex128RefType::get(ctx);
|
||||
}
|
||||
case StandardTypes::Integer: {
|
||||
const auto& itype = type.cast<IntegerType>();
|
||||
switch (itype.getWidth()) {
|
||||
case 1:
|
||||
return BoolRefType::get(ctx);
|
||||
case 8:
|
||||
return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx))
|
||||
: Int8RefType::get(ctx);
|
||||
case 16:
|
||||
return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx))
|
||||
: Int16RefType::get(ctx);
|
||||
case 32:
|
||||
return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx))
|
||||
: Int32RefType::get(ctx);
|
||||
case 64:
|
||||
return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx))
|
||||
: Int64RefType::get(ctx);
|
||||
default:
|
||||
llvm_unreachable("unexpected integer type");
|
||||
}
|
||||
llvm_unreachable("unexpected complex type");
|
||||
} else if (auto itype = type.dyn_cast<IntegerType>()) {
|
||||
switch (itype.getWidth()) {
|
||||
case 1:
|
||||
return BoolRefType::get(ctx);
|
||||
case 8:
|
||||
return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx))
|
||||
: Int8RefType::get(ctx);
|
||||
case 16:
|
||||
return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx))
|
||||
: Int16RefType::get(ctx);
|
||||
case 32:
|
||||
return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx))
|
||||
: Int32RefType::get(ctx);
|
||||
case 64:
|
||||
return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx))
|
||||
: Int64RefType::get(ctx);
|
||||
default:
|
||||
llvm_unreachable("unexpected integer type");
|
||||
}
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
|
||||
case TensorFlowTypes::enumerant: \
|
||||
}
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
|
||||
if (auto derived_ty = type.dyn_cast<tftype##Type>()) \
|
||||
return tftype##RefType::get(ctx);
|
||||
|
||||
#define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
|
||||
// NOLINTNEXTLINE
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
|
||||
default:
|
||||
llvm_unreachable("unexpected type kind");
|
||||
}
|
||||
llvm_unreachable("unexpected type kind");
|
||||
}
|
||||
|
||||
Type TensorFlowRefType::RemoveRef() {
|
||||
MLIRContext* ctx = getContext();
|
||||
switch (getKind()) {
|
||||
case TensorFlowTypes::HALF_REF:
|
||||
return mlir::FloatType::getF16(ctx);
|
||||
case TensorFlowTypes::FLOAT_REF:
|
||||
return mlir::FloatType::getF32(ctx);
|
||||
case TensorFlowTypes::DOUBLE_REF:
|
||||
return mlir::FloatType::getF64(ctx);
|
||||
case TensorFlowTypes::BFLOAT16_REF:
|
||||
return mlir::FloatType::getBF16(ctx);
|
||||
case TensorFlowTypes::BOOL_REF:
|
||||
return mlir::IntegerType::get(1, ctx);
|
||||
case TensorFlowTypes::INT8_REF:
|
||||
return mlir::IntegerType::get(8, ctx);
|
||||
case TensorFlowTypes::INT16_REF:
|
||||
return mlir::IntegerType::get(16, ctx);
|
||||
case TensorFlowTypes::INT32_REF:
|
||||
return mlir::IntegerType::get(32, ctx);
|
||||
case TensorFlowTypes::INT64_REF:
|
||||
return mlir::IntegerType::get(64, ctx);
|
||||
case TensorFlowTypes::UINT8_REF:
|
||||
return mlir::IntegerType::get(8, IntegerType::Unsigned, ctx);
|
||||
case TensorFlowTypes::UINT16_REF:
|
||||
return mlir::IntegerType::get(16, IntegerType::Unsigned, ctx);
|
||||
case TensorFlowTypes::UINT32_REF:
|
||||
return mlir::IntegerType::get(32, IntegerType::Unsigned, ctx);
|
||||
case TensorFlowTypes::UINT64_REF:
|
||||
return mlir::IntegerType::get(64, IntegerType::Unsigned, ctx);
|
||||
case TensorFlowTypes::COMPLEX64_REF:
|
||||
return mlir::ComplexType::get(mlir::FloatType::getF32(ctx));
|
||||
case TensorFlowTypes::COMPLEX128_REF:
|
||||
return mlir::ComplexType::get(mlir::FloatType::getF64(ctx));
|
||||
if (isa<HalfRefType>()) return mlir::FloatType::getF16(ctx);
|
||||
if (isa<FloatRefType>()) return mlir::FloatType::getF32(ctx);
|
||||
if (isa<DoubleRefType>()) return mlir::FloatType::getF64(ctx);
|
||||
if (isa<Bfloat16RefType>()) return mlir::FloatType::getBF16(ctx);
|
||||
if (isa<BoolRefType>()) return mlir::IntegerType::get(1, ctx);
|
||||
if (isa<Int8RefType>()) return mlir::IntegerType::get(8, ctx);
|
||||
if (isa<Int16RefType>()) return mlir::IntegerType::get(16, ctx);
|
||||
if (isa<Int32RefType>()) return mlir::IntegerType::get(32, ctx);
|
||||
if (isa<Int64RefType>()) return mlir::IntegerType::get(64, ctx);
|
||||
if (isa<Uint8RefType>())
|
||||
return mlir::IntegerType::get(8, IntegerType::Unsigned, ctx);
|
||||
if (isa<Uint16RefType>())
|
||||
return mlir::IntegerType::get(16, IntegerType::Unsigned, ctx);
|
||||
if (isa<Uint32RefType>())
|
||||
return mlir::IntegerType::get(32, IntegerType::Unsigned, ctx);
|
||||
if (isa<Uint64RefType>())
|
||||
return mlir::IntegerType::get(64, IntegerType::Unsigned, ctx);
|
||||
if (isa<Complex64RefType>())
|
||||
return mlir::ComplexType::get(mlir::FloatType::getF32(ctx));
|
||||
if (isa<Complex128RefType>())
|
||||
return mlir::ComplexType::get(mlir::FloatType::getF64(ctx));
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
|
||||
case TensorFlowTypes::enumerant##_REF: \
|
||||
return tftype##Type::get(ctx);
|
||||
if (isa<tftype##RefType>()) return tftype##Type::get(ctx);
|
||||
|
||||
#define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
|
||||
// NOLINTNEXTLINE
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
|
||||
default:
|
||||
llvm_unreachable("unexpected tensorflow ref type kind");
|
||||
}
|
||||
llvm_unreachable("unexpected tensorflow ref type kind");
|
||||
}
|
||||
|
||||
Type TensorFlowTypeWithSubtype::RemoveSubtypes() {
|
||||
MLIRContext* ctx = getContext();
|
||||
switch (getKind()) {
|
||||
case TensorFlowTypes::VARIANT:
|
||||
return VariantType::get(ctx);
|
||||
case TensorFlowTypes::RESOURCE:
|
||||
return ResourceType::get(ctx);
|
||||
default:
|
||||
llvm_unreachable("unexpected tensorflow type with subtypes kind");
|
||||
}
|
||||
if (isa<VariantType>()) return VariantType::get(ctx);
|
||||
if (isa<ResourceType>()) return ResourceType::get(ctx);
|
||||
llvm_unreachable("unexpected tensorflow type with subtypes kind");
|
||||
}
|
||||
|
||||
ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes() {
|
||||
switch (getKind()) {
|
||||
case TensorFlowTypes::VARIANT:
|
||||
return this->cast<VariantType>().getSubtypes();
|
||||
case TensorFlowTypes::RESOURCE:
|
||||
return this->cast<ResourceType>().getSubtypes();
|
||||
default:
|
||||
llvm_unreachable("unexpected tensorflow type with subtypes kind");
|
||||
}
|
||||
if (auto variant_type = dyn_cast<VariantType>())
|
||||
return variant_type.getSubtypes();
|
||||
if (auto resource_type = dyn_cast<ResourceType>())
|
||||
return resource_type.getSubtypes();
|
||||
llvm_unreachable("unexpected tensorflow type with subtypes kind");
|
||||
}
|
||||
|
||||
// TODO(jpienaar): BroadcastCompatible and HasCompatibleElementTypes have
|
||||
@ -306,8 +296,11 @@ ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes() {
|
||||
bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
|
||||
if (lhs.size() != rhs.size()) return false;
|
||||
for (auto types : llvm::zip(lhs, rhs)) {
|
||||
auto lhs_type = std::get<0>(types);
|
||||
auto rhs_type = std::get<1>(types);
|
||||
// Drop ref types because they don't affect broadcast compatibility. E.g.,
|
||||
// `tensor<!tf.f32ref>` and `tensor<f32>` should be considered broadcast
|
||||
// compatible.
|
||||
auto lhs_type = DropRefType(std::get<0>(types));
|
||||
auto rhs_type = DropRefType(std::get<1>(types));
|
||||
|
||||
// This should be true for all TF ops:
|
||||
auto lhs_tt = lhs_type.dyn_cast<TensorType>();
|
||||
@ -366,27 +359,31 @@ bool AreCastCompatible(ArrayRef<Type> types) {
|
||||
return true;
|
||||
}
|
||||
|
||||
ShapedType DropTypeSubTypes(ShapedType ty) {
|
||||
Type element_ty = ty.getElementType();
|
||||
auto subtype_ty = element_ty.dyn_cast<TF::TensorFlowTypeWithSubtype>();
|
||||
if (!subtype_ty) return ty;
|
||||
// Assumes a function `GetDefaultTypeOf(ComposedType)` that returns the default
|
||||
// type for a composed type (such as a ref type or a type with subtypes).
|
||||
template <typename ComposedType>
|
||||
Type DropTypeHelper(Type ty) {
|
||||
Type element_ty = getElementTypeOrSelf(ty);
|
||||
auto composed_type = element_ty.dyn_cast<ComposedType>();
|
||||
if (!composed_type) return ty;
|
||||
|
||||
Type default_ty = GetDefaultTypeOf(subtype_ty);
|
||||
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
|
||||
|
||||
return UnrankedTensorType::get(default_ty);
|
||||
Type default_ty = GetDefaultTypeOf(composed_type);
|
||||
if (auto ranked_ty = ty.dyn_cast<RankedTensorType>()) {
|
||||
return RankedTensorType::get(ranked_ty.getShape(), default_ty);
|
||||
} else if (ty.dyn_cast<UnrankedTensorType>()) {
|
||||
return UnrankedTensorType::get(default_ty);
|
||||
} else {
|
||||
return default_ty;
|
||||
}
|
||||
}
|
||||
|
||||
ShapedType DropRefType(ShapedType ty) {
|
||||
Type element_ty = ty.getElementType();
|
||||
TF::TensorFlowRefType ref_ty = element_ty.dyn_cast<TF::TensorFlowRefType>();
|
||||
if (!ref_ty) return ty;
|
||||
|
||||
Type default_ty = TF::GetDefaultTypeOf(ref_ty);
|
||||
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
|
||||
|
||||
return UnrankedTensorType::get(default_ty);
|
||||
Type DropSubTypes(Type ty) {
|
||||
return DropTypeHelper<TF::TensorFlowTypeWithSubtype>(ty);
|
||||
}
|
||||
|
||||
Type DropRefType(Type ty) { return DropTypeHelper<TF::TensorFlowRefType>(ty); }
|
||||
|
||||
Type DropRefAndSubTypes(Type ty) { return DropRefType(DropSubTypes(ty)); }
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
@ -83,10 +83,7 @@ class TensorFlowType : public Type {
|
||||
using Type::Type;
|
||||
|
||||
// Support method to enable LLVM-style type casting.
|
||||
static bool classof(Type type) {
|
||||
return type.getKind() >= Type::FIRST_TENSORFLOW_TYPE &&
|
||||
type.getKind() <= TensorFlowTypes::LAST_USED_TENSORFLOW_TYPE;
|
||||
}
|
||||
static bool classof(Type type);
|
||||
};
|
||||
|
||||
// Returns true if the specified type is a valid TensorFlow element type.
|
||||
@ -130,10 +127,7 @@ class TensorFlowRefType : public TensorFlowType {
|
||||
using TensorFlowType::TensorFlowType;
|
||||
|
||||
// Checks if a type is TensorFlow Ref type.
|
||||
static bool classof(Type type) {
|
||||
return type.getKind() >= TensorFlowTypes::FLOAT_REF &&
|
||||
type.getKind() <= TensorFlowTypes::LAST_USED_TENSORFLOW_TYPE;
|
||||
}
|
||||
static bool classof(Type type);
|
||||
|
||||
// Converts a type to the corresponding TensorFlowRef type.
|
||||
static TensorFlowType get(Type type);
|
||||
@ -263,10 +257,7 @@ class TensorFlowTypeWithSubtype : public TensorFlowType {
|
||||
using TensorFlowType::TensorFlowType;
|
||||
|
||||
// Checks if a type is TensorFlow type with subtypes.
|
||||
static bool classof(Type type) {
|
||||
return type.getKind() == TensorFlowTypes::VARIANT ||
|
||||
type.getKind() == TensorFlowTypes::RESOURCE;
|
||||
}
|
||||
static bool classof(Type type);
|
||||
|
||||
// Converts a TypeWithSubtype type to the same type but without its subtypes.
|
||||
Type RemoveSubtypes();
|
||||
@ -325,15 +316,21 @@ bool HasCompatibleElementTypes(Type lhs, Type rhs,
|
||||
// compatible.
|
||||
bool AreCastCompatible(ArrayRef<Type> types);
|
||||
|
||||
// If the given tensor has elements of type with subtypes, then returns a new
|
||||
// type after dropping subtypes info. Otherwise, returns the original type as
|
||||
// is.
|
||||
ShapedType DropTypeSubTypes(ShapedType ty);
|
||||
// If `ty` is a tensor type and its element type has subtypes, then returns a
|
||||
// new type of same shape but dropped subtypes for the element type.
|
||||
// Otherwise, if `ty` has subtypes, then returns corresponding type with dropped
|
||||
// subtypes.
|
||||
// Otherwise, returns the original type `ty`.
|
||||
Type DropSubTypes(Type ty);
|
||||
|
||||
// If the given tensor has elements of type ref, then returns a new type
|
||||
// of the shape, but corresponding non-ref type as element type. Otherwise,
|
||||
// returns the original type as is.
|
||||
ShapedType DropRefType(ShapedType ty);
|
||||
// If `ty` is a tensor type and has elements of a ref type, then returns a new
|
||||
// type of same shape but corresponding non-ref type as element type.
|
||||
// Otherwise, if `ty` is a ref type, then returns corresponding non-ref type.
|
||||
// Otherwise, returns the original type `ty`.
|
||||
Type DropRefType(Type ty);
|
||||
|
||||
// Convenience call for executing both `DropRefType` and `DropSubTypes`.
|
||||
Type DropRefAndSubTypes(Type ty);
|
||||
|
||||
} // end namespace TF
|
||||
} // end namespace mlir
|
||||
|
||||
@ -834,11 +834,11 @@ func @foldCase(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
|
||||
// CHECK: PartitionedCall
|
||||
// CHECK-SAME: device = "noodle"
|
||||
// CHECK-SAME: f = @add
|
||||
%4 = "tf.Case"(%2, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], device = "noodle"} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%4 = "tf.Case"(%2, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], device = "noodle", is_stateless = false} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: PartitionedCall
|
||||
// CHECK-SAME: _cluster_launch = "not_ready"
|
||||
// CHECK-SAME: f = @sub
|
||||
%5 = "tf.Case"(%3, %4, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], _cluster_launch = "not_ready"} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%5 = "tf.Case"(%3, %4, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], _cluster_launch = "not_ready", is_stateless = false} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %5 : tensor<f32>
|
||||
}
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ module {
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
|
||||
%index = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
|
||||
%input = "tf.opB"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
|
||||
%result = "tf.Case"(%index, %input) {branches = [@branch_0, @branch_1, @branch_2, @branch_3, @branch_4]} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%result = "tf.Case"(%index, %input) {branches = [@branch_0, @branch_1, @branch_2, @branch_3, @branch_4], is_stateless = false} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
tf_executor.yield %result : tensor<i32>
|
||||
}
|
||||
tf_executor.fetch %output : tensor<i32>
|
||||
|
||||
@ -479,6 +479,7 @@ func @DynamicStitch_duplicates(%arg0: tensor<2x2xf32>) -> tensor<1x2xf32> {
|
||||
return %0 : tensor<1x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @Reciprocal
|
||||
func @Reciprocal(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
@ -486,6 +487,7 @@ func @Reciprocal(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @ScatterNd
|
||||
func @ScatterNd(%arg0: tensor<4x1xi32>, %arg1: tensor<4xf32>) -> tensor<8xf32> {
|
||||
// CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>} : () -> tensor<8xf32>
|
||||
// CHECK: "tf.TensorScatterUpdate"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32>
|
||||
@ -494,3 +496,16 @@ func @ScatterNd(%arg0: tensor<4x1xi32>, %arg1: tensor<4xf32>) -> tensor<8xf32> {
|
||||
%0 = "tf.ScatterNd"(%arg0, %arg1, %shape) : (tensor<4x1xi32>, tensor<4xf32>, tensor<1xi32>) -> tensor<8xf32>
|
||||
return %0 : tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @_UnaryOpsComposition
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<4xf32>
|
||||
func @_UnaryOpsComposition(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||
|
||||
// CHECK: %[[RESULT0:.*]] = "tf.Asin"(%[[ARG0]])
|
||||
// CHECK: %[[RESULT1:.*]] = "tf.Abs"(%[[RESULT0]])
|
||||
// CHECK: %[[RESULT2:.*]] = "tf.Log"(%[[RESULT1]])
|
||||
// CHECK: return %[[RESULT2]]
|
||||
|
||||
%0 = "tf._UnaryOpsComposition"(%arg0) {op_names = ["Asin", "Abs", "Log"]} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
@ -136,6 +136,7 @@ func @if_region_captured_string(%arg0: tensor<i1>, %arg1: tensor<!tf.string>) ->
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK: "tf.IfRegion"
|
||||
// CHECK: "tf.StringToNumber"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK: _xla_outside_compilation = "auto", is_stateless = true
|
||||
%1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%2 = "tf.IfRegion"(%arg0) ( {
|
||||
|
||||
@ -43,7 +43,7 @@ func @main() {
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: }
|
||||
%1:2 = tf_executor.island wraps "tf.Case"(%0#0) {Tin = [], Tout = ["tfdtype$DT_FLOAT"], branches = [@foo, @bar], device = "", output_shapes = []} : (tensor<i32>) -> tensor<*xf32> loc("Case")
|
||||
%1:2 = tf_executor.island wraps "tf.Case"(%0#0) {Tin = [], Tout = ["tfdtype$DT_FLOAT"], branches = [@foo, @bar], device = "", output_shapes = [], is_stateless = false} : (tensor<i32>) -> tensor<*xf32> loc("Case")
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
|
||||
@ -232,3 +232,55 @@ func @while_region_aliasing(%arg0: !tf_res, %arg1: !tf_res, %arg2: !tf_res) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// Test aliasing through calls
|
||||
!tf_res = type tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
|
||||
// CHECK-LABEL: func @aliasing_through_calls
|
||||
func @aliasing_through_calls(%arg0: tensor<32xf32>) -> () {
|
||||
// expected-remark@below {{Result #0, ID 0 : 0, 1, 2}}
|
||||
%vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res
|
||||
// expected-remark@below {{Result #0, ID 1 : Unknown}}
|
||||
// expected-remark@below {{Result #1, ID 2 : 0, 1, 2}}
|
||||
%c:2 = call @passthru(%vh0) : (!tf_res) -> (!tf_res, !tf_res)
|
||||
return
|
||||
}
|
||||
|
||||
// expected-remark@below {{Region #0, Arg #0, ID 1 : 1}}
|
||||
func @passthru(%arg0: !tf_res) -> (!tf_res, !tf_res) {
|
||||
// expected-remark@below {{Result #0, ID 0 : 0}}
|
||||
%vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res
|
||||
return %vh0, %arg0 : !tf_res, !tf_res
|
||||
}
|
||||
|
||||
// -----
|
||||
// Test aliasing through tf_device.launch
|
||||
!tf_res = type tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
|
||||
// CHECK-LABEL: func @aliasing_through_launch
|
||||
func @aliasing_through_launch(%arg0: tensor<32xf32>) {
|
||||
// expected-remark@below {{Result #0, ID 0 : 0, 1}}
|
||||
%vh = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> !tf_res
|
||||
|
||||
// expected-remark@below {{Result #0, ID 1 : 0, 1}}
|
||||
%launch = "tf_device.launch"() ({
|
||||
tf_device.return %vh : !tf_res
|
||||
}) {device = ""} : () -> !tf_res
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// Test aliasing through tf_device.cluster
|
||||
!tf_res = type tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
|
||||
// CHECK-LABEL: func @aliasing_through_cluster
|
||||
func @aliasing_through_cluster(%arg0: tensor<32xf32>) {
|
||||
// expected-remark@below {{Result #0, ID 0 : 0, 1}}
|
||||
%vh = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> !tf_res
|
||||
|
||||
// expected-remark@below {{Result #0, ID 1 : 0, 1}}
|
||||
%cluster = "tf_device.cluster"() ({
|
||||
tf_device.return %vh : !tf_res
|
||||
}) : () -> !tf_res
|
||||
return
|
||||
}
|
||||
|
||||
@ -424,3 +424,117 @@ func @propagate_if_region_inlined(
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Test propagation through WhileRegion (inlined calls)
|
||||
// CHECK-LABEL: func @propagate_while_region_inlined
|
||||
func @propagate_while_region_inlined(
|
||||
%arg0: !tf_res {tf.device = "/TPU:0"},
|
||||
%arg1: tensor<i32>) {
|
||||
tf_executor.graph {
|
||||
// CHECK: tf_executor.island
|
||||
%island = tf_executor.island {
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-SAME: {device = "/TPU:0"}
|
||||
%id0 = "tf.Identity"(%arg0) : (!tf_res) -> !tf_res
|
||||
// CHECK-NEXT: "tf.VarHandleOp"
|
||||
%var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"} : () -> !tf_res
|
||||
// CHECK-NEXT: "tf.WhileRegion"
|
||||
"tf.WhileRegion"(%arg1, %id0, %var_handle) ({
|
||||
^bb0(%carg0: tensor<i32>, %carg1: !tf_res, %carg2: !tf_res):
|
||||
// CHECK: ^bb
|
||||
// CHECK: "tf.Identity"
|
||||
// CHECK-SAME: {device = "/TPU:0"}
|
||||
%cid0 = "tf.Identity"(%carg1) : (!tf_res) -> !tf_res loc("cid0")
|
||||
%read = "tf.ReadVariableOp"(%cid0) : (!tf_res) -> tensor<32xf32>
|
||||
%cst = constant dense<3.0> : tensor<32xf32>
|
||||
%cmp = "tf.Less"(%read, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xi1>
|
||||
%dims = constant dense<0> : tensor<1xi32>
|
||||
%reduce = "tf.All"(%cmp, %dims) {keep_dims = false} : (tensor<32xi1>, tensor<1xi32>) -> tensor<i1>
|
||||
"tf.Yield"(%reduce) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%barg0: tensor<i32>, %barg1: !tf_res, %barg2: !tf_res):
|
||||
// CHECK: ^bb
|
||||
// CHECK: "tf.Identity"
|
||||
// CHECK-SAME: {device = "/TPU:0"}
|
||||
%bid0 = "tf.Identity"(%barg1) : (!tf_res) -> !tf_res
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-SAME: {device = "/TPU:1"}
|
||||
%id1 = "tf.Identity"(%barg2) : (!tf_res) -> !tf_res
|
||||
"tf.Yield"(%barg0, %bid0, %id1) : (tensor<i32>, !tf_res,!tf_res) -> ()
|
||||
}){is_stateless = false}
|
||||
: (tensor<i32>, !tf_res, !tf_res) -> (tensor<i32>, !tf_res, !tf_res)
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Test propagation through WhileRegion (non-inlined calls)
|
||||
// CHECK-LABEL: func @propagate_while_region
|
||||
func @propagate_while_region(
|
||||
%arg0: !tf_res {tf.device = "/TPU:0"},
|
||||
%arg1: tensor<i32>) {
|
||||
tf_executor.graph {
|
||||
// CHECK: tf_executor.island
|
||||
%island = tf_executor.island {
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-SAME: {device = "/TPU:0"}
|
||||
%id0 = "tf.Identity"(%arg0) : (!tf_res) -> !tf_res
|
||||
// CHECK-NEXT: "tf.VarHandleOp"
|
||||
%var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"} : () -> !tf_res
|
||||
// CHECK-NEXT: "tf.WhileRegion"
|
||||
"tf.WhileRegion"(%arg1, %id0, %var_handle) ({
|
||||
^bb0(%carg0: tensor<i32>, %carg1: !tf_res, %carg2: !tf_res):
|
||||
%cond = call @whileregion_cond(%carg0, %carg1, %carg2) : (tensor<i32>, !tf_res, !tf_res) -> tensor<i1>
|
||||
"tf.Yield"(%cond) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%barg0: tensor<i32>, %barg1: !tf_res, %barg2: !tf_res):
|
||||
%new_values:3 = call @whileregion_body(%barg0, %barg1, %barg2) : (tensor<i32>, !tf_res,!tf_res) -> (tensor<i32>, !tf_res,!tf_res)
|
||||
"tf.Yield"(%new_values#0, %new_values#1, %new_values#2) : (tensor<i32>, !tf_res,!tf_res) -> ()
|
||||
}){is_stateless = false}
|
||||
: (tensor<i32>, !tf_res, !tf_res) -> (tensor<i32>, !tf_res, !tf_res)
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @whileregion_body
|
||||
func @whileregion_body(%arg0: tensor<i32>, %arg1: !tf_res, %arg2: !tf_res) -> (tensor<i32>, !tf_res, !tf_res) {
|
||||
%graph:3 = tf_executor.graph {
|
||||
// CHECK: tf_executor.island
|
||||
%island:4 = tf_executor.island {
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-SAME: {device = "/TPU:0"}
|
||||
%id0 = "tf.Identity"(%arg1) : (!tf_res) -> !tf_res
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-SAME: {device = "/TPU:1"}
|
||||
%id1 = "tf.Identity"(%arg2) : (!tf_res) -> !tf_res
|
||||
tf_executor.yield %arg0, %id0, %id1 : tensor<i32>, !tf_res, !tf_res
|
||||
}
|
||||
tf_executor.fetch %island#0, %island#1, %island#2 : tensor<i32>, !tf_res, !tf_res
|
||||
}
|
||||
return %graph#0, %graph#1, %graph#2: tensor<i32>, !tf_res, !tf_res
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @whileregion_cond
|
||||
func @whileregion_cond(%arg0: tensor<i32>, %arg1: !tf_res, %arg2: !tf_res) -> tensor<i1> {
|
||||
%graph = tf_executor.graph {
|
||||
// CHECK: tf_executor.island
|
||||
%island:2 = tf_executor.island {
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-SAME: {device = "/TPU:0"}
|
||||
%id0 = "tf.Identity"(%arg1) : (!tf_res) -> !tf_res
|
||||
%read = "tf.ReadVariableOp"(%id0) : (!tf_res) -> tensor<32xf32>
|
||||
%cst = constant dense<3.0> : tensor<32xf32>
|
||||
%cmp = "tf.Less"(%read, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xi1>
|
||||
%dims = constant dense<0> : tensor<1xi32>
|
||||
%reduce = "tf.All"(%cmp, %dims) {keep_dims = false} : (tensor<32xi1>, tensor<1xi32>) -> tensor<i1>
|
||||
tf_executor.yield %reduce : tensor<i1>
|
||||
}
|
||||
tf_executor.fetch %island#0 : tensor<i1>
|
||||
}
|
||||
return %graph : tensor<i1>
|
||||
}
|
||||
|
||||
@ -112,26 +112,6 @@ func @internal_resource() -> tensor<*xi32> {
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that pass fails when there are remaining resource operationss that can
|
||||
// not be lifted.
|
||||
|
||||
func @lifting_failure() -> tensor<*xi32> {
|
||||
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// expected-error @+1 {{has remaining resource inputs that can not be lifted}}
|
||||
%1 = "tf_device.cluster"() ( {
|
||||
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
%3 = "tf.SomeResourceOp"(%0, %2) : (tensor<*x!tf.resource>, tensor<*xi32>) -> tensor<*xi32>
|
||||
"tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
|
||||
tf_device.return %3 : tensor<*xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
|
||||
|
||||
return %1 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that pass lifts resource reads/writes from a loop, and removed unused
|
||||
// resources.
|
||||
|
||||
@ -347,30 +327,6 @@ func @while_cond(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> {
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that pass reports error on unsupported ops in loop body.
|
||||
|
||||
func @cluster_with_loop() -> () {
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
"tf_device.cluster"() ( {
|
||||
%1 = "tf.While"(%0) {
|
||||
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>)
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
func @while_body(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>) {
|
||||
// expected-error @+1 {{found unsupported operations on resource.}}
|
||||
"tf._UnknownOp"(%arg0) : (tensor<*x!tf.resource<tensor<f32>>>) -> ()
|
||||
return %arg0 : tensor<*x!tf.resource<tensor<f32>>>
|
||||
}
|
||||
func @while_cond(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> {
|
||||
%read = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
return %read : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that pass reports error on unsupported ops in loop cond.
|
||||
|
||||
func @cluster_with_loop() -> () {
|
||||
@ -409,7 +365,7 @@ func @cluster_with_case(%arg0: tensor<i32>) -> tensor<4xf32> {
|
||||
// CHECK: %[[CLUSTER:.*]]:2 = "tf_device.cluster"()
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
// CHECK: %[[CASE:.*]]:2 = "tf.Case"(%[[ARG0]], %[[READ0]], %[[READ1]])
|
||||
%3:2 = "tf.Case"(%arg0, %0, %1) {branches = [@branch_0, @branch_1, @branch_2]}
|
||||
%3:2 = "tf.Case"(%arg0, %0, %1) {branches = [@branch_0, @branch_1, @branch_2], is_stateless = false}
|
||||
: (tensor<i32>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<*x!tf.resource<tensor<4xf32>>>)
|
||||
-> (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>)
|
||||
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[CASE]]#1, %[[CASE]]#0)
|
||||
|
||||
@ -223,7 +223,7 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: tensor<!tf.resource<tensor<1x2x3xf32>>>
|
||||
func @shape_from_case_to_branch_functions(%arg0: tensor<i32>, %arg1: tensor<!tf.resource<tensor<1x2x3xf32>>>) -> tensor<1x2x3xf32> {
|
||||
// CHECK: %[[CASE:.*]] = "tf.Case"(%[[ARG_0]], %[[ARG_1]])
|
||||
%0 = "tf.Case"(%arg0, %arg1) {branches = [@branch_0, @branch_1]} : (tensor<i32>, tensor<!tf.resource<tensor<1x2x3xf32>>>) -> tensor<1x2x3xf32>
|
||||
%0 = "tf.Case"(%arg0, %arg1) {branches = [@branch_0, @branch_1], is_stateless = false} : (tensor<i32>, tensor<!tf.resource<tensor<1x2x3xf32>>>) -> tensor<1x2x3xf32>
|
||||
// CHECK: return %[[CASE]] : tensor<1x2x3xf32>
|
||||
return %0 : tensor<1x2x3xf32>
|
||||
}
|
||||
|
||||
@ -256,7 +256,7 @@ func @main(%arg0: tensor<i32>) -> () {
|
||||
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NOT: tf.EmptyTensorList
|
||||
%tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<f32>>>
|
||||
%case_op = "tf.Case"(%arg0, %tl) {branches = [@branch_0, @branch_1, @branch_2]}
|
||||
%case_op = "tf.Case"(%arg0, %tl) {branches = [@branch_0, @branch_1, @branch_2], is_stateless = false}
|
||||
: (tensor<i32>, tensor<!tf.variant<tensor<f32>>>) -> tensor<!tf.variant<tensor<f32>>>
|
||||
// CHECK: "tf.Slice"
|
||||
%pop:2 = "tf.TensorListPopBack"(%case_op, %elem_shape) : (tensor<!tf.variant<tensor<f32>>>, tensor<0xi32>) -> (tensor<!tf.variant<tensor<f32>>>, tensor<f32>)
|
||||
|
||||
@ -848,7 +848,7 @@ func @testInvalidIfOp(tensor<i1>, tensor<*xf32>) -> tensor<2xf32> {
|
||||
|
||||
// Test invalid tf.Yield operation (parent should be IfRegion)
|
||||
func @testInvalidYieldOp(%arg0: f32) -> () {
|
||||
// expected-error @+1 {{'tf.Yield' op expects parent op to be one of 'tf.IfRegion, tf.WhileRegion'}}
|
||||
// expected-error @+1 {{'tf.Yield' op expects parent op to be one of 'tf.CaseRegion, tf.IfRegion, tf.WhileRegion'}}
|
||||
"tf.Yield"(%arg0) : (f32) -> ()
|
||||
}
|
||||
|
||||
@ -3313,3 +3313,88 @@ func @testBatchToSpaceInvalidOutputDepth(%arg0: tensor<16x8x8x3xf32>, %arg1: ten
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<16x8x8x3xf32>, tensor<*xi32>) -> tensor<4x8x8x8xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testCaseRegionNoRegions(%arg0: tensor<i32>) {
|
||||
// expected-error @+1 {{expects to have at least 1 region}}
|
||||
"tf.CaseRegion"(%arg0) {is_stateless = false} : (tensor<i32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testCaseRegionBadBranchIndicesShape(%arg0: tensor<8xi32>) {
|
||||
// expected-error @+1 {{expects 'branch_index' to be a scalar, but got 'tensor<8xi32>'}}
|
||||
"tf.CaseRegion"(%arg0) ( {
|
||||
"tf.Yield"() : () -> ()
|
||||
}) {is_stateless = false} : (tensor<8xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testCaseRegionBadBranchIndicesNegative() {
|
||||
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
// expected-error @+1 {{expects 'branch_index' to be non-negative, but got -1}}
|
||||
"tf.CaseRegion"(%0) ( {
|
||||
"tf.Yield"() : () -> ()
|
||||
}) {is_stateless = false} : (tensor<i32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testCaseRegionBadBranchIndicesPositive() {
|
||||
%0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// expected-error @+1 {{expects 'branch_index' to be less than the number of regions (1), but got 1}}
|
||||
"tf.CaseRegion"(%0) ( {
|
||||
"tf.Yield"() : () -> ()
|
||||
}) {is_stateless = false} : (tensor<i32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testCaseRegionMismatchedNumResults(%arg0: tensor<i32>) {
|
||||
// expected-error @+1 {{region #0 should have same number (1) of results as tf.CaseRegion but has 0 results}}
|
||||
%1 = "tf.CaseRegion"(%arg0) ( {
|
||||
"tf.Yield"() : () -> ()
|
||||
}) {is_stateless = false} : (tensor<i32>) -> tensor<i1>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testCaseRegionMismatchedResultTypes(%arg0: tensor<i32>, %arg1: tensor<f32>) {
|
||||
// expected-error @+1 {{region #0 result type tensor<f32> is incompatible with tf.CaseRegion result type tensor<i1> at index 0}}
|
||||
%1 = "tf.CaseRegion"(%arg0) ( {
|
||||
"tf.Yield"(%arg1) : (tensor<f32>) -> ()
|
||||
}) {is_stateless = false} : (tensor<i32>) -> tensor<i1>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test valid tf.Cumsum
|
||||
func @testCumsum(%arg: tensor<8x16xf32>, %axis: tensor<i32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.Cumsum"(%arg, %axis) : (tensor<8x16xf32>, tensor<i32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testCumprod(%arg: tensor<8x16xf32>, %axis: tensor<2xi32>) -> tensor<8x16xf32> {
|
||||
// expected-error @+1 {{requires scalar axis operand}}
|
||||
%0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<2xi32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testCumprod(%arg: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%axis = constant dense<-3> : tensor<i32>
|
||||
// expected-error @+1 {{axis operand should be within range [-2, 2)}}
|
||||
%0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<i32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
}
|
||||
|
||||
@ -9,7 +9,7 @@ func @select(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<i32>, tensor<f32
|
||||
// CHECK: return %[[first]],
|
||||
%0 = "tf.DeviceIndex"() {device = "", device_names = ["CPU", "GPU"]} : () -> tensor<i32>
|
||||
%1 = "tf.DeviceIndex"() {device = "", device_names = ["CPU", "GPU"]} : () -> tensor<i32>
|
||||
%4 = "tf.Case"(%1, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>]} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%4 = "tf.Case"(%1, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], is_stateless = false} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
|
||||
return %0, %4 : tensor<i32>, tensor<f32>
|
||||
}
|
||||
|
||||
@ -0,0 +1,93 @@
|
||||
// RUN: tf-opt %s -tf-tpu-identity-pruning | FileCheck %s --dump-input=always
|
||||
|
||||
// Tests Identity op in cluster is pruned away.
|
||||
|
||||
// CHECK-LABEL: func @testIdentity
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
|
||||
func @testIdentity(%arg0: tensor<i32>) {
|
||||
// CHECK-NOT: "tf.Identity"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: tf_device.return [[ARG0]]
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
%1 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
tf_device.return %1 : tensor<i32>
|
||||
}) : () -> tensor<i32>
|
||||
return
|
||||
}
|
||||
|
||||
// Tests IdentityN op in cluster is pruned away.
|
||||
|
||||
// CHECK-LABEL: func @testIdentityN
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>, [[ARG1:%.*]]: tensor<f32>)
|
||||
func @testIdentityN(%arg0: tensor<i32>, %arg1: tensor<f32>) {
|
||||
// CHECK-NOT: "tf.IdentityN"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: tf_device.return [[ARG0]], [[ARG1]]
|
||||
%0:2 = "tf_device.cluster"() ( {
|
||||
%1:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>)
|
||||
tf_device.return %1#0, %1#1 : tensor<i32>, tensor<f32>
|
||||
}) : () -> (tensor<i32>, tensor<f32>)
|
||||
return
|
||||
}
|
||||
|
||||
// Tests transitive Identity ops reachable from the cluster are pruned away.
|
||||
|
||||
// CHECK-LABEL: func @testTransitiveIdentity
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
|
||||
func @testTransitiveIdentity(%arg0: tensor<i32>) {
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: "tf.PartitionedCall"([[ARG0]])
|
||||
// CHECK-SAME: f = @callee0
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
%1 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @callee0} : (tensor<i32>) -> tensor<i32>
|
||||
tf_device.return %1 : tensor<i32>
|
||||
}) : () -> tensor<i32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @callee0
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
|
||||
func @callee0(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
// CHECK-NOT: "tf.Identity"
|
||||
// CHECK: "tf.PartitionedCall"([[ARG0]])
|
||||
// CHECK-SAME: f = @callee1
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
%1 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @callee1} : (tensor<i32>) -> tensor<i32>
|
||||
return %1 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @callee1
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
|
||||
func @callee1(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
// CHECK-NOT: "tf.Identity"
|
||||
// CHECK: return [[ARG0]]
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// Tests Identity ops not reachable from the cluster are not pruned away.
|
||||
|
||||
// CHECK-LABEL: func @testIdentityOutsideCluster
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
|
||||
func @testIdentityOutsideCluster(%arg0: tensor<i32>) {
|
||||
// CHECK: [[IDENTITY:%.*]] = "tf.Identity"([[ARG0]])
|
||||
// CHECK: [[CLUSTER:%.*]] = "tf_device.cluster"
|
||||
// CHECK-NEXT: tf_device.return [[IDENTITY]]
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
%1 = "tf_device.cluster"() ( {
|
||||
tf_device.return %0 : tensor<i32>
|
||||
}) : () -> tensor<i32>
|
||||
// CHECK: "tf.PartitionedCall"([[CLUSTER]])
|
||||
// CHECK-SAME: f = @callee2
|
||||
%2 = "tf.PartitionedCall"(%1) {config = "", config_proto = "", executor_type = "", f = @callee2} : (tensor<i32>) -> tensor<i32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @callee2
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
|
||||
func @callee2(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
// CHECK: [[IDENTITY:%.*]] = "tf.Identity"([[ARG0]])
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
// CHECK: return [[IDENTITY]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
@ -95,16 +95,16 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
|
||||
func_pm.addPass(CreateTPUHostComputationExpansionPass());
|
||||
func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass());
|
||||
}
|
||||
// Run another shape inference pass because resource decomposition might have
|
||||
// created new partial types.
|
||||
pm.addPass(TF::CreateTFShapeInferencePass());
|
||||
pm.addPass(TFDevice::CreateResourceOpLiftingPass());
|
||||
pm.addPass(TF::CreateTFFunctionalControlFlowToRegions());
|
||||
pm.addPass(mlir::createInlinerPass());
|
||||
pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass());
|
||||
pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
|
||||
|
||||
// Run another shape inference pass because resource decomposition might have
|
||||
// created new partial types.
|
||||
pm.addPass(TF::CreateTFShapeInferencePass());
|
||||
pm.addNestedPass<FuncOp>(tf_executor::CreateTFExecutorConstantSinkingPass());
|
||||
pm.addPass(TFDevice::CreateResourceOpLiftingPass());
|
||||
pm.addPass(TF::CreateResourceDeviceInferencePass());
|
||||
pm.addPass(TFDevice::CreateClusterOutliningPass());
|
||||
pm.addPass(CreateTPUDynamicPaddingMapperPass());
|
||||
|
||||
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "mlir/IR/OpDefinition.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
@ -68,7 +69,7 @@ static bool ShouldBeFolded(Operation* inst) {
|
||||
|
||||
LogicalResult ConstantFoldFallbackHook(
|
||||
Operation* inst, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute>& results) { // NOLINT
|
||||
SmallVectorImpl<OpFoldResult>& results) { // NOLINT
|
||||
// Instructions with side effects should not be constant folded to preserve
|
||||
// the original semantics.
|
||||
if (inst->getNumRegions() != 0 || !MemoryEffectOpInterface::hasNoEffect(inst))
|
||||
@ -126,8 +127,16 @@ LogicalResult ConstantFoldFallbackHook(
|
||||
// TODO(jpienaar): Avoid using global context & mutex here.
|
||||
static auto* mu = new tensorflow::mutex();
|
||||
tensorflow::mutex_lock l(*mu);
|
||||
return tensorflow::EvaluateOperation(inst, inputs, ctx, &results);
|
||||
SmallVector<Attribute, 8> constants;
|
||||
LogicalResult status =
|
||||
tensorflow::EvaluateOperation(inst, inputs, ctx, &constants);
|
||||
results.assign(constants.begin(), constants.end());
|
||||
return status;
|
||||
}
|
||||
|
||||
static bool init_hooks = ([] () {
|
||||
TensorFlowDialect::RegisterConstantFoldHook(ConstantFoldFallbackHook);
|
||||
}(), true);
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
@ -27,7 +27,7 @@ namespace TF {
|
||||
|
||||
LogicalResult ConstantFoldFallbackHook(
|
||||
Operation *inst, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results); // NOLINT
|
||||
SmallVectorImpl<OpFoldResult> &results); // NOLINT
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/DialectHooks.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
@ -35,31 +34,22 @@ namespace {
|
||||
|
||||
// Since this method is passed to MLIR as decode hook it has to conform
|
||||
// to LLVM style used by MLIR.
|
||||
bool DecodeOpaqueTensorHook(const OpaqueElementsAttr input,
|
||||
ElementsAttr& output) { // NOLINT
|
||||
LogicalResult DecodeOpaqueTensorHook(const OpaqueElementsAttr input,
|
||||
ElementsAttr& output) { // NOLINT
|
||||
Builder builder(input.getType().getContext());
|
||||
auto decoded_attr_or = tensorflow::DecodeOpaqueTensor(input, builder);
|
||||
if (!decoded_attr_or.ok()) {
|
||||
VLOG(2) << decoded_attr_or.status().error_message();
|
||||
return true;
|
||||
return failure();
|
||||
}
|
||||
|
||||
output = decoded_attr_or.ValueOrDie();
|
||||
return false;
|
||||
return success();
|
||||
}
|
||||
|
||||
// Hooks for the TensorFlow dialect.
|
||||
class TensorFlowHooks : public DialectHooks {
|
||||
public:
|
||||
DialectConstantFoldHook getConstantFoldHook() {
|
||||
return TF::ConstantFoldFallbackHook;
|
||||
}
|
||||
DialectConstantDecodeHook getDecodeHook() { return DecodeOpaqueTensorHook; }
|
||||
};
|
||||
static bool init_hooks = ([] () {
|
||||
TF::TensorFlowDialect::RegisterDecodeConstantHook(DecodeOpaqueTensorHook);
|
||||
}(), true);
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// Static initialization for TensorFlow dialect hooks registration.
|
||||
static DialectHooksRegistration<TensorFlowHooks> tf_hooks_registration("tf");
|
||||
|
||||
} // namespace mlir
|
||||
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
@ -427,12 +428,38 @@ class LowerSparseMatMulOp : public OpRewritePattern<TF::SparseMatMulOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Lowers _UnaryOpsComposition op as a series of original TensorFlow ops that
|
||||
// were fused together.
|
||||
class Lower_UnaryOpsComposition
|
||||
: public OpRewritePattern<TF::_UnaryOpsCompositionOp> {
|
||||
public:
|
||||
using OpRewritePattern<TF::_UnaryOpsCompositionOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TF::_UnaryOpsCompositionOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value result = op.x();
|
||||
for (StringRef op_name :
|
||||
op.op_names().getAsRange<StringAttr, StringRef>()) {
|
||||
std::string full_name = "tf." + op_name.str();
|
||||
// All ops in the sequences have the same result type as the original
|
||||
// result type.
|
||||
OperationState state(op.getLoc(), full_name, /*operands=*/{result},
|
||||
/*types=*/{op.getType()}, /*attributes=*/{});
|
||||
Operation *op = rewriter.createOperation(state);
|
||||
result = op->getResult(0);
|
||||
}
|
||||
rewriter.replaceOp(op, {result});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void PopulateLoweringTFPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
patterns->insert<LowerAddNOp, LowerDynamicStitchOp, LowerInvertPermutationOp,
|
||||
LowerPackOp, LowerSparseMatMulOp>(context);
|
||||
LowerPackOp, LowerSparseMatMulOp, Lower_UnaryOpsComposition>(
|
||||
context);
|
||||
populateWithGenerated(context, patterns);
|
||||
}
|
||||
|
||||
|
||||
@ -131,6 +131,25 @@ LogicalResult MarkUncompilableOps(
|
||||
return success();
|
||||
}
|
||||
|
||||
// Unmarks outside compilation for any op that has parents already
|
||||
// marked for outside compilation since the child will be extracted
|
||||
// anyways.
|
||||
void UnmarkChildren(Block* block) {
|
||||
block->walk([&](Operation* op) {
|
||||
if (!op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) return;
|
||||
Operation* iter_op = op;
|
||||
bool remove_attr = false;
|
||||
while (auto* parent_op = iter_op->getParentOp()) {
|
||||
if (parent_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
|
||||
remove_attr = true;
|
||||
break;
|
||||
}
|
||||
iter_op = parent_op;
|
||||
}
|
||||
if (remove_attr) op->removeAttr(kXlaOutsideCompilationAttr);
|
||||
});
|
||||
}
|
||||
|
||||
void MarkOpsForOutsideCompilation::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
const Dialect* tf_dialect = getContext().getRegisteredDialect("tf");
|
||||
@ -168,6 +187,17 @@ void MarkOpsForOutsideCompilation::runOnOperation() {
|
||||
});
|
||||
|
||||
if (result.wasInterrupted()) return signalPassFailure();
|
||||
|
||||
module.walk([&](tf_device::ClusterOp cluster) {
|
||||
// Only if `allow_soft_placement` attribute is true should we unmark ops
|
||||
// for outside compilation.
|
||||
auto soft_placement_attr =
|
||||
cluster.getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
|
||||
if (!(soft_placement_attr && soft_placement_attr.getValue())) {
|
||||
return;
|
||||
}
|
||||
UnmarkChildren(&cluster.GetBody());
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -271,6 +271,9 @@ namespace TFTPU {
|
||||
// `_tpu_replicate` attribute.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUClusterFormationPass();
|
||||
|
||||
// Creates a pass that removes Identity/IdentityN ops from a cluster.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUIdentityPruningPass();
|
||||
|
||||
// Creates a pass that allows TPU program inputs to have layouts determined at
|
||||
// run time.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicLayoutPass();
|
||||
|
||||
@ -26,10 +26,13 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
@ -39,6 +42,9 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h"
|
||||
|
||||
#define DEBUG_TYPE "tf-resource-device-inference"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
@ -132,6 +138,13 @@ inline StringRef GetDeviceAttr(Operation* op) {
|
||||
return device_attr ? device_attr.getValue() : "";
|
||||
}
|
||||
|
||||
// Print operation with debug info (to get line number info for debugging)
|
||||
void dump(StringRef message, Operation* op) {
|
||||
llvm::dbgs() << message;
|
||||
op->print(llvm::dbgs(), OpPrintingFlags().enableDebugInfo(true));
|
||||
llvm::dbgs() << "\n";
|
||||
}
|
||||
|
||||
// Propagates device assignment inside a function.
|
||||
LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
|
||||
PerFunctionResult* result) {
|
||||
@ -153,26 +166,67 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
|
||||
if (failed(res)) return res;
|
||||
}
|
||||
|
||||
auto walk_res = func_op.walk([&](Operation* op) {
|
||||
if (auto var_handle = dyn_cast<VarHandleOp>(op)) {
|
||||
// Record VarHandleOp's device attribute.
|
||||
StringRef device_attr = GetDeviceAttr(op);
|
||||
if (device_attr.empty()) return WalkResult::advance();
|
||||
auto res = AddResourceDeviceAndEmitError(var_handle.resource(),
|
||||
device_attr, op, result);
|
||||
if (failed(res)) return WalkResult::interrupt();
|
||||
}
|
||||
if (auto identity = dyn_cast<IdentityOp>(op)) {
|
||||
// Try to construct IdentityOp's attribute from recorded assignment.
|
||||
if (!GetDeviceAttr(op).empty()) return WalkResult::advance();
|
||||
for (auto output : filter_resources(op->getResults())) {
|
||||
if (auto device = result->DeviceForResource(output))
|
||||
identity.setAttr(kDeviceAttr, builder.getStringAttr(*device));
|
||||
}
|
||||
return WalkResult::advance();
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
// To support WhileRegion, we need to propagate device attributes from
|
||||
// WhileRegion operands to body/cond region arguments *prior* to visiting
|
||||
// these regions. Use tensorflow::walk() instead of MLIR core walker to
|
||||
// implement such a pre-order walk.
|
||||
auto walk_res = tensorflow::GenericWalk(
|
||||
func_op, [&](Operation* op, const tensorflow::WalkStage& stage) {
|
||||
// We just need to visit operations in pre-order mode.
|
||||
if (!stage.IsBeforeAllRegions()) return WalkResult::advance();
|
||||
|
||||
if (auto var_handle = dyn_cast<VarHandleOp>(op)) {
|
||||
// Record VarHandleOp's device attribute.
|
||||
StringRef device_attr = GetDeviceAttr(op);
|
||||
if (device_attr.empty()) return WalkResult::advance();
|
||||
auto res = AddResourceDeviceAndEmitError(var_handle.resource(),
|
||||
device_attr, op, result);
|
||||
if (failed(res)) return WalkResult::interrupt();
|
||||
} else if (auto identity = dyn_cast<IdentityOp>(op)) {
|
||||
LLVM_DEBUG(dump("Visiting ", identity));
|
||||
// Try to construct IdentityOp's attribute from recorded assignment.
|
||||
if (!GetDeviceAttr(op).empty()) return WalkResult::advance();
|
||||
for (auto output : filter_resources(op->getResults())) {
|
||||
LLVM_DEBUG(llvm::dbgs() << " Processing output #"
|
||||
<< output.getResultNumber() << "\n");
|
||||
if (auto device = result->DeviceForResource(output)) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< " Setting device = " << *device << "\n");
|
||||
identity.setAttr(kDeviceAttr, builder.getStringAttr(*device));
|
||||
}
|
||||
}
|
||||
} else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
|
||||
// For WhileRegion, do local analysis prior to visiting the attached
|
||||
// regions and propagate device annotations to the cond and body
|
||||
// region arguments. The annotations are the union of annotations
|
||||
// on the input and result. Resource alias analysis already propagates
|
||||
// resource ID from the inputs to the results for a while, so just
|
||||
// need to consider the results.
|
||||
LLVM_DEBUG(llvm::dbgs() << "Visiting WhileRegion\n");
|
||||
|
||||
for (auto output : filter_resources(while_region.getResults())) {
|
||||
auto device = result->DeviceForResource(output);
|
||||
int output_index = output.getResultNumber();
|
||||
if (!device) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< " No device for output #" << output_index << "\n");
|
||||
continue;
|
||||
}
|
||||
// Transfer the annotation to both region arguments
|
||||
for (Region* region : while_region.getRegions()) {
|
||||
BlockArgument arg = region->getArgument(output_index);
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< " Propagating device = '" << *device
|
||||
<< "' to arg #" << output_index << " of region #"
|
||||
<< region->getRegionNumber() << "\n");
|
||||
if (failed(AddResourceDeviceAndEmitError(arg, *device,
|
||||
while_region, result)))
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
return failure(walk_res.wasInterrupted());
|
||||
}
|
||||
|
||||
@ -201,6 +255,10 @@ void ResourceDeviceInference::runOnOperation() {
|
||||
Value arg_operand = caller_operands[arg.getArgNumber()];
|
||||
auto device = caller_res.DeviceForResource(arg_operand);
|
||||
if (!device) continue;
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Propagating '" << *device << "' to arg #"
|
||||
<< arg.getArgNumber() << " of function @"
|
||||
<< callee.getName() << "\n");
|
||||
if (failed(AddResourceDeviceAndEmitError(arg, *device, caller,
|
||||
&callee_res,
|
||||
&callee_needs_recompute)))
|
||||
@ -240,6 +298,8 @@ void ResourceDeviceInference::runOnOperation() {
|
||||
"call");
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Visiting call to function @" << func.getName() << "\n");
|
||||
if (failed(propagate_operands_to_callee_arguments(
|
||||
call, call.getArgOperands(), {func}, func_res)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
@ -330,15 +331,6 @@ LogicalResult HoistResourceOpsFromCluster(tf_device::ClusterOp cluster,
|
||||
getUsedValuesDefinedAbove(new_cluster.body(), new_cluster.body(),
|
||||
captured_values);
|
||||
|
||||
for (Value v : captured_values) {
|
||||
auto tensor_type = v.getType().dyn_cast<TensorType>();
|
||||
if (!tensor_type) continue;
|
||||
if (!tensor_type.getElementType().isa<TF::ResourceType>()) continue;
|
||||
|
||||
return new_cluster.emitOpError()
|
||||
<< "has remaining resource inputs that can not be lifted";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -361,29 +353,23 @@ LogicalResult FindResourceArgUseInfo(
|
||||
ResourceArgUseInfo info;
|
||||
info.used = false;
|
||||
info.updated = false;
|
||||
bool do_not_touch = false;
|
||||
bool read_or_assigned = false;
|
||||
for (auto user : arg.getUsers()) {
|
||||
if (user == return_op) continue;
|
||||
info.used = true;
|
||||
if (auto read = llvm::dyn_cast<TF::ReadVariableOp>(user)) {
|
||||
info.used = true;
|
||||
read_or_assigned = true;
|
||||
info.data_type = read.getType();
|
||||
continue;
|
||||
}
|
||||
if (auto assign = llvm::dyn_cast<TF::AssignVariableOp>(user)) {
|
||||
info.used = true;
|
||||
read_or_assigned = true;
|
||||
info.updated = true;
|
||||
info.data_type = assign.value().getType();
|
||||
continue;
|
||||
}
|
||||
if (isa<TF::StackPushV2Op, TF::StackPopV2Op>(user)) {
|
||||
// Stacks will be handled by a separate pass.
|
||||
do_not_touch = true;
|
||||
break;
|
||||
}
|
||||
user->emitOpError("found unsupported operations on resource.");
|
||||
return failure();
|
||||
}
|
||||
if (!do_not_touch) (*result)[arg.getArgNumber()] = info;
|
||||
if (!info.used || read_or_assigned) (*result)[arg.getArgNumber()] = info;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
@ -914,8 +900,8 @@ LogicalResult HandlePartitionedCallOpCallee(
|
||||
// resource-lifted new callee function in lifting_info.
|
||||
template <typename CallOpType>
|
||||
void UpdatePartitionedCallOpWithNewCallee(
|
||||
CallOpType call_op, const PartitionedCallLiftingInfo& lifting_info) {
|
||||
if (lifting_info.lifted_callee == nullptr) return;
|
||||
CallOpType call_op, PartitionedCallLiftingInfo& lifting_info) {
|
||||
if (!lifting_info.lifted_callee) return;
|
||||
// Replace output resource uses with the aliasing input, so that we can remove
|
||||
// this output.
|
||||
for (const auto& entry : lifting_info.old_outputs_aliasing_old_inputs) {
|
||||
@ -929,12 +915,10 @@ void UpdatePartitionedCallOpWithNewCallee(
|
||||
auto new_operands =
|
||||
FilterRange<Value, OperandRange>(call_op.args(), lifting_info.use_info);
|
||||
auto new_call = builder.create<CallOpType>(
|
||||
call_op.getLoc(),
|
||||
const_cast<FuncOp&>(lifting_info.lifted_callee).getType().getResults(),
|
||||
call_op.getLoc(), lifting_info.lifted_callee.getType().getResults(),
|
||||
new_operands, call_op.getAttrs());
|
||||
new_call.setAttr(
|
||||
"f", builder.getSymbolRefAttr(
|
||||
const_cast<FuncOp&>(lifting_info.lifted_callee).getName()));
|
||||
"f", builder.getSymbolRefAttr(lifting_info.lifted_callee.getName()));
|
||||
AddLoadsStoresOutsideControlFlowOp(
|
||||
new_call, lifting_info.arg_data_type_and_updated_output_index);
|
||||
// Replace uses.
|
||||
@ -949,7 +933,8 @@ void UpdatePartitionedCallOpWithNewCallee(
|
||||
}
|
||||
|
||||
LogicalResult HoistForFunctionalControlFlow(
|
||||
Block*, ModuleOp, llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>*);
|
||||
Block*, ModuleOp,
|
||||
llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*);
|
||||
|
||||
// A templated routine for handling both PartitionedCallOp and
|
||||
// StatefulPartitionedCallOp. If the callee is already lifted, it just updates
|
||||
@ -958,9 +943,10 @@ LogicalResult HoistForFunctionalControlFlow(
|
||||
template <typename CallOpType>
|
||||
LogicalResult HandlePartitionedCallOp(
|
||||
CallOpType call_op, FuncOp callee, ModuleOp module,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>* lifted_callees) {
|
||||
auto emplace_res =
|
||||
lifted_callees->try_emplace(callee, PartitionedCallLiftingInfo());
|
||||
llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*
|
||||
lifted_callees) {
|
||||
auto emplace_res = lifted_callees->try_emplace(callee.getName(),
|
||||
PartitionedCallLiftingInfo());
|
||||
if (emplace_res.second) {
|
||||
// Unseen callee. Perform resource lifting on it.
|
||||
HoistForFunctionalControlFlow(&callee.front(), module, lifted_callees);
|
||||
@ -977,7 +963,7 @@ LogicalResult HandlePartitionedCallOp(
|
||||
// body/cond/branch/callee functions.
|
||||
LogicalResult HoistForFunctionalControlFlow(
|
||||
Block* block, ModuleOp module,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>*
|
||||
llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*
|
||||
lifted_partitioned_call_callees) {
|
||||
// Remove identity nodes to avoid aliasing.
|
||||
RemoveIdentity(block);
|
||||
@ -1056,7 +1042,7 @@ LogicalResult HoistForFunctionalControlFlow(
|
||||
// Returns failure if there are remaining resource-type values that can not be
|
||||
// lifted.
|
||||
void ResourceOpLiftingPass::runOnOperation() {
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>
|
||||
llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
|
||||
lifted_partitioned_call_callees;
|
||||
ModuleOp module = getOperation();
|
||||
auto result = module.walk([&](FuncOp func_op) {
|
||||
@ -1121,7 +1107,7 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) {
|
||||
<< function.getBlocks().size();
|
||||
}
|
||||
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>
|
||||
llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
|
||||
lifted_partitioned_call_callees;
|
||||
return HoistForFunctionalControlFlow(&function.front(),
|
||||
cast<ModuleOp>(function.getParentOp()),
|
||||
|
||||
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "mlir/IR/SymbolTable.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/FoldInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
@ -697,11 +698,8 @@ bool ShapeInference::RefineShapeForPassThroughOps(Operation* op) {
|
||||
// TODO(jpienaar): The tf.Cast op, which is uniformly inserted at the
|
||||
// moment, cannot handle arbirary types (e.g., it can't handle quantized
|
||||
// types). This restriction can be relaxed if not only tf.Cast is used.
|
||||
auto kind = t.getKind();
|
||||
return (kind >= Type::FIRST_STANDARD_TYPE &&
|
||||
kind < Type::LAST_STANDARD_TYPE) ||
|
||||
(kind >= Type::FIRST_TENSORFLOW_TYPE &&
|
||||
kind < Type::LAST_TENSORFLOW_TYPE);
|
||||
return t.getDialect().getNamespace().empty() ||
|
||||
isa<TensorFlowDialect>(t.getDialect());
|
||||
};
|
||||
|
||||
bool changed = false;
|
||||
@ -1174,10 +1172,11 @@ LogicalResult ShapeInference::TryToFold(Operation* op) {
|
||||
if (!dialect) return failure();
|
||||
// Only attempt TF dialect fallback if there are no unknown operands.
|
||||
if (some_unknown && dialect == tf_dialect_) return failure();
|
||||
SmallVector<Attribute, 8> constants;
|
||||
if (failed(dialect->constantFoldHook(op, constant_operands, constants)))
|
||||
auto* interface = dialect->getRegisteredInterface<DialectFoldInterface>();
|
||||
if (!interface) return failure();
|
||||
|
||||
if (failed(interface->Fold(op, constant_operands, fold_results)))
|
||||
return failure();
|
||||
fold_results.assign(constants.begin(), constants.end());
|
||||
}
|
||||
|
||||
for (auto result : zip(op->getResults(), fold_results)) {
|
||||
|
||||
@ -0,0 +1,113 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/Region.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFTPU {
|
||||
|
||||
namespace {
|
||||
|
||||
// This pass removes Identity/IdentityN ops from the TPU computation and
|
||||
// reachable functions.
|
||||
// TODO(lyandy): Remove this pass once resource op lifting is migrated to use
|
||||
// resource alias analysis and support region based control flow. Removing
|
||||
// Identity ops may remove `_XlaSharding` annotation attribute if Identity ops
|
||||
// are used to propagate such information.
|
||||
|
||||
struct TPUIdentityPruning
|
||||
: public PassWrapper<TPUIdentityPruning, OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// Collects all reachable functions (via call ops) from a given region.
|
||||
SmallVector<FuncOp, 4> CollectReachableFunctions(Region& region) {
|
||||
llvm::SmallPtrSet<FuncOp, 4> reachable_funcs;
|
||||
|
||||
auto collect_reachable_funcs =
|
||||
[&reachable_funcs](Region& src, SmallVectorImpl<FuncOp>& funcs_to_visit) {
|
||||
src.walk([&reachable_funcs, &funcs_to_visit](CallOpInterface call_op) {
|
||||
auto func = dyn_cast_or_null<FuncOp>(call_op.resolveCallable());
|
||||
if (func && reachable_funcs.insert(func).second)
|
||||
funcs_to_visit.push_back(func);
|
||||
});
|
||||
};
|
||||
|
||||
SmallVector<FuncOp, 4> funcs_to_visit;
|
||||
collect_reachable_funcs(region, funcs_to_visit);
|
||||
|
||||
while (!funcs_to_visit.empty()) {
|
||||
SmallVector<FuncOp, 4> new_funcs_to_visit;
|
||||
for (FuncOp func_to_visit : funcs_to_visit) {
|
||||
if (!func_to_visit.getCallableRegion()) continue;
|
||||
collect_reachable_funcs(*func_to_visit.getCallableRegion(),
|
||||
new_funcs_to_visit);
|
||||
}
|
||||
funcs_to_visit.swap(new_funcs_to_visit);
|
||||
}
|
||||
|
||||
return llvm::to_vector<4>(reachable_funcs);
|
||||
}
|
||||
|
||||
// Removes Identity/IdentityN ops from a region and forwards its operands to its
|
||||
// results.
|
||||
void RemoveIdentityFromRegion(Region& region) {
|
||||
region.walk([](Operation* op) {
|
||||
if (isa<TF::IdentityOp, TF::IdentityNOp>(op)) {
|
||||
op->replaceAllUsesWith(op->getOperands());
|
||||
op->erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void TPUIdentityPruning::runOnOperation() {
|
||||
SmallVector<tf_device::ClusterOp, 4> clusters;
|
||||
getOperation().walk(
|
||||
[&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); });
|
||||
|
||||
for (tf_device::ClusterOp cluster : clusters) {
|
||||
RemoveIdentityFromRegion(cluster.body());
|
||||
auto reachable_funcs = CollectReachableFunctions(cluster.body());
|
||||
for (FuncOp reachable_func : reachable_funcs)
|
||||
RemoveIdentityFromRegion(*reachable_func.getCallableRegion());
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUIdentityPruningPass() {
|
||||
return std::make_unique<TPUIdentityPruning>();
|
||||
}
|
||||
|
||||
static PassRegistration<TPUIdentityPruning> pass(
|
||||
"tf-tpu-identity-pruning",
|
||||
"Removes Identity/IdentityN ops from the TPU computation");
|
||||
|
||||
} // namespace TFTPU
|
||||
} // namespace mlir
|
||||
@ -177,7 +177,8 @@ Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def,
|
||||
restrict_functionalization_to_tpu_nodes
|
||||
? [](const Node* n) { return n->attrs().Find(kTpuReplicateAttr); }
|
||||
: NodeFilter{};
|
||||
return FunctionalizeControlFlow(graph, flib_def, node_filter);
|
||||
return FunctionalizeControlFlow(graph, flib_def, node_filter,
|
||||
/*include_functions=*/true);
|
||||
}
|
||||
|
||||
// Stateful helper class to import a TensorFlow model into an MLIR Module.
|
||||
|
||||
@ -219,22 +219,18 @@ StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
|
||||
if (auto attr = inst.getAttrOfType<mlir::ElementsAttr>(attr_id)) {
|
||||
mlir::Attribute rand_val;
|
||||
mlir::Type element_type = attr.getType().getElementType();
|
||||
if (element_type.isa<mlir::IntegerType>()) {
|
||||
rand_val = mlir::IntegerAttr::get(element_type, std::rand());
|
||||
} else if (element_type.isF16() || element_type.isF32() ||
|
||||
element_type.isF64()) {
|
||||
rand_val = mlir::FloatAttr::get(element_type,
|
||||
std::rand() * 1.0 / RAND_MAX);
|
||||
|
||||
switch (element_type.getKind()) {
|
||||
case mlir::StandardTypes::Integer:
|
||||
rand_val = mlir::IntegerAttr::get(element_type, std::rand());
|
||||
break;
|
||||
case mlir::StandardTypes::F16:
|
||||
case mlir::StandardTypes::F32:
|
||||
case mlir::StandardTypes::F64:
|
||||
rand_val = mlir::FloatAttr::get(element_type,
|
||||
std::rand() * 1.0 / RAND_MAX);
|
||||
break;
|
||||
default:
|
||||
inst.emitWarning()
|
||||
<< "Skipping splat conversion for "
|
||||
<< "an unsupported attribute type " << element_type;
|
||||
continue;
|
||||
} else {
|
||||
inst.emitWarning()
|
||||
<< "Skipping splat conversion for "
|
||||
<< "an unsupported attribute type " << element_type;
|
||||
continue;
|
||||
}
|
||||
auto new_attr =
|
||||
mlir::DenseElementsAttr::get(attr.getType(), rand_val);
|
||||
|
||||
@ -36,8 +36,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/bfloat16.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
|
||||
@ -91,64 +91,62 @@ Status ConvertDataType(DataType dtype, Builder builder, Type* type) {
|
||||
}
|
||||
|
||||
Status ConvertScalarTypeToDataType(Type type, DataType* dtype) {
|
||||
switch (type.getKind()) {
|
||||
case mlir::StandardTypes::F16:
|
||||
*dtype = DT_HALF;
|
||||
return Status::OK();
|
||||
case mlir::StandardTypes::F32:
|
||||
*dtype = DT_FLOAT;
|
||||
return Status::OK();
|
||||
case mlir::StandardTypes::F64:
|
||||
*dtype = DT_DOUBLE;
|
||||
return Status::OK();
|
||||
case mlir::StandardTypes::BF16:
|
||||
*dtype = DT_BFLOAT16;
|
||||
return Status::OK();
|
||||
case mlir::StandardTypes::Integer: {
|
||||
const auto& itype = type.cast<mlir::IntegerType>();
|
||||
switch (itype.getWidth()) {
|
||||
case 1:
|
||||
*dtype = DT_BOOL;
|
||||
return Status::OK();
|
||||
case 8:
|
||||
*dtype = itype.isUnsigned() ? DT_UINT8 : DT_INT8;
|
||||
return Status::OK();
|
||||
case 16:
|
||||
*dtype = itype.isUnsigned() ? DT_UINT16 : DT_INT16;
|
||||
return Status::OK();
|
||||
case 32:
|
||||
*dtype = itype.isUnsigned() ? DT_UINT32 : DT_INT32;
|
||||
return Status::OK();
|
||||
case 64:
|
||||
*dtype = itype.isUnsigned() ? DT_UINT64 : DT_INT64;
|
||||
return Status::OK();
|
||||
default:
|
||||
return errors::Unimplemented(
|
||||
absl::StrCat("Converting ", debugString(type), " to DataType"));
|
||||
}
|
||||
}
|
||||
case mlir::StandardTypes::Complex: {
|
||||
auto etype = type.cast<mlir::ComplexType>().getElementType();
|
||||
if (etype.isF32()) {
|
||||
*dtype = DT_COMPLEX64;
|
||||
return Status::OK();
|
||||
} else if (etype.isF64()) {
|
||||
*dtype = DT_COMPLEX128;
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::Unimplemented(
|
||||
absl::StrCat("Converting ", debugString(type), " to DataType"));
|
||||
}
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
|
||||
case mlir::TF::TensorFlowTypes::enumerant: \
|
||||
*dtype = DT_##enumerant; \
|
||||
if (type.isF16()) {
|
||||
*dtype = DT_HALF;
|
||||
return Status::OK();
|
||||
} else if (type.isF32()) {
|
||||
*dtype = DT_FLOAT;
|
||||
return Status::OK();
|
||||
} else if (type.isF64()) {
|
||||
*dtype = DT_DOUBLE;
|
||||
return Status::OK();
|
||||
} else if (type.isBF16()) {
|
||||
*dtype = DT_BFLOAT16;
|
||||
return Status::OK();
|
||||
} else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
|
||||
switch (itype.getWidth()) {
|
||||
case 1:
|
||||
*dtype = DT_BOOL;
|
||||
return Status::OK();
|
||||
case 8:
|
||||
*dtype = itype.isUnsigned() ? DT_UINT8 : DT_INT8;
|
||||
return Status::OK();
|
||||
case 16:
|
||||
*dtype = itype.isUnsigned() ? DT_UINT16 : DT_INT16;
|
||||
return Status::OK();
|
||||
case 32:
|
||||
*dtype = itype.isUnsigned() ? DT_UINT32 : DT_INT32;
|
||||
return Status::OK();
|
||||
case 64:
|
||||
*dtype = itype.isUnsigned() ? DT_UINT64 : DT_INT64;
|
||||
return Status::OK();
|
||||
default:
|
||||
return errors::Unimplemented(
|
||||
absl::StrCat("Converting ", debugString(type), " to DataType"));
|
||||
}
|
||||
} else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
|
||||
auto etype = complex_type.getElementType();
|
||||
if (etype.isF32()) {
|
||||
*dtype = DT_COMPLEX64;
|
||||
return Status::OK();
|
||||
} else if (etype.isF64()) {
|
||||
*dtype = DT_COMPLEX128;
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::Unimplemented(
|
||||
absl::StrCat("Converting ", debugString(type), " to DataType"));
|
||||
}
|
||||
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
|
||||
if (type.isa<mlir::TF::tftype##Type>()) { \
|
||||
*dtype = DT_##enumerant; \
|
||||
return Status::OK(); \
|
||||
}
|
||||
// NOLINTNEXTLINE
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
|
||||
default:
|
||||
return errors::Unimplemented(
|
||||
absl::StrCat("Converting ", debugString(type), " to DataType"));
|
||||
}
|
||||
|
||||
return errors::Unimplemented(
|
||||
absl::StrCat("Converting ", debugString(type), " to DataType"));
|
||||
}
|
||||
|
||||
Status ConvertToDataType(Type type, DataType* dtype) {
|
||||
|
||||
@ -9,103 +9,41 @@ package(
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//tensorflow/compiler/mlir/...",
|
||||
"//tensorflow/core/kernels/mlir_generated/...",
|
||||
],
|
||||
packages = ["//tensorflow/compiler/mlir/..."],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "passes",
|
||||
srcs = ["passes.cc"],
|
||||
hdrs = ["passes.h"],
|
||||
name = "cubin_creator",
|
||||
srcs = ["cubin_creator.cc"],
|
||||
hdrs = ["cubin_creator.h"],
|
||||
copts = if_cuda(["-DGOOGLE_CUDA=1"]),
|
||||
deps = [
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//mlir:GPUDialect",
|
||||
"@llvm-project//mlir:LLVMDialect",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:TargetNVVMIR",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:materialize_broadcasts", # buildcleaner: keep
|
||||
"//tensorflow/compiler/xla/service/gpu:stream_executor_util",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/mlir/hlo:unfuse_batch_norm", # buildcleaner: keep
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla/service/gpu:target_constants",
|
||||
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
|
||||
"//tensorflow/core:cuda_libdevice_path",
|
||||
"//tensorflow/core:lib",
|
||||
] + if_cuda(["//tensorflow/stream_executor/gpu:asm_compiler"]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "kernel_creator",
|
||||
srcs = ["kernel_creator.cc"],
|
||||
hdrs = ["kernel_creator.h"],
|
||||
copts = if_cuda(["-DGOOGLE_CUDA=1"]),
|
||||
deps = [
|
||||
":passes",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Affine",
|
||||
"@llvm-project//mlir:AffineToStandardTransforms",
|
||||
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_dialect_registration",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:CFGTransforms",
|
||||
"@llvm-project//mlir:GPUDialect",
|
||||
"@llvm-project//mlir:GPUToNVVMTransforms",
|
||||
"@llvm-project//mlir:GPUTransforms",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LLVMDialect",
|
||||
"@llvm-project//mlir:LLVMTransforms",
|
||||
"@llvm-project//mlir:LinalgOps",
|
||||
"@llvm-project//mlir:LinalgToLLVM",
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:NVVMDialect",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:GPUToGPURuntimeTransforms",
|
||||
"@llvm-project//mlir:SCFToGPUPass",
|
||||
"@llvm-project//mlir:SCFTransforms",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TargetNVVMIR",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:all_passes",
|
||||
"//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration",
|
||||
"//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
|
||||
"//tensorflow/compiler/mlir/hlo:legalize_tanh_to_approximation",
|
||||
"//tensorflow/compiler/mlir/hlo:legalize_to_linalg",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_copy_removal",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_fuse_linalg",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_affine",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/hlo:materialize_broadcasts", # buildcleaner: keep
|
||||
"//tensorflow/compiler/mlir/hlo:unfuse_batch_norm", # buildcleaner: keep
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service/gpu:stream_executor_util",
|
||||
"//tensorflow/compiler/xla/service/gpu:target_constants",
|
||||
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
|
||||
"//tensorflow/compiler/xla/service/mlir_gpu:kernel_lowering",
|
||||
"//tensorflow/compiler/xla/service/mlir_gpu:passes",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:cuda_libdevice_path",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
] + if_cuda(["//tensorflow/stream_executor/gpu:asm_compiler"]),
|
||||
)
|
||||
|
||||
@ -114,36 +52,11 @@ tf_cc_binary(
|
||||
srcs = ["tf_to_cubin.cc"],
|
||||
visibility = ["//tensorflow/core/kernels/mlir_generated:__pkg__"],
|
||||
deps = [
|
||||
":kernel_creator",
|
||||
":cubin_creator",
|
||||
"//tensorflow/compiler/mlir:init_mlir",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Pass",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tf_to_kernel",
|
||||
srcs = ["tf_to_kernel.cc"],
|
||||
visibility = ["//tensorflow/core/kernels/mlir_generated:__pkg__"],
|
||||
deps = [
|
||||
":kernel_creator",
|
||||
"//tensorflow/compiler/mlir:init_mlir",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Analysis",
|
||||
"@llvm-project//llvm:CodeGen",
|
||||
"@llvm-project//llvm:Core",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:Target",
|
||||
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
|
||||
"@llvm-project//llvm:X86Disassembler", # fixdeps: keep
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:TargetLLVMIR",
|
||||
],
|
||||
)
|
||||
|
||||
@ -159,7 +72,6 @@ tf_cc_binary(
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:MlirOptLib",
|
||||
"@llvm-project//mlir:MlirOptMain",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
|
||||
@ -13,33 +13,59 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/passes.h"
|
||||
//===- cubin_creator.cc -----------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file implements the function to compile a TF kernel function to a cubin.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/escaping.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Target/NVVMIR.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h"
|
||||
#include "tensorflow/core/platform/cuda_libdevice_path.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
|
||||
#endif
|
||||
|
||||
namespace mlir {
|
||||
namespace kernel_gen {
|
||||
namespace {
|
||||
using tensorflow::Status;
|
||||
using xla::InternalError;
|
||||
using xla::StatusOr;
|
||||
|
||||
xla::StatusOr<std::string> GetLibdeviceDir(
|
||||
StatusOr<std::string> GetLibdeviceDir(
|
||||
const xla::HloModuleConfig& hlo_module_config) {
|
||||
for (const std::string& cuda_root : tensorflow::CandidateCudaRoots(
|
||||
hlo_module_config.debug_options().xla_gpu_cuda_data_dir())) {
|
||||
@ -51,7 +77,7 @@ xla::StatusOr<std::string> GetLibdeviceDir(
|
||||
return libdevice_dir;
|
||||
}
|
||||
}
|
||||
return xla::InternalError(
|
||||
return InternalError(
|
||||
"Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice");
|
||||
}
|
||||
|
||||
@ -87,11 +113,34 @@ struct UnfuseBatchNormPass
|
||||
}
|
||||
};
|
||||
|
||||
struct PropagateTensorFlowABIKnowledgePass
|
||||
: public mlir::PassWrapper<PropagateTensorFlowABIKnowledgePass,
|
||||
Status LowerTfOpToLhloWithDynamicShapes(mlir::ModuleOp module) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
auto enable_if_vlog_is_on = [](mlir::Pass* pass, mlir::Operation* op) {
|
||||
return VLOG_IS_ON(1);
|
||||
};
|
||||
pm.enableIRPrinting(/*shouldPrintBeforePass=*/{},
|
||||
/*shouldPrintAfterPass=*/enable_if_vlog_is_on,
|
||||
/*printModuleScope=*/false,
|
||||
/*printAfterOnlyOnChange=*/false, llvm::dbgs());
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(false));
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
absl::make_unique<MaterializeBroadcastsPass>());
|
||||
pm.addNestedPass<mlir::FuncOp>(absl::make_unique<UnfuseBatchNormPass>());
|
||||
pm.addPass(mlir::mhlo::createLegalizeToLhloPass(
|
||||
/*results_escape_functions=*/true));
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::lmhlo::createLhloCopyRemovalPass());
|
||||
|
||||
if (failed(pm.run(module))) {
|
||||
return InternalError("Lowering TF to LHLO failed.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
struct PropagateTensorFlowABIKnowledge
|
||||
: public mlir::PassWrapper<PropagateTensorFlowABIKnowledge,
|
||||
mlir::OperationPass<mlir::LLVM::LLVMFuncOp>> {
|
||||
explicit PropagateTensorFlowABIKnowledgePass(
|
||||
mlir::FunctionType type, llvm::ArrayRef<uint32_t> same_shape_)
|
||||
explicit PropagateTensorFlowABIKnowledge(mlir::FunctionType type,
|
||||
llvm::ArrayRef<uint32_t> same_shape_)
|
||||
: func_type(type), same_shape(same_shape_) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
@ -125,7 +174,8 @@ struct PropagateTensorFlowABIKnowledgePass
|
||||
for (mlir::Type arg_type : arg_types) {
|
||||
if (!arg_type.isa<mlir::MemRefType>()) {
|
||||
func.emitError() << "argument of surrounding func is not ranked memref";
|
||||
return signalPassFailure();
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
positions.push_back(arg_pos);
|
||||
// Set alignment and aliasing on the pointers.
|
||||
@ -154,7 +204,8 @@ struct PropagateTensorFlowABIKnowledgePass
|
||||
func.emitOpError() << "same shape constraints on arguments with "
|
||||
"non-matching shapes: #"
|
||||
<< first << " and #" << same;
|
||||
return signalPassFailure();
|
||||
signalPassFailure();
|
||||
continue;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < 2 * rank; ++i) {
|
||||
@ -171,93 +222,91 @@ struct PropagateTensorFlowABIKnowledgePass
|
||||
llvm::ArrayRef<uint32_t> same_shape;
|
||||
};
|
||||
|
||||
class GpuKernelToBlobPass
|
||||
: public mlir::PassWrapper<GpuKernelToBlobPass,
|
||||
mlir::OperationPass<mlir::gpu::GPUModuleOp>> {
|
||||
public:
|
||||
GpuKernelToBlobPass(mlir::StringRef blob_annotation,
|
||||
std::pair<int32_t, int32_t> compute_capability)
|
||||
: blob_annotation_(blob_annotation),
|
||||
compute_capability_(compute_capability) {}
|
||||
Status PropagateTensorFlowABIKnowledgeToKernel(
|
||||
mlir::ModuleOp module, llvm::ArrayRef<uint32_t> same_shape) {
|
||||
// Grab the original signature from the single function.
|
||||
auto func = *module.getBody()->op_begin<mlir::FuncOp>();
|
||||
|
||||
void runOnOperation() override {
|
||||
mlir::gpu::GPUModuleOp module = getOperation();
|
||||
mlir::PassManager pm(module.getContext());
|
||||
auto enable_if_vlog_is_on = [](mlir::Pass*, mlir::Operation*) {
|
||||
return VLOG_IS_ON(1);
|
||||
};
|
||||
pm.enableIRPrinting(/*shouldPrintBeforePass=*/{},
|
||||
/*shouldPrintAfterPass=*/enable_if_vlog_is_on,
|
||||
/*printModuleScope=*/false,
|
||||
/*printAfterOnlyOnChange=*/false, llvm::dbgs());
|
||||
auto& kernel_pm = pm.nest<::mlir::gpu::GPUModuleOp>();
|
||||
kernel_pm.addNestedPass<mlir::LLVM::LLVMFuncOp>(
|
||||
absl::make_unique<PropagateTensorFlowABIKnowledge>(func.getType(),
|
||||
same_shape));
|
||||
|
||||
llvm::LLVMContext llvmContext;
|
||||
auto llvmModule = mlir::translateModuleToNVVMIR(module, llvmContext);
|
||||
if (!llvmModule) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
llvmModule->setModuleIdentifier("acme");
|
||||
llvmModule->setDataLayout(xla::gpu::nvptx::kDataLayout);
|
||||
xla::HloModuleConfig config;
|
||||
config.set_debug_options(xla::GetDebugOptionsFromFlags());
|
||||
|
||||
auto enable_fusion = [](llvm::TargetMachine* target) {
|
||||
target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast;
|
||||
};
|
||||
|
||||
auto libdevice_dir_or = GetLibdeviceDir(config);
|
||||
if (!libdevice_dir_or.ok()) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
auto ptx_or = xla::gpu::nvptx::CompileToPtx(
|
||||
llvmModule.get(), compute_capability_, config,
|
||||
libdevice_dir_or.ValueOrDie(), enable_fusion);
|
||||
if (!ptx_or.ok()) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
auto ptx = ptx_or.ValueOrDie();
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
auto blob_or = tensorflow::se::CompileGpuAsm(
|
||||
std::get<0>(compute_capability_), std::get<1>(compute_capability_),
|
||||
ptx.c_str(), xla::gpu::PtxOptsFromConfig(config));
|
||||
if (blob_or.ok()) {
|
||||
const auto& blob = blob_or.ValueOrDie();
|
||||
std::string blob_string(blob.begin(), blob.end());
|
||||
module.setAttr(blob_annotation_,
|
||||
mlir::StringAttr::get(blob_string, &getContext()));
|
||||
return;
|
||||
} else {
|
||||
return signalPassFailure();
|
||||
}
|
||||
#endif
|
||||
return signalPassFailure();
|
||||
if (failed(pm.run(module))) {
|
||||
return InternalError("Static knowledge propagation failed.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
mlir::StringRef blob_annotation_;
|
||||
std::pair<int32_t, int32_t> compute_capability_;
|
||||
};
|
||||
|
||||
void RegisterDialects() {
|
||||
static bool init_once = []() {
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
return true;
|
||||
}();
|
||||
(void)init_once;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<mlir::FunctionPass> createMaterializeBroadcastsPass() {
|
||||
return absl::make_unique<MaterializeBroadcastsPass>();
|
||||
}
|
||||
StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
|
||||
llvm::StringRef tf_code, std::pair<int32_t, int32_t> compute_capability,
|
||||
llvm::ArrayRef<uint32_t> tile_sizes, llvm::ArrayRef<uint32_t> same_shape,
|
||||
llvm::ArrayRef<uint32_t> unroll_factors) {
|
||||
RegisterDialects();
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
|
||||
|
||||
std::unique_ptr<mlir::FunctionPass> createUnfuseBatchNormPass() {
|
||||
return absl::make_unique<UnfuseBatchNormPass>();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get()));
|
||||
{
|
||||
xla::mlir_gpu::LowerLHLOToGPUOptions options;
|
||||
options.tile_sizes = tile_sizes;
|
||||
options.unroll_factors = unroll_factors;
|
||||
options.collapse_parallel_loops = false;
|
||||
options.use_approximations = true;
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerLHLOToGPU(module.get(), options));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
PropagateTensorFlowABIKnowledgeToKernel(module.get(), same_shape));
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<mlir::LLVM::LLVMFuncOp>>
|
||||
createPropagateTensorFlowABIKnowledgePass(mlir::FunctionType type,
|
||||
llvm::ArrayRef<uint32_t> same_shape) {
|
||||
return absl::make_unique<PropagateTensorFlowABIKnowledgePass>(type,
|
||||
same_shape);
|
||||
}
|
||||
mlir::OwningModuleRef kernel_module =
|
||||
xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie();
|
||||
llvm::LLVMContext llvmContext;
|
||||
auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module, llvmContext);
|
||||
if (!llvmModule) {
|
||||
return InternalError("Could not translate MLIR module to NVVM");
|
||||
}
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<mlir::gpu::GPUModuleOp>>
|
||||
createGpuKernelToBlobPass(
|
||||
mlir::StringRef blob_annotation,
|
||||
const std::pair<int32_t, int32_t>& compute_capability) {
|
||||
return absl::make_unique<GpuKernelToBlobPass>(blob_annotation,
|
||||
compute_capability);
|
||||
}
|
||||
llvmModule->setModuleIdentifier("acme");
|
||||
llvmModule->setDataLayout(xla::gpu::nvptx::kDataLayout);
|
||||
|
||||
} // namespace kernel_gen
|
||||
} // namespace mlir
|
||||
xla::HloModuleConfig config;
|
||||
config.set_debug_options(xla::GetDebugOptionsFromFlags());
|
||||
|
||||
auto enable_fusion = [](llvm::TargetMachine* target) {
|
||||
target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast;
|
||||
};
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::string ptx,
|
||||
xla::gpu::nvptx::CompileToPtx(llvmModule.get(), compute_capability,
|
||||
config, libdevice_dir, enable_fusion));
|
||||
VLOG(1) << ptx;
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
return tensorflow::se::CompileGpuAsm(
|
||||
std::get<0>(compute_capability), std::get<1>(compute_capability),
|
||||
ptx.c_str(), xla::gpu::PtxOptsFromConfig(config));
|
||||
#else
|
||||
return InternalError(
|
||||
"GOOGLE_CUDA not defined. Did you specify --config=cuda ?");
|
||||
#endif
|
||||
}
|
||||
@ -13,43 +13,30 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
//===- kernel_creator.h -----------------------------------------*- C++ -*-===//
|
||||
//===- cubin_creator.h ------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file declares the function to compile a TF kernel function to a cubin.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_KERNEL_CREATOR_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_KERNEL_CREATOR_H_
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace kernel_gen {
|
||||
|
||||
// Registers necessary dialects. It should be called before creating
|
||||
// MLIRContext.
|
||||
void RegisterDialects();
|
||||
|
||||
// Converts TF code to LLVM/NVVM. If `cubin_only` is true, then the conversion
|
||||
// stops after cubin binary blob is generated. If `cubin_only` is false, lowers
|
||||
// the host side to LLVM Dialect.
|
||||
xla::StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
||||
mlir::MLIRContext& mlir_context, llvm::StringRef tf_code, bool cubin_only,
|
||||
xla::StatusOr<std::vector<uint8_t>> GenerateCubinForTfCode(
|
||||
llvm::StringRef tf_code,
|
||||
std::pair<int32_t, int32_t> compute_capability = {7, 5},
|
||||
llvm::ArrayRef<uint32_t> tile_sizes = {16, 64},
|
||||
llvm::ArrayRef<uint32_t> same_shape = {},
|
||||
llvm::ArrayRef<uint32_t> unroll_factors = {});
|
||||
|
||||
// Extracts cubin from the converted module.
|
||||
xla::StatusOr<std::string> ExtractGpuBinary(mlir::ModuleOp module);
|
||||
|
||||
} // namespace kernel_gen
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_KERNEL_CREATOR_H_
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_
|
||||
@ -48,13 +48,11 @@ Type TFFrameworkDialect::parseType(DialectAsmParser &parser) const {
|
||||
|
||||
/// Print a type registered to this dialect.
|
||||
void TFFrameworkDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||
switch (type.getKind()) {
|
||||
case TFFrameworkTypes::OpKernelContextType:
|
||||
os << "op_kernel_context";
|
||||
return;
|
||||
default:
|
||||
llvm_unreachable("unexpected TF Framework type kind");
|
||||
if (type.isa<OpKernelContextType>()) {
|
||||
os << "op_kernel_context";
|
||||
return;
|
||||
}
|
||||
llvm_unreachable("unexpected TF Framework type kind");
|
||||
}
|
||||
|
||||
template <typename OpTy>
|
||||
|
||||
@ -1,258 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
//===- kernel_creator.cc ----------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file implements the function to compile a TF kernel function to a cubin.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h"
|
||||
|
||||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project
|
||||
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project
|
||||
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project
|
||||
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/GPU/ParallelLoopMapper.h" // from @llvm-project
|
||||
#include "mlir/Dialect/GPU/Passes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/SCF/Passes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
|
||||
#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/passes.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace kernel_gen {
|
||||
namespace {
|
||||
|
||||
using tensorflow::Status;
|
||||
using xla::InternalError;
|
||||
using xla::StatusOr;
|
||||
|
||||
constexpr llvm::StringRef kGpuBinaryAttrName = "nvvm.cubin";
|
||||
|
||||
Status LowerTFtoGPU(mlir::ModuleOp module, bool cubin_only,
|
||||
llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
llvm::ArrayRef<uint32_t> unroll_factors) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
applyPassManagerCLOptions(pm);
|
||||
|
||||
pm.addPass(mlir::mhlo::createLegalizeTFPass(false));
|
||||
if (cubin_only) {
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
mlir::kernel_gen::createMaterializeBroadcastsPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
mlir::kernel_gen::createUnfuseBatchNormPass());
|
||||
pm.addPass(mlir::mhlo::createLegalizeToLhloPass(
|
||||
/*results_escape_functions=*/true));
|
||||
// Moving `AllocOp`s and inserting missing `DeallocOp`s
|
||||
pm.addPass(::mlir::createBufferPlacementPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::lmhlo::createLhloCopyRemovalPass());
|
||||
} else {
|
||||
pm.addPass(mlir::mhlo::createTransformUnrankedHloPass());
|
||||
pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass());
|
||||
pm.addPass(mlir::kernel_gen::transforms::CreateBufferizePass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
}
|
||||
|
||||
// We have to anticipate later unrolling in tiling to make sure that we get
|
||||
// the requested tiling after unrolling. Compute the new tiling here if
|
||||
// needed.
|
||||
llvm::SmallVector<unsigned, 4> tiling_for_unrolling;
|
||||
llvm::SmallVector<int64_t, 4> as_int64;
|
||||
if (!unroll_factors.empty()) {
|
||||
tiling_for_unrolling.reserve(tile_sizes.size());
|
||||
for (auto pair : llvm::zip(tile_sizes, unroll_factors)) {
|
||||
tiling_for_unrolling.push_back(std::get<0>(pair) * std::get<1>(pair));
|
||||
as_int64.push_back(std::get<1>(pair));
|
||||
}
|
||||
} else {
|
||||
tiling_for_unrolling.append(tile_sizes.begin(), tile_sizes.end());
|
||||
}
|
||||
// Transform LHLO operations to LinAlg.
|
||||
pm.addPass(::mlir::lmhlo::createLegalizeLhloToLinalgPass());
|
||||
// Fuse linalg operations.
|
||||
pm.addPass(::mlir::lmhlo::createLhloFuseLinalgPass(
|
||||
/*use_parallel_loops=*/true, tiling_for_unrolling));
|
||||
// Transform the Linalg operations inside of the loop nest into parallel
|
||||
// loops.
|
||||
pm.addPass(::mlir::createConvertLinalgToParallelLoopsPass());
|
||||
// Canonicalize the code to simplify index computations. This is needed so
|
||||
// that loop bounds have the same value.
|
||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
|
||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
|
||||
// Fuse the inner-most loops.
|
||||
pm.addPass(xla::mlir_gpu::createFuseInnerParallelLoopsPass());
|
||||
// Run CSE to ensure that loads and stores to the same subview get
|
||||
// recognized as such.
|
||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
|
||||
// Forward stores to buffers to loads.
|
||||
pm.addPass(xla::mlir_gpu::createStoreForwardingPass());
|
||||
// Remove now unused temporary buffers.
|
||||
pm.addPass(xla::mlir_gpu::createDeadTempBufferRemovalPass());
|
||||
if (!unroll_factors.empty()) {
|
||||
pm.addPass(::mlir::createParallelLoopTilingPass(as_int64));
|
||||
}
|
||||
// Some basic cleanup.
|
||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
|
||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
|
||||
// Greedily map the remaining loop to GPU hardware dimensions.
|
||||
pm.addPass(xla::mlir_gpu::createMapParallelLoopsPass());
|
||||
// Apply the mapping.
|
||||
pm.addPass(mlir::createParallelLoopToGpuPass());
|
||||
|
||||
// Embed TF Framework ops.
|
||||
if (!cubin_only) {
|
||||
pm.addPass(mlir::kernel_gen::tf_framework::createEmbedTFFrameworkPass());
|
||||
}
|
||||
|
||||
// Some basic cleanup.
|
||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
|
||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
|
||||
// Make loops with min bounds into a conditional plus static bounds.
|
||||
// Only do this if we unrolled in the first place.
|
||||
if (!unroll_factors.empty()) {
|
||||
pm.addNestedPass<::mlir::FuncOp>(mlir::createForLoopSpecializationPass());
|
||||
}
|
||||
// Approximate Tanh using standard operations.
|
||||
pm.addNestedPass<::mlir::FuncOp>(
|
||||
::mlir::mhlo::createLegalizeTanhToApproximationPass());
|
||||
// Move scalar operations into the launch to ensure smaller signatures.
|
||||
pm.addPass(xla::mlir_gpu::createMoveScalarComputationsIntoGpuLaunchPass());
|
||||
// Take launches to launches with kernels.
|
||||
pm.addPass(::mlir::createGpuKernelOutliningPass());
|
||||
|
||||
if (cubin_only) {
|
||||
// Make kernel signature deterministic so that we can call it externally.
|
||||
pm.addPass(xla::mlir_gpu::createRewriteKernelSignaturePass());
|
||||
}
|
||||
pm.addPass(::mlir::createLowerAffinePass());
|
||||
pm.addPass(::mlir::createLowerToCFGPass());
|
||||
if (failed(pm.run(module))) {
|
||||
return InternalError("Lowering to GPU kernels failed.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PropagateTensorFlowABIKnowledgeToKernel(
|
||||
mlir::ModuleOp module, llvm::ArrayRef<uint32_t> same_shape) {
|
||||
// Grab the original signature from the single function.
|
||||
auto func = *module.getBody()->op_begin<mlir::FuncOp>();
|
||||
|
||||
mlir::PassManager pm(module.getContext());
|
||||
applyPassManagerCLOptions(pm);
|
||||
auto& kernel_pm = pm.nest<::mlir::gpu::GPUModuleOp>();
|
||||
kernel_pm.addNestedPass<mlir::LLVM::LLVMFuncOp>(
|
||||
mlir::kernel_gen::createPropagateTensorFlowABIKnowledgePass(
|
||||
func.getType(), same_shape));
|
||||
|
||||
if (failed(pm.run(module))) {
|
||||
return InternalError("Static knowledge propagation failed.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LowerGPUToLLVM(mlir::ModuleOp module, bool cubin_only,
|
||||
llvm::ArrayRef<uint32_t> same_shape,
|
||||
llvm::StringRef gpu_binary_attr_name,
|
||||
std::pair<int32_t, int32_t> compute_capability) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
applyPassManagerCLOptions(pm);
|
||||
|
||||
auto& kernel_pm = pm.nest<mlir::gpu::GPUModuleOp>();
|
||||
if (cubin_only) {
|
||||
// Grab the original signature from the single function.
|
||||
auto func = *module.getBody()->op_begin<mlir::FuncOp>();
|
||||
kernel_pm.addNestedPass<mlir::LLVM::LLVMFuncOp>(
|
||||
mlir::kernel_gen::createPropagateTensorFlowABIKnowledgePass(
|
||||
func.getType(), same_shape));
|
||||
}
|
||||
kernel_pm.addPass(mlir::createStripDebugInfoPass());
|
||||
kernel_pm.addPass(mlir::kernel_gen::createGpuKernelToBlobPass(
|
||||
gpu_binary_attr_name, compute_capability));
|
||||
|
||||
if (!cubin_only) {
|
||||
pm.addPass(mlir::kernel_gen::tf_framework::
|
||||
createTestTFFrameworkLegalizeToLLVMPass());
|
||||
pm.addPass(mlir::createGpuToLLVMConversionPass(gpu_binary_attr_name));
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createCSEPass());
|
||||
}
|
||||
|
||||
return failed(pm.run(module)) ? InternalError("Lowering to LLVM IR failed.")
|
||||
: Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void RegisterDialects() {
|
||||
static bool init_once = []() {
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
return true;
|
||||
}();
|
||||
(void)init_once;
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
||||
mlir::MLIRContext& context, llvm::StringRef tf_code, bool cubin_only,
|
||||
std::pair<int32_t, int32_t> compute_capability,
|
||||
llvm::ArrayRef<uint32_t> tile_sizes, llvm::ArrayRef<uint32_t> same_shape,
|
||||
llvm::ArrayRef<uint32_t> unroll_factors) {
|
||||
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
|
||||
TF_RETURN_IF_ERROR(
|
||||
LowerTFtoGPU(module.get(), cubin_only, tile_sizes, unroll_factors));
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
|
||||
TF_RETURN_IF_ERROR(LowerGPUToLLVM(module.get(), cubin_only, same_shape,
|
||||
kGpuBinaryAttrName, compute_capability));
|
||||
return module;
|
||||
}
|
||||
|
||||
StatusOr<std::string> ExtractGpuBinary(mlir::ModuleOp module) {
|
||||
auto gpu_modules = module.getOps<mlir::gpu::GPUModuleOp>();
|
||||
if (std::distance(gpu_modules.begin(), gpu_modules.end()) != 1) {
|
||||
return InternalError("There should be exactly one GPU Module");
|
||||
}
|
||||
mlir::gpu::GPUModuleOp gpu_mod = *gpu_modules.begin();
|
||||
auto blob = gpu_mod.getAttrOfType<mlir::StringAttr>(kGpuBinaryAttrName);
|
||||
if (!blob) {
|
||||
return InternalError("No binary blob found in the module");
|
||||
}
|
||||
return blob.getValue().str();
|
||||
}
|
||||
|
||||
} // namespace kernel_gen
|
||||
} // namespace tensorflow
|
||||
@ -1,43 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_PASSES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_PASSES_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace kernel_gen {
|
||||
|
||||
std::unique_ptr<mlir::FunctionPass> createMaterializeBroadcastsPass();
|
||||
|
||||
std::unique_ptr<mlir::FunctionPass> createUnfuseBatchNormPass();
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<mlir::LLVM::LLVMFuncOp>>
|
||||
createPropagateTensorFlowABIKnowledgePass(mlir::FunctionType type,
|
||||
llvm::ArrayRef<uint32_t> same_shape);
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<mlir::gpu::GPUModuleOp>>
|
||||
createGpuKernelToBlobPass(
|
||||
mlir::StringRef blob_annotation,
|
||||
const std::pair<int32_t, int32_t>& compute_capability);
|
||||
|
||||
} // namespace kernel_gen
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_PASSES_H_
|
||||
@ -23,47 +23,10 @@
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h"
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace kernel_gen {
|
||||
namespace {
|
||||
|
||||
xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
|
||||
int32_t architecture, llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
llvm::ArrayRef<uint32_t> same_shape,
|
||||
llvm::ArrayRef<uint32_t> unroll_factors) {
|
||||
std::pair<int32_t, int32_t> compute_capability(architecture / 10,
|
||||
architecture % 10);
|
||||
// Read TF code.
|
||||
std::string tf_code;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReadFileToString(Env::Default(), input_file.str(), &tf_code));
|
||||
// Compile.
|
||||
RegisterDialects();
|
||||
mlir::MLIRContext mlir_context;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
mlir::OwningModuleRef module,
|
||||
GenerateKernelForTfCode(mlir_context, tf_code, /*cubin_only=*/true,
|
||||
compute_capability, tile_sizes, same_shape,
|
||||
unroll_factors));
|
||||
// Extract cubin.
|
||||
TF_ASSIGN_OR_RETURN(std::string cubin, ExtractGpuBinary(*module));
|
||||
|
||||
// Write cubin binary blob.
|
||||
TF_RETURN_IF_ERROR(
|
||||
WriteStringToFile(Env::Default(), output_file.str(), cubin));
|
||||
return xla::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace kernel_gen
|
||||
} // namespace tensorflow
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
llvm::cl::opt<std::string> input_file("input", llvm::cl::desc("input file"),
|
||||
@ -88,15 +51,38 @@ int main(int argc, char** argv) {
|
||||
llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated);
|
||||
|
||||
tensorflow::InitMlir y(&argc, &argv);
|
||||
mlir::registerPassManagerCLOptions();
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n");
|
||||
|
||||
auto status =
|
||||
tensorflow::kernel_gen::Run(input_file, output_file, architecture,
|
||||
tile_sizes, same_shape, unroll_factors);
|
||||
std::pair<int32_t, int32_t> compute_capability(architecture / 10,
|
||||
architecture % 10);
|
||||
|
||||
std::string tf_code;
|
||||
auto read_status = tensorflow::ReadFileToString(tensorflow::Env::Default(),
|
||||
input_file, &tf_code);
|
||||
if (!read_status.ok()) {
|
||||
LOG(ERROR) << read_status;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto cubin = tensorflow::kernel_gen::GenerateCubinForTfCode(
|
||||
tf_code, compute_capability, tile_sizes, same_shape, unroll_factors);
|
||||
|
||||
if (!cubin.ok()) {
|
||||
LOG(ERROR) << cubin.status();
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> cubin_data = cubin.ConsumeValueOrDie();
|
||||
|
||||
auto status = tensorflow::WriteStringToFile(
|
||||
tensorflow::Env::Default(), output_file,
|
||||
absl::string_view{reinterpret_cast<char*>(cubin_data.data()),
|
||||
cubin_data.size()});
|
||||
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status;
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -1,164 +0,0 @@
|
||||
// Copyright 2020 The TensorFlow Runtime Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//===- tf_to_kernel.cc ------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file implements the entry point to compile a tf op to a cubin file.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/Analysis/TargetLibraryInfo.h"
|
||||
#include "llvm/CodeGen/CommandFlags.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Host.h"
|
||||
#include "llvm/Support/TargetRegistry.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Target/LLVMIR.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace kernel_gen {
|
||||
namespace {
|
||||
|
||||
static llvm::codegen::RegisterCodeGenFlags CGF;
|
||||
|
||||
std::unique_ptr<llvm::TargetMachine> GetTargetMachine(llvm::Module* module) {
|
||||
llvm::Triple triple(module->getTargetTriple());
|
||||
if (triple.getTriple().empty()) {
|
||||
triple = llvm::Triple(llvm::sys::getDefaultTargetTriple());
|
||||
module->setTargetTriple(triple.getTriple());
|
||||
}
|
||||
|
||||
std::string error;
|
||||
const llvm::Target* target =
|
||||
llvm::TargetRegistry::lookupTarget("", triple, error);
|
||||
if (!target) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
llvm::TargetOptions target_options =
|
||||
llvm::codegen::InitTargetOptionsFromCodeGenFlags();
|
||||
return std::unique_ptr<llvm::TargetMachine>(target->createTargetMachine(
|
||||
triple.str(), "generic", "", target_options, llvm::Reloc::Model::PIC_));
|
||||
}
|
||||
|
||||
// Compiles the given MLIR module via LLVM into an executable binary format.
|
||||
xla::StatusOr<std::string> EmitToBinary(mlir::ModuleOp module) {
|
||||
// Translate the module.
|
||||
llvm::LLVMContext llvm_context;
|
||||
std::unique_ptr<llvm::Module> llvm_module =
|
||||
mlir::translateModuleToLLVMIR(module, llvm_context);
|
||||
|
||||
// Set up the output stream.
|
||||
llvm::SmallString<8> outstr;
|
||||
llvm::raw_svector_ostream ostream(outstr);
|
||||
ostream.SetUnbuffered();
|
||||
|
||||
llvm::legacy::PassManager codegen_passes;
|
||||
codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass(
|
||||
llvm::Triple(llvm_module->getTargetTriple())));
|
||||
|
||||
// TODO(b/163818770): Apply optimizations before dumping .a file.
|
||||
auto target_machine = GetTargetMachine(llvm_module.get());
|
||||
llvm_module->setDataLayout(target_machine->createDataLayout());
|
||||
if (target_machine->addPassesToEmitFile(codegen_passes, ostream, nullptr,
|
||||
llvm::CGFT_ObjectFile, false)) {
|
||||
return xla::InternalError("Failed add passes to emit file");
|
||||
}
|
||||
codegen_passes.run(*llvm_module);
|
||||
return ostream.str().str();
|
||||
}
|
||||
|
||||
xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
|
||||
int32_t architecture, llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
llvm::ArrayRef<uint32_t> same_shape,
|
||||
llvm::ArrayRef<uint32_t> unroll_factors) {
|
||||
std::pair<int32_t, int32_t> compute_capability(architecture / 10,
|
||||
architecture % 10);
|
||||
// Read TF code.
|
||||
std::string tf_code;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReadFileToString(Env::Default(), input_file.str(), &tf_code));
|
||||
// Compile.
|
||||
RegisterDialects();
|
||||
mlir::MLIRContext mlir_context;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
mlir::OwningModuleRef module,
|
||||
GenerateKernelForTfCode(mlir_context, tf_code, /*cubin_only=*/false,
|
||||
compute_capability, tile_sizes, same_shape,
|
||||
unroll_factors));
|
||||
// Get binary.
|
||||
TF_ASSIGN_OR_RETURN(std::string binary, EmitToBinary(*module));
|
||||
|
||||
// Write .a file.
|
||||
TF_RETURN_IF_ERROR(
|
||||
WriteStringToFile(Env::Default(), output_file.str(), binary));
|
||||
return xla::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace kernel_gen
|
||||
} // namespace tensorflow
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
llvm::cl::opt<std::string> input_file("input", llvm::cl::desc("input file"),
|
||||
llvm::cl::value_desc("filename"),
|
||||
llvm::cl::init("foo.mlir"));
|
||||
llvm::cl::opt<std::string> output_file(
|
||||
"output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"),
|
||||
llvm::cl::init("foo.bin"));
|
||||
llvm::cl::opt<int32_t> architecture(
|
||||
"arch", llvm::cl::desc("target architecture (e.g. 50 for sm_50)"),
|
||||
llvm::cl::init(50));
|
||||
llvm::cl::list<uint32_t> tile_sizes(
|
||||
"tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore,
|
||||
llvm::cl::CommaSeparated);
|
||||
llvm::cl::list<uint32_t> unroll_factors(
|
||||
"unroll_factors",
|
||||
llvm::cl::desc("factors to unroll by, separated by commas"),
|
||||
llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated);
|
||||
llvm::cl::list<uint32_t> same_shape(
|
||||
"same_shape",
|
||||
llvm::cl::desc("arguments with same shape, separated by commas"),
|
||||
llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated);
|
||||
|
||||
tensorflow::InitMlir y(&argc, &argv);
|
||||
llvm::InitializeNativeTarget();
|
||||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
mlir::registerPassManagerCLOptions();
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n");
|
||||
|
||||
auto status =
|
||||
tensorflow::kernel_gen::Run(input_file, output_file, architecture,
|
||||
tile_sizes, same_shape, unroll_factors);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status;
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@ -77,7 +77,6 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_llvm",
|
||||
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:GPUDialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LLVMDialect",
|
||||
"@llvm-project//mlir:LLVMTransforms",
|
||||
|
||||
@ -15,7 +15,6 @@ limitations under the License.
|
||||
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
@ -53,11 +52,10 @@ class TestTFFrameworkToLLVMPass
|
||||
// Set target.
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
target.addLegalDialect<gpu::GPUDialect>();
|
||||
target.addIllegalDialect<tf_framework::TFFrameworkDialect>();
|
||||
target.addIllegalOp<LLVM::DialectCastOp>();
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
|
||||
if (failed(applyPartialConversion(m, target, patterns))) {
|
||||
if (failed(applyFullConversion(m, target, patterns))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
@ -55,6 +55,7 @@ cc_library(
|
||||
"transforms/passes.h",
|
||||
],
|
||||
deps = [
|
||||
":attribute_importer",
|
||||
":type_to_shape",
|
||||
":xla_legalize_tf_with_tf2xla",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
@ -69,7 +70,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client/lib:conv_grad_size_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/kernels:conv_grad_shape_utils",
|
||||
"//tensorflow/core/lib/bfloat16",
|
||||
"//tensorflow/core/platform:bfloat16",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
|
||||
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/platform/bfloat16.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
@ -83,6 +83,9 @@ StatusOr<llvm::SmallVector<AffineMap, 1>> GetPermutationIfAvailable(
|
||||
strides[dim] = accumulated_stride;
|
||||
accumulated_stride *= shape.dimensions(dim);
|
||||
}
|
||||
if (accumulated_stride == 0) {
|
||||
return llvm::SmallVector<AffineMap, 1>{};
|
||||
}
|
||||
return llvm::SmallVector<AffineMap, 1>{
|
||||
makeStridedLinearLayoutMap(strides, /*offset=*/0, builder.getContext())};
|
||||
}
|
||||
|
||||
@ -8,6 +8,6 @@ HloModule TestModule
|
||||
ENTRY TestComputation {
|
||||
x = f32[3, 2]{1,0} parameter(0)
|
||||
|
||||
// CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}}) : (memref<3x2xf32>, memref<3x2xf32, #[[MAP]]>) -> ()
|
||||
// CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}}) {name = "copy.1"} : (memref<3x2xf32>, memref<3x2xf32, #[[MAP]]>) -> ()
|
||||
ROOT x.copy = f32[3, 2]{0,1} copy(x)
|
||||
}
|
||||
|
||||
@ -44,7 +44,7 @@ attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} {
|
||||
// CHECK-LABEL: func @case
|
||||
// CHECK-SAME: %[[BRANCH_INDEX:.*]]: tensor<i32>, %[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>)
|
||||
func @case(%index: tensor<i32>, %arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
|
||||
%0:2 = "tf.Case"(%index, %arg0, %arg1) {branches = [@exponential, @log, @floor]} : (tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
|
||||
%0:2 = "tf.Case"(%index, %arg0, %arg1) {branches = [@exponential, @log, @floor], is_stateless = true} : (tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
|
||||
// CHECK: %[[TUPLE_INPUT:.*]] = "mhlo.tuple"(%[[ARG0]], %[[ARG1]]) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
|
||||
// CHECK: %[[CASE:.*]]:2 = "mhlo.case"(%[[BRANCH_INDEX]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]]) ( {
|
||||
// CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple<tensor<f32>, tensor<f32>>):
|
||||
|
||||
@ -265,6 +265,31 @@ func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2
|
||||
return %0#0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: bessel_i0e
|
||||
func @bessel_i0e(%arg0: tensor<3xf16>, %arg1: tensor<3xf32>, %arg2: tensor<3xf64>) -> (tensor<3xf16>, tensor<3xf32>, tensor<3xf64>) {
|
||||
// CHECK-NOT: tf.BesselI0e
|
||||
%0 = "tf.BesselI0e"(%arg0) : (tensor<3xf16>) -> (tensor<3xf16>)
|
||||
%1 = "tf.BesselI0e"(%arg1) : (tensor<3xf32>) -> (tensor<3xf32>)
|
||||
%2 = "tf.BesselI0e"(%arg2) : (tensor<3xf64>) -> (tensor<3xf64>)
|
||||
return %0, %1, %2 : tensor<3xf16>, tensor<3xf32>, tensor<3xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: bessel_i1e
|
||||
func @bessel_i1e(%arg0: tensor<3xf16>, %arg1: tensor<3xf32>, %arg2: tensor<3xf64>) -> (tensor<3xf16>, tensor<3xf32>, tensor<3xf64>) {
|
||||
// CHECK-NOT: tf.BesselI1e
|
||||
%0 = "tf.BesselI1e"(%arg0) : (tensor<3xf16>) -> (tensor<3xf16>)
|
||||
%1 = "tf.BesselI1e"(%arg1) : (tensor<3xf32>) -> (tensor<3xf32>)
|
||||
%2 = "tf.BesselI1e"(%arg2) : (tensor<3xf64>) -> (tensor<3xf64>)
|
||||
return %0, %1, %2 : tensor<3xf16>, tensor<3xf32>, tensor<3xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: diag
|
||||
func @diag(%arg0: tensor<2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK-NOT: tf.Diag
|
||||
%0 = "tf.Diag"(%arg0) : (tensor<2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
|
||||
// available but doesn't support this instance.
|
||||
}
|
||||
|
||||
@ -3484,6 +3484,20 @@ func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tenso
|
||||
return %result : tensor<2x8x8x8x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @collective_permute
|
||||
func @collective_permute(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> {
|
||||
%source_target_pairs = "tf.Const" () {
|
||||
value = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi32>
|
||||
} : () -> tensor<3x2xi32>
|
||||
|
||||
// CHECK: "mhlo.collective_permute"
|
||||
// CHECK-SAME: source_target_pairs = dense<{{\[}}[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>
|
||||
%0 = "tf.CollectivePermute"(%arg0, %source_target_pairs) {
|
||||
} : (tensor<128x32xf32>, tensor<3x2xi32>) -> tensor<128x32xf32>
|
||||
|
||||
return %0 : tensor<128x32xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cross_replica_sum
|
||||
func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> {
|
||||
%replica_groups = "tf.Const" () {
|
||||
@ -3775,7 +3789,7 @@ func @unsorted_segment_prod(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x1
|
||||
// CHECK-LABEL: @unsorted_segment_min
|
||||
func @unsorted_segment_min(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) {
|
||||
%num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: mhlo.constant dense<0x7F800000> : tensor<f32>
|
||||
// CHECK: mhlo.constant dense<3.40282347E+38> : tensor<f32>
|
||||
// CHECK: mhlo.scatter
|
||||
// CHECK: mhlo.minimum
|
||||
%0 = "tf.UnsortedSegmentMin"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>)
|
||||
@ -3785,7 +3799,7 @@ func @unsorted_segment_min(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16
|
||||
// CHECK-LABEL: @unsorted_segment_max
|
||||
func @unsorted_segment_max(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) {
|
||||
%num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: mhlo.constant dense<0xFF800000> : tensor<f32>
|
||||
// CHECK: mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||
// CHECK: mhlo.scatter
|
||||
// CHECK: mhlo.maximum
|
||||
%0 = "tf.UnsortedSegmentMax"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>)
|
||||
@ -4668,6 +4682,20 @@ func @cumsum_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<i32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Cumprod op legalizations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func @cumprod
|
||||
func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: [[INIT:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) ( {
|
||||
// CHECK: mhlo.mul
|
||||
%0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.Cumprod"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
|
||||
return %1 : tensor<4xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Qr op legalization
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -4766,3 +4794,20 @@ func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> {
|
||||
// CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf64>
|
||||
return %0 : tensor<8x16xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @xla_gather
|
||||
func @xla_gather(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x1x300xf32> {
|
||||
%cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi64> } : () -> tensor<3xi64>
|
||||
|
||||
// CHECK: "mhlo.gather"
|
||||
// CHECK-SAME: dimension_numbers =
|
||||
// CHECK-SAME: collapsed_slice_dims = dense<0> : tensor<1xi64>
|
||||
// CHECK-SAME: index_vector_dim = 1 : i64
|
||||
// CHECK-SAME: offset_dims = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: start_index_map = dense<0> : tensor<1xi64>
|
||||
// CHECK-SAME: indices_are_sorted = true
|
||||
// CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>
|
||||
|
||||
%0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\01\01\12\01\00\1A\01\00 \01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<10x1x300xf32>
|
||||
return %0 : tensor<10x1x300xf32>
|
||||
}
|
||||
|
||||
@ -50,6 +50,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
|
||||
#include "tensorflow/compiler/mlir/xla/attribute_importer.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h"
|
||||
#include "tensorflow/compiler/xla/client/padding.h"
|
||||
@ -57,7 +58,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_shape_util.h"
|
||||
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/platform/bfloat16.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
@ -262,49 +263,21 @@ tensorflow::TensorShape ToTensorShape(
|
||||
sizes.begin(), sizes.end()));
|
||||
}
|
||||
|
||||
// Returns minimal value for the given int or float element type.
|
||||
static ConstOp GetMinValueForType(Type ty, Location loc,
|
||||
PatternRewriter *rewriter) {
|
||||
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
|
||||
|
||||
DenseElementsAttr attr;
|
||||
if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) {
|
||||
APFloat neg_inf =
|
||||
APFloat::getInf(float_ty.getFloatSemantics(), /*negative=*/true);
|
||||
attr = DenseElementsAttr::get(scalar_ty, neg_inf);
|
||||
} else {
|
||||
auto int_ty = ty.cast<IntegerType>();
|
||||
APInt min_val = APInt::getSignedMinValue(int_ty.getWidth());
|
||||
attr = DenseElementsAttr::get(scalar_ty, min_val);
|
||||
}
|
||||
return rewriter->create<ConstOp>(loc, attr);
|
||||
}
|
||||
|
||||
// Returns maximal value for the given int or float element type.
|
||||
static ConstOp GetMaxValueForType(Type ty, Location loc,
|
||||
PatternRewriter *rewriter) {
|
||||
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
|
||||
|
||||
DenseElementsAttr attr;
|
||||
if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) {
|
||||
APFloat pos_inf =
|
||||
APFloat::getInf(float_ty.getFloatSemantics(), /*negative=*/false);
|
||||
attr = DenseElementsAttr::get(scalar_ty, pos_inf);
|
||||
} else {
|
||||
auto int_ty = ty.cast<IntegerType>();
|
||||
APInt max_val = APInt::getSignedMaxValue(int_ty.getWidth());
|
||||
attr = DenseElementsAttr::get(scalar_ty, max_val);
|
||||
}
|
||||
return rewriter->create<ConstOp>(loc, attr);
|
||||
}
|
||||
|
||||
// Returns int or float scalar DenseElementsAttr attribute with the given
|
||||
// element type and the value.
|
||||
// Returns int, float, or complex scalar DenseElementsAttr attribute with the
|
||||
// given element type and the value.
|
||||
static ConstOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value,
|
||||
OpBuilder *builder) {
|
||||
return builder->create<ConstOp>(loc, hlo::GetScalarOfType(ty, raw_value));
|
||||
}
|
||||
|
||||
// Returns a limit scalar const op for the given type.
|
||||
// Requires FloatType or IntegerType
|
||||
static ConstOp GetScalarLimitConstOfType(Type ty, Location loc,
|
||||
hlo::ScalarLimit limit,
|
||||
OpBuilder *builder) {
|
||||
return builder->create<ConstOp>(loc, hlo::GetScalarLimitOfType(ty, limit));
|
||||
}
|
||||
|
||||
// Creates an mhlo::SliceOp where the major dimensions have full size, and
|
||||
// the minor dimensions have the provided offsets and sizes.
|
||||
static Value SliceInMinorDims(Location loc, Value v,
|
||||
@ -1065,6 +1038,21 @@ static void BuildSortComparisonBody(llvm::ArrayRef<Type> element_types,
|
||||
builder->create<mhlo::ReturnOp>(loc, compare);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XlaGather op utilities.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool HasValidGatherDims(StringAttr attr) {
|
||||
::xla::GatherDimensionNumbers dims;
|
||||
return dims.ParseFromString(attr.getValue().str());
|
||||
}
|
||||
|
||||
GatherDimensionNumbers GetGatherDimNumsAttr(StringAttr attr, Builder *builder) {
|
||||
::xla::GatherDimensionNumbers dims;
|
||||
if (!dims.ParseFromString(attr.getValue().str())) return {};
|
||||
return ::xla::ConvertGatherDimensionNumbers(dims, builder);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op converters.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2385,15 +2373,16 @@ class ConvertMaxPoolOp : public OpRewritePattern<OpTy> {
|
||||
op.input().getType().template cast<TensorType>().getElementType();
|
||||
if (!element_type.isSignlessIntOrFloat()) return failure();
|
||||
Location loc = op.getLoc();
|
||||
ConstOp init = GetMinValueForType(element_type, loc, &rewriter);
|
||||
ConstOp init = GetScalarLimitConstOfType(element_type, loc,
|
||||
hlo::kInfinityLowest, &rewriter);
|
||||
|
||||
auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
|
||||
if (!input_ty) return failure();
|
||||
DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
|
||||
input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
|
||||
auto reduce = rewriter.create<ReduceWindowOp>(
|
||||
loc, op.getType(), op.input(), init.getResult(),
|
||||
GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
|
||||
loc, op.getType(), op.input(), init, GetI64ElementsAttr(op.ksize()),
|
||||
GetI64ElementsAttr(op.strides()),
|
||||
/*base_dilations=*/DenseIntElementsAttr(),
|
||||
/*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
|
||||
BuildReduceBody<MaxOp>(element_type, &reduce.body(), &rewriter);
|
||||
@ -3636,7 +3625,8 @@ class ConvertMaxOp
|
||||
|
||||
static Value GetInitialValue(Type reduce_element_type, Location loc,
|
||||
PatternRewriter *rewriter) {
|
||||
return GetMinValueForType(reduce_element_type, loc, rewriter);
|
||||
return GetScalarLimitConstOfType(reduce_element_type, loc,
|
||||
hlo::kInfinityLowest, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
@ -3653,7 +3643,8 @@ class ConvertMinOp
|
||||
|
||||
static Value GetInitialValue(Type reduce_element_type, Location loc,
|
||||
PatternRewriter *rewriter) {
|
||||
return GetMaxValueForType(reduce_element_type, loc, rewriter);
|
||||
return GetScalarLimitConstOfType(reduce_element_type, loc,
|
||||
hlo::kInfinityMax, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
@ -3789,7 +3780,8 @@ class ConvertArgMaxOp
|
||||
|
||||
static Value GetInitialValue(Type reduce_element_type, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
return GetMinValueForType(reduce_element_type, loc, &rewriter);
|
||||
return GetScalarLimitConstOfType(reduce_element_type, loc,
|
||||
hlo::kInfinityLowest, &rewriter);
|
||||
}
|
||||
|
||||
static StringRef GetDirection() { return "GT"; }
|
||||
@ -4728,7 +4720,7 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern<OpTy> {
|
||||
auto output_type =
|
||||
RankedTensorType::get(output_shape, data_type.getElementType());
|
||||
|
||||
// Broadccast the initial value for reduction. This will become the
|
||||
// Broadcast the initial value for reduction. This will become the
|
||||
// 'operand' parameter to scatter to for the final scatter op.
|
||||
Value init = ConcreteClass::GetInitialValue(data_type.getElementType(),
|
||||
op.getLoc(), &rewriter);
|
||||
@ -4768,7 +4760,8 @@ class ConvertUnsortedSegmentMaxOp
|
||||
|
||||
static Value GetInitialValue(Type reduce_element_type, Location loc,
|
||||
PatternRewriter *rewriter) {
|
||||
return GetMinValueForType(reduce_element_type, loc, rewriter);
|
||||
return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kLowest,
|
||||
rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
@ -4781,7 +4774,8 @@ class ConvertUnsortedSegmentMinOp
|
||||
|
||||
static Value GetInitialValue(Type reduce_element_type, Location loc,
|
||||
PatternRewriter *rewriter) {
|
||||
return GetMaxValueForType(reduce_element_type, loc, rewriter);
|
||||
return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kMax,
|
||||
rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
@ -5092,17 +5086,19 @@ class ConvertXlaDynamicUpdateSliceOp
|
||||
}
|
||||
};
|
||||
|
||||
/// Converts the Cumsum TensorFlow op to the HLO ReduceWindow op by setting
|
||||
/// appropriate window dimensions, with 'add' as the reduction function. The
|
||||
/// input tensor needs to have a static shape, and 'axis' must be const. The
|
||||
/// TableGen pattern is not used for this rewrite because it involves regions.
|
||||
class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
|
||||
using OpRewritePattern<TF::CumsumOp>::OpRewritePattern;
|
||||
// Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by
|
||||
// setting appropriate window dimensions, with the given aggregation op as the
|
||||
// reduction function. The input tensor needs to have a static shape, and 'axis'
|
||||
// must be const. The TableGen pattern is not used for this rewrite because it
|
||||
// involves regions.
|
||||
template <typename OpT, typename AggregationOp>
|
||||
class ConvertCumOp : public OpRewritePattern<OpT> {
|
||||
using OpRewritePattern<OpT>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TF::CumsumOp op,
|
||||
LogicalResult matchAndRewrite(OpT op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto input = op.x();
|
||||
auto input_type = input.getType().dyn_cast<ShapedType>();
|
||||
auto input_type = input.getType().template dyn_cast<ShapedType>();
|
||||
if (!input_type || !input_type.hasStaticShape()) {
|
||||
return failure();
|
||||
}
|
||||
@ -5135,6 +5131,10 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
|
||||
// Convert if we need to enlarge the element type's bitwidth to avoid
|
||||
// precision loss.
|
||||
Type input_element_type = input_type.getElementType();
|
||||
|
||||
// TODO(hinsu): Handle complex element types.
|
||||
if (!input_element_type.isIntOrFloat()) return failure();
|
||||
|
||||
Type sum_element_type = GetSumAccumulationType(input_element_type);
|
||||
input = rewriter.create<ConvertOp>(op.getLoc(), input, sum_element_type);
|
||||
|
||||
@ -5148,8 +5148,9 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
|
||||
RankedTensorType::get({rank, 2}, rewriter.getIntegerType(64)),
|
||||
paddings);
|
||||
|
||||
Value init =
|
||||
GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter);
|
||||
int64_t init_value = (std::is_same<AggregationOp, AddOp>::value) ? 0 : 1;
|
||||
Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value,
|
||||
&rewriter);
|
||||
|
||||
auto reduce = rewriter.create<ReduceWindowOp>(
|
||||
op.getLoc(), input_type, input, init,
|
||||
@ -5157,7 +5158,7 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
|
||||
GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)),
|
||||
/*base_dilations=*/DenseIntElementsAttr(),
|
||||
/*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
|
||||
BuildReduceBody<AddOp>(sum_element_type, &reduce.body(), &rewriter);
|
||||
BuildReduceBody<AggregationOp>(sum_element_type, &reduce.body(), &rewriter);
|
||||
Value result = reduce.getResult();
|
||||
|
||||
if (op.exclusive()) {
|
||||
@ -5193,6 +5194,9 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
|
||||
}
|
||||
};
|
||||
|
||||
using ConvertCumsumOp = ConvertCumOp<TF::CumsumOp, AddOp>;
|
||||
using ConvertCumprodOp = ConvertCumOp<TF::CumprodOp, MulOp>;
|
||||
|
||||
// Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard
|
||||
// dialect lowerings. This involves extracting the shape type, extracting and
|
||||
// converting each dimension to a known integer type, and repacking into a final
|
||||
@ -5857,7 +5861,7 @@ void PopulateLegalizeTfPatterns(MLIRContext *context,
|
||||
ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp,
|
||||
ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp,
|
||||
ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp,
|
||||
ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp,
|
||||
ConvertCumprodOp, ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp,
|
||||
ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
|
||||
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op,
|
||||
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
|
||||
|
||||
@ -51,6 +51,10 @@ def GetHLOAxisFromTFAxisVariadic : NativeCodeCall<
|
||||
"$0, (*$1.begin()).getType().cast<RankedTensorType>().getRank(), "
|
||||
"&$_builder)">;
|
||||
|
||||
def CastElementsToI64Elements : NativeCodeCall<
|
||||
"hlo::ConvertElementsAttr("
|
||||
"$0, $_builder.getIntegerType(64)).cast<DenseIntElementsAttr>()">;
|
||||
|
||||
def : Pattern<
|
||||
(TF_FusedBatchNormOp:$root $x, $scale, $offset, $mean, $variance, $epsilon,
|
||||
$exponential_avg_factor, $data_format,
|
||||
@ -255,12 +259,16 @@ def : Pat<(TF_ConcatV2Op $inputs, (TF_ConstOp OneElementAttr:$axis)),
|
||||
[(HasRankedFirstOperand $inputs)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CrossReplicaSum op patterns.
|
||||
// CollectivePermute op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def CastElementsToI64Elements : NativeCodeCall<
|
||||
"hlo::ConvertElementsAttr("
|
||||
"$0, $_builder.getIntegerType(64)).cast<DenseIntElementsAttr>()">;
|
||||
def : Pat<(TF_CollectivePermuteOp $input, (TF_ConstOp $source_target_pairs)),
|
||||
(HLO_CollectivePermuteOp $input,
|
||||
(CastElementsToI64Elements $source_target_pairs))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CrossReplicaSum op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def : Pat<(TF_CrossReplicaSumOp $input, (TF_ConstOp $group_assignment)),
|
||||
(HLO_CrossReplicaSumOp $input,
|
||||
@ -660,3 +668,18 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features),
|
||||
),
|
||||
(replaceWithValue $output)
|
||||
]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XlaGather op.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ToGatherDimNumsAttr : NativeCodeCall<"GetGatherDimNumsAttr($0, &$_builder)">;
|
||||
|
||||
def HasValidGatherDims : Constraint<CPred<"HasValidGatherDims($0)">>;
|
||||
|
||||
def : Pat<(TF_XlaGatherOp $operand, $start_indices, (TF_ConstOp $slice_sizes),
|
||||
$dimension_numbers, $indices_are_sorted),
|
||||
(HLO_GatherOp $operand, $start_indices,
|
||||
(ToGatherDimNumsAttr $dimension_numbers),
|
||||
$slice_sizes, $indices_are_sorted),
|
||||
[(HasValidGatherDims $dimension_numbers)]>;
|
||||
|
||||
@ -102,6 +102,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::BatchMatMulV2Op>(),
|
||||
TypeID::get<TF::BatchToSpaceNDOp>(),
|
||||
TypeID::get<TF::BatchToSpaceOp>(),
|
||||
TypeID::get<TF::BesselI0eOp>(),
|
||||
TypeID::get<TF::BesselI1eOp>(),
|
||||
TypeID::get<TF::BiasAddGradOp>(),
|
||||
TypeID::get<TF::BiasAddOp>(),
|
||||
TypeID::get<TF::BitwiseAndOp>(),
|
||||
@ -116,6 +118,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::CrossOp>(),
|
||||
TypeID::get<TF::DataFormatDimMapOp>(),
|
||||
TypeID::get<TF::DataFormatVecPermuteOp>(),
|
||||
TypeID::get<TF::DiagOp>(),
|
||||
TypeID::get<TF::DigammaOp>(),
|
||||
TypeID::get<TF::DivNoNanOp>(),
|
||||
TypeID::get<TF::EluGradOp>(),
|
||||
|
||||
@ -34,7 +34,6 @@ limitations under the License.
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassOptions.h" // from @llvm-project
|
||||
#include "mlir/Translation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
||||
@ -182,7 +181,10 @@ template <typename OpType>
|
||||
StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
|
||||
HloInstruction* instr) {
|
||||
Location loc = getLocation(instr);
|
||||
ArrayRef<std::pair<Identifier, Attribute>> attrs;
|
||||
std::pair<Identifier, Attribute> attrs[] = {
|
||||
{Identifier::get("name", builder_.getContext()),
|
||||
builder_.getStringAttr(instr->name())},
|
||||
};
|
||||
ArrayRef<Type> rets{};
|
||||
|
||||
llvm::SmallVector<Value, 4> operands;
|
||||
@ -252,15 +254,14 @@ Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<mlir::Operation*> LhloDialectEmitter::EmitSortOp(
|
||||
HloInstruction* instr) {
|
||||
StatusOr<lmhlo::SortOp> LhloDialectEmitter::EmitSortOp(HloInstruction* instr) {
|
||||
TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs<lmhlo::SortOp>(instr));
|
||||
auto* sort_instr = ::xla::Cast<::xla::HloSortInstruction>(instr);
|
||||
sort.dimensionAttr(builder_.getI64IntegerAttr(sort_instr->sort_dimension()));
|
||||
sort.is_stableAttr(builder_.getBoolAttr(sort_instr->is_stable()));
|
||||
TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
|
||||
*sort_instr->called_computations()[0], &sort.comparator(), &builder_));
|
||||
return sort.getOperation();
|
||||
return sort;
|
||||
}
|
||||
|
||||
Status LhloDialectEmitter::HandleSort(HloInstruction* instr) {
|
||||
|
||||
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
|
||||
@ -41,7 +42,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
||||
builder_(module.getContext()),
|
||||
i8_type_(builder_.getIntegerType(8)) {}
|
||||
|
||||
::xla::StatusOr<mlir::Operation*> EmitSortOp(::xla::HloInstruction* instr);
|
||||
::xla::StatusOr<lmhlo::SortOp> EmitSortOp(::xla::HloInstruction* instr);
|
||||
|
||||
private:
|
||||
template <typename OpType>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user