merge with master

This commit is contained in:
Daniel Nguyen 2020-08-17 19:31:57 +00:00
commit 333c864732
503 changed files with 14384 additions and 7153 deletions

View File

@ -88,6 +88,7 @@
dataset when it is safe to do so. The optimization can be disabled via
the `experimental_optimization.reorder_data_discarding_ops` dataset
option.
* `tf.data.Options` were previously immutable and can now be overriden.
* `tf.image`:
* Added deterministic `tf.image.stateless_random_*` functions for each
`tf.image.random_*` function. Added a new op
@ -106,7 +107,8 @@
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand.
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
as an alternative to accepting a `callable` loss.
* Added `beta` parameter to FTRL optimizer to match paper.
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
to match FTRL paper (https://research.google.com/pubs/archive/41159.pdf).
* Added `mobilenet_v3` to keras application model.
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
customization of how gradients are aggregated across devices, as well as
@ -155,6 +157,14 @@
* <ADD RELEASE NOTES HERE>
* Tracing and Debugging:
* <ADD RELEASE NOTES HERE>
* `tf.train.Checkpoint`:
* Now accepts a `root` argument in the initialization, which generates a
checkpoint with a root object. This allows users to create a `Checkpoint`
object that is compatible with Keras `model.save_weights()` and
`model.load_weights`. The checkpoint is also compatible with the
checkpoint saved in the `variables/` folder in the SavedModel.
* When restoring, `save_path` can be a path to a SavedModel. The function
will automatically find the checkpoint in the SavedModel.
* Other:
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
and "denylist" where possible. Please see
@ -251,6 +261,7 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
* Mutable tables now restore checkpointed values when loaded from SavedModel.
* GPU
* TF 2.3 includes PTX kernels only for [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0 to reduce the TF pip binary size. Earlier releases included PTX for a variety of older compute capabilities.
* Remove environmental variable `TF_USE_CUDNN`.
* Others
* Retain parent namescope for ops added inside `tf.while_loop`/`tf.cond`/`tf.switch_case`.
* Update `tf.vectorized_map` to support vectorizing `tf.while_loop` and TensorList operations.

View File

@ -220,6 +220,7 @@ cc_library(
name = "logging",
srcs = ["logging.cc"],
hdrs = ["logging.h"],
visibility = ["//visibility:public"],
deps = [
":c_api_macros",
"//tensorflow/core/platform:logging",

View File

@ -240,6 +240,7 @@ tf_cuda_cc_test(
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:array_grad",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/cc/profiler",

View File

@ -30,18 +30,26 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
// TODO(b/162618595): Enable this test once we remove the check of remote
// outputs in ProcessFunctionLibraryRuntime.
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) {
TEST(CAPI, RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) {
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);

View File

@ -169,6 +169,13 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
}
if (remote_func_outputs) {
const string backing_device =
TFE_TensorHandleBackingDeviceName(retvals[0], status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(backing_device, task2_name);
}
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);

View File

@ -85,7 +85,11 @@ class GraphOperation : public TracingOperation {
return errors::FailedPrecondition(
"GraphOperation::Reset must be called before calling SetOpName.");
}
op_.reset(TF_NewOperation(g_, op_type_.c_str(), op_name));
// TODO(b/145674566): We use Graph::NewName to get a unique name here but
// this may not be consistent with python's naming policy.
mutex_lock l(g_->mu);
op_.reset(new TF_OperationDescription(g_, op_type_.c_str(),
g_->graph.NewName(op_name).c_str()));
return Status::OK();
}
const string& Name() const override { return op_type_; }

View File

@ -557,7 +557,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(add_op, "my_add1", s);
TF_AbstractOpSetOpName(add_op, "my_add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg0, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
@ -579,7 +579,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(add_op, "my_add2", s);
TF_AbstractOpSetOpName(add_op, "my_add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg1, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/c/eager/gradients.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
@ -23,25 +24,97 @@ limitations under the License.
namespace tensorflow {
namespace gradients {
Status GradientRegistry::Register(const string& op_name,
GradientFunctionFactory factory) {
namespace {
Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t,
AbstractTensorHandle** result) {
AbstractOperationPtr op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(op->Reset("ZerosLike", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
absl::StrCat("ZerosLike", ToId(t)).c_str()));
}
TF_RETURN_IF_ERROR(op->AddInput(t));
int num_outputs = 1;
std::vector<AbstractTensorHandle*> outputs(num_outputs);
TF_RETURN_IF_ERROR(
op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
*result = outputs[0];
return Status::OK();
}
} // namespace
class IncomingGradientsImpl : public IncomingGradients {
public:
explicit IncomingGradientsImpl(
absl::Span<AbstractTensorHandle* const> grad_inputs, Context* ctx,
DefaultGradientFunction* default_gradients)
: grad_inputs_(grad_inputs),
ctx_(ctx),
default_gradients_(default_gradients) {}
AbstractTensorHandle* operator[](int i) const override {
return default_gradients_->get(ctx_, grad_inputs_, i);
}
size_t size() const override { return grad_inputs_.size(); }
private:
absl::Span<AbstractTensorHandle* const> grad_inputs_;
Context* ctx_;
DefaultGradientFunction* default_gradients_;
};
AllZerosDefaultGradients::AllZerosDefaultGradients(const ForwardOperation& op)
: outputs_(op.outputs) {
for (auto output : outputs_) {
output->Ref();
}
}
AbstractTensorHandle* AllZerosDefaultGradients::get(
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs, int i) {
if (grad_inputs[i]) {
return grad_inputs[i];
}
if (cached_default_grads_[i]) {
return cached_default_grads_[i].get();
}
AbstractTensorHandle* result = nullptr;
Status s = ZerosLike(ctx->ctx, outputs_[i], &result);
if (!s.ok()) {
if (result) {
result->Unref();
}
VLOG(1) << "Failed to create ZerosLike for index " << i;
return nullptr;
}
cached_default_grads_[i].reset(result);
return result;
}
PassThroughDefaultGradients::PassThroughDefaultGradients(
const ForwardOperation& op) {}
AbstractTensorHandle* PassThroughDefaultGradients::get(
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs, int i) {
return grad_inputs[i];
}
Status GradientRegistry::Register(
const string& op_name, BackwardFunctionFactory backward_function_factory) {
auto iter = registry_.find(op_name);
if (iter != registry_.end()) {
const string error_msg = "Gradient already exists for op: " + op_name + ".";
return errors::AlreadyExists(error_msg);
}
registry_.insert({op_name, factory});
registry_.insert({op_name, backward_function_factory});
return Status::OK();
}
Status GradientRegistry::Lookup(
const ForwardOperation& op,
std::unique_ptr<GradientFunction>* grad_fn) const {
std::unique_ptr<BackwardFunction>* backward_function) const {
auto iter = registry_.find(op.op_name);
if (iter == registry_.end()) {
const string error_msg = "No gradient defined for op: " + op.op_name + ".";
return errors::NotFound(error_msg);
}
grad_fn->reset(iter->second(op));
backward_function->reset(iter->second(op));
return Status::OK();
}
@ -92,33 +165,8 @@ AbstractTensorHandle* TapeTensor::OnesLike() const {
}
return outputs[0];
}
AbstractTensorHandle* TapeTensor::ZerosLike() const {
AbstractOperationPtr op(ctx_->CreateOperation());
// TODO(srbs): Consider adding a TF_RETURN_NULLPTR_IF_ERROR.
Status s = op->Reset("ZerosLike", /*raw_device_name=*/nullptr);
if (!s.ok()) {
return nullptr;
}
if (isa<tracing::TracingOperation>(op.get())) {
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
absl::StrCat("ZerosLike", ToId(handle_)).c_str());
if (!s.ok()) {
return nullptr;
}
}
s = op->AddInput(handle_);
if (!s.ok()) {
return nullptr;
}
int num_outputs = 1;
// TODO(srbs): Figure out who is in charge of releasing this.
std::vector<AbstractTensorHandle*> outputs(num_outputs);
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
if (!s.ok()) {
return nullptr;
}
return outputs[0];
}
AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; }
// Returns the number of elements in the gradient tensor.
int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const {
@ -159,13 +207,16 @@ AbstractTensorHandle* TapeVSpace::AggregateGradients(
// Calls the passed-in backward function.
Status TapeVSpace::CallBackwardFunction(
GradientFunction* backward_function,
BackwardFunction* backward_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
std::vector<AbstractTensorHandle*>* result) const {
if (backward_function == nullptr) return Status::OK();
Context ctx = {ctx_};
return backward_function->Compute(&ctx, output_gradients, result);
IncomingGradientsImpl incoming_gradients(
output_gradients, &ctx, backward_function->GetDefaultGradientFunction());
return backward_function->GetGradientFunction()->Compute(
&ctx, incoming_gradients, result);
}
// Looks up the ID of a Gradient.
@ -373,15 +424,15 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx,
}
tape->RecordOperation(
op_->Name(), tape_tensors, input_ids, input_dtypes,
[registry, forward_op_]() -> GradientFunction* {
std::unique_ptr<GradientFunction> grad_fn;
Status s = registry.Lookup(*forward_op_, &grad_fn);
[registry, forward_op_]() -> BackwardFunction* {
std::unique_ptr<BackwardFunction> backward_fn;
Status s = registry.Lookup(*forward_op_, &backward_fn);
if (!s.ok()) {
return nullptr;
}
return grad_fn.release();
return backward_fn.release();
},
[](GradientFunction* ptr) {
[](BackwardFunction* ptr) {
if (ptr) {
delete ptr;
}

View File

@ -55,18 +55,25 @@ struct Context {
public:
AbstractContext* ctx;
};
class IncomingGradients {
public:
virtual AbstractTensorHandle* operator[](int i) const = 0;
virtual size_t size() const = 0;
virtual ~IncomingGradients() {}
};
class GradientFunction {
public:
// TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in
// `grad_inputs`.
virtual Status Compute(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
virtual Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
std::vector<AbstractTensorHandle*>* grad_outputs) = 0;
virtual ~GradientFunction() {}
};
// Metadata from the forward operation that is made available to the
// gradient registerer to instantiate a GradientFunction.
// gradient registerer to instantiate a BackwardFunction.
struct ForwardOperation {
public:
string op_name;
@ -76,18 +83,86 @@ struct ForwardOperation {
AbstractContext* ctx;
};
using GradientFunctionFactory =
std::function<GradientFunction*(const ForwardOperation& op)>;
// Map from op name to a `GradientFunctionFactory`.
class GradientRegistry {
// Interface for building default zeros gradients for op outputs which are
// missing incoming gradients. Custom implementations of this can be used to
// control which of the forward op's output tensors/their metadata needs to
// be kept around in memory to build the default zeros grad.
//
// Some common helper implementations are provided below.
class DefaultGradientFunction {
public:
Status Register(const string& op, GradientFunctionFactory factory);
Status Lookup(const ForwardOperation& op,
std::unique_ptr<GradientFunction>* grad_fn) const;
virtual AbstractTensorHandle* get(
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs,
int i) = 0;
virtual ~DefaultGradientFunction() {}
};
// Returns zeros for any `nullptr` in `grad_inputs`.
//
// This may require keeping track of all of forward op's output
// tensors and hence may incur a higher memory footprint. Use sparingly.
//
// Multiple calls to `AllZerosDefaultGradients::get` return the same tensor
// handle.
//
// The destructor of this class `Unref`'s any cached tensor handles so users of
// those tensor handles should `Ref` them in order to keep them alive if needed.
class AllZerosDefaultGradients : public DefaultGradientFunction {
public:
explicit AllZerosDefaultGradients(const ForwardOperation& op);
AbstractTensorHandle* get(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
int i) override;
private:
absl::flat_hash_map<string, GradientFunctionFactory> registry_;
// TODO(srbs): We do not always need to keep the tensors around. In immediate
// execution mode we just need to store the shape and dtype. During tracing
// we may need to keep the tensor around if the shape is not full defined.
std::vector<AbstractTensorHandle*> outputs_;
std::vector<AbstractTensorHandlePtr> cached_default_grads_;
};
// Passes through `grad_inputs` as-is. The `GradientFunction`
// will be expected to deal with nullptr in `grad_inputs` if any.
class PassThroughDefaultGradients : public DefaultGradientFunction {
public:
explicit PassThroughDefaultGradients(const ForwardOperation& op);
AbstractTensorHandle* get(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
int i) override;
};
// A `BackwardFunction` wraps a `GradientFunction` and a
// `DefaultGradientFunction`. Both are owned by this class' instance.
class BackwardFunction {
public:
BackwardFunction(GradientFunction* gradient_function,
DefaultGradientFunction* default_gradients)
: gradient_function_(gradient_function),
default_gradients_(default_gradients) {}
GradientFunction* GetGradientFunction() { return gradient_function_.get(); }
DefaultGradientFunction* GetDefaultGradientFunction() {
return default_gradients_.get();
}
private:
std::unique_ptr<GradientFunction> gradient_function_;
std::unique_ptr<DefaultGradientFunction> default_gradients_;
};
using BackwardFunctionFactory =
std::function<BackwardFunction*(const ForwardOperation& op)>;
// Map from op name to a `BackwardFunctionFactory`.
class GradientRegistry {
public:
Status Register(const string& op,
BackwardFunctionFactory backward_function_factory);
Status Lookup(const ForwardOperation& op,
std::unique_ptr<BackwardFunction>* backward_function) const;
private:
absl::flat_hash_map<string, BackwardFunctionFactory> registry_;
};
// Returns a unique id for the tensor which is used by the tape to build
@ -106,9 +181,16 @@ int64 ToId(AbstractTensorHandle* t);
// allow us to trace the data dependencies between operations and hence compute
// gradients.
//
// This also implements `ZerosLike` and `OnesLike` to create the default
// This also implements `OnesLike` to create the default
// incoming gradients for tensors which do not already have an incoming
// gradient.
//
// `ZerosLike` is not expected to be called and returns a nullptr. The creation
// of default zeros grads is handled by the `DefaultGradientFunction` registered
// for each op.
// TODO(srbs): We need to define `ZerosLike` here to keep the compiler happy.
// Figure out a way to avoid this.
// TODO(srbs): Should ZerosLike check-fail instead of returning nullptr?
class TapeTensor {
public:
TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx);
@ -123,7 +205,7 @@ class TapeTensor {
private:
AbstractTensorHandle* handle_;
// The context where OnesLike and ZerosLike ops are to be created.
// The context where OnesLike ops are to be created.
AbstractContext* ctx_;
};
@ -132,7 +214,7 @@ class TapeTensor {
// gradient and for performing gradient aggregation.
// See `tensorflow::eager::VSpace` for more details.
class TapeVSpace
: public eager::VSpace<AbstractTensorHandle, GradientFunction, TapeTensor> {
: public eager::VSpace<AbstractTensorHandle, BackwardFunction, TapeTensor> {
public:
explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {}
~TapeVSpace() override {}
@ -147,7 +229,7 @@ class TapeVSpace
// Calls the passed-in backward function.
Status CallBackwardFunction(
GradientFunction* backward_function,
BackwardFunction* backward_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
std::vector<AbstractTensorHandle*>* result) const override;
@ -168,8 +250,14 @@ class TapeVSpace
};
// A tracing/immediate-execution agnostic tape.
//
// Gradient functions defined for this library support handling null incoming
// gradients. `Tape::ComputeGradient` should be called with
// `build_default_zeros_grads=false`. Calling with
// `build_default_zeros_grads=true` (the default) is equivalent but just results
// in extra work because `TapeTensor::ZerosLike` returns a `nullptr` anyway.
using Tape = tensorflow::eager::GradientTape<AbstractTensorHandle,
GradientFunction, TapeTensor>;
BackwardFunction, TapeTensor>;
} // namespace gradients
} // namespace tensorflow

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/gradients/array_grad.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
@ -50,6 +51,7 @@ class CppGradients
Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
return Status::OK();
}
@ -94,6 +96,26 @@ Status Exp(AbstractContext* ctx, Tape* tape,
registry);
}
// Computes `IdentityN(inputs)` and records it on the tape.
Status IdentityN(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr identity_n_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(Reset(identity_n_op.get(), "IdentityN",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(identity_n_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<TracingOperation>(identity_n_op.get())
->SetOpName("my_identity_n"));
}
TF_RETURN_IF_ERROR(AddInputList(identity_n_op.get(), inputs, &forward_op));
int num_retvals = outputs.size();
return Execute(identity_n_op.get(), ctx, outputs, &num_retvals, &forward_op,
tape, registry);
}
// Computes
// y = inputs[0] + inputs[1]
// return grad(y, {inputs[0], inputs[1]})
@ -116,7 +138,8 @@ Status AddGradModel(AbstractContext* ctx,
vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads));
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto add_output : add_outputs) {
add_output->Unref();
}
@ -146,7 +169,8 @@ Status ExpGradModel(AbstractContext* ctx,
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(exp_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads));
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto exp_output : exp_outputs) {
exp_output->Unref();
}
@ -155,6 +179,41 @@ Status ExpGradModel(AbstractContext* ctx,
return Status::OK();
}
// Computes
// ignored, y = IdentityN(inputs[0], inputs[1])
// return grad(y, {inputs[0], inputs[1]})
// This should return [nullptr, 1].
Status IdentityNGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0]));
tape->Watch(ToId(inputs[1]));
vector<AbstractTensorHandle*> identity_n_outputs(2);
TF_RETURN_IF_ERROR(IdentityN(ctx, tape, inputs,
absl::MakeSpan(identity_n_outputs), registry));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(identity_n_outputs[1])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto identity_n_output : identity_n_outputs) {
identity_n_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
AbstractContext* BuildFunction(const char* fn_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -389,18 +448,79 @@ TEST_P(CppGradients, TestExpGrad) {
result_tensor = nullptr;
}
TEST_P(CppGradients, TestIdentityNGrad) {
// Pseudo-code:
//
// tape.watch(x1)
// tape.watch(x2)
// unused, y = IdentityN([x1, x2])
// outputs = tape.gradient(y, [x1, x2])
// Expected: [nullptr, 1]
//
// This test is interesting because the current implementation of GradientTape
// would return [0, 1] whereas we use build_default_zeros_grads=false here
// so we get back [nullptr, 1].
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x1;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x1.reset(x_raw);
}
AbstractTensorHandlePtr x2;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x2.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(IdentityNGradModel, ctx.get(), {x1.get(), x2.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
EXPECT_EQ(outputs[0], nullptr);
TF_Tensor* result_tensor;
s = getValue(outputs[1], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[1]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
// TODO(b/160888630): Enable this test with mlir after AddInputList is
// supported. It is needed for AddN op which is used for gradient aggregation.
// supported. It is needed for IdentityN.
// TODO(b/164171226): Enable this test with tfrt after AddInputList is
// supported. It is needed for IdentityN.
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(true, false),
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef", "mlir"),
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#endif

View File

@ -80,6 +80,21 @@ cc_library(
],
)
tf_cc_test(
name = "parallel_device_lib_test",
srcs = ["parallel_device_lib_test.cc"],
deps = [
":parallel_device_lib",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "parallel_device_testlib",
testonly = 1,

View File

@ -118,6 +118,9 @@ class DeviceThread {
int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_);
// Outputs
std::vector<TensorHandlePtr> op_outputs_ TF_GUARDED_BY(execution_mutex_);
// TF_Status is an incomplete type and so can't be stack allocated. To avoid
// unnecessary allocations each Execute call, we keep one heap-allocated
// version for the thread.
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
const std::string device_;
@ -188,6 +191,9 @@ std::vector<TensorHandlePtr> DeviceThread::Join(TF_Status* status) {
if (TF_GetCode(status_.get()) != TF_OK) {
TF_SetStatus(status, TF_GetCode(status_.get()),
TF_Message(status_.get()));
// Reset the member `status_` so future op executions (after recovery from
// the bad `status`) start with an OK status.
TF_SetStatus(status_.get(), TF_OK, "");
}
execution_state_ = ExecutionState::kIdle;
result = std::move(op_outputs_);
@ -319,21 +325,36 @@ ParallelDevice::Execute(TFE_Context* context,
std::move(device_inputs), attributes,
expected_max_outputs);
}
StatusPtr first_bad_status(nullptr);
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
DeviceThread* device_thread = device_threads_[device_index].get();
per_device_output_tensors.push_back(device_thread->Join(status));
if (TF_GetCode(status) != TF_OK) return result;
// We will run every Join even if there are bad statuses in case the user
// wants to recover and continue running ops on the parallel device (which
// would otherwise deadlock).
if (TF_GetCode(status) != TF_OK && first_bad_status == nullptr) {
first_bad_status.reset(TF_NewStatus());
TF_SetStatus(first_bad_status.get(), TF_GetCode(status),
TF_Message(status));
}
if (device_index == 0) {
first_op_output_count = per_device_output_tensors.rbegin()->size();
} else {
if (per_device_output_tensors.rbegin()->size() != first_op_output_count) {
TF_SetStatus(status, TF_INTERNAL,
if (first_bad_status == nullptr &&
per_device_output_tensors.rbegin()->size() != first_op_output_count) {
first_bad_status.reset(TF_NewStatus());
TF_SetStatus(first_bad_status.get(), TF_INTERNAL,
"Parallel ops produced different numbers of tensors.");
return result;
}
}
}
if (first_bad_status != nullptr) {
TF_SetStatus(status, TF_GetCode(first_bad_status.get()),
TF_Message(first_bad_status.get()));
return result;
}
// For each output of the original operation, pack the per-device
// TensorHandles we've computed into a single parallel TensorHandle.
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;

View File

@ -0,0 +1,84 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace parallel_device {
TEST(PARALLEL_DEVICE_LIB, TestOpWithError) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::vector<std::string> devices{
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1"};
ParallelDevice parallel_device(std::move(devices));
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> handle_op(
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT);
TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
auto outputs =
parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(),
"VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
/*expected_max_outputs=*/1, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs;
std::vector<ParallelTensor*> handle_inputs;
handle_inputs.reserve(handles.size());
for (auto& handle : handles) {
handle_inputs.push_back(handle.get());
}
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> read_op(
TFE_NewOp(context.get(), "ReadVariableOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(read_op.get(), "dtype", TF_FLOAT);
parallel_device.Execute(context.get(), handle_inputs, "ReadVariableOp",
TFE_OpGetAttrs(read_op.get()),
/*expected_max_outputs=*/1, status.get());
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
TF_SetStatus(status.get(), TF_OK, "");
// Check that ops still run successfully on the device.
parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(),
"VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
/*expected_max_outputs=*/1, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
} // namespace parallel_device
} // namespace tensorflow

View File

@ -146,13 +146,16 @@ class GradientTape {
// once) and produces the gradient of the target tensors with respect to the
// source tensors. The output gradients are used if not empty and not
// null. The result is populated with one tensor per target element.
// When running backward functions, builds zeros-like tensors for
// incoming grads which are nullptrs, unless `build_default_zeros_grads`
// is set to false.
Status ComputeGradient(
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
const gtl::ArraySlice<int64> target_tensor_ids,
const gtl::ArraySlice<int64> source_tensor_ids,
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result);
std::vector<Gradient*>* result, bool build_default_zeros_grads = true);
bool IsPersistent() const { return persistent_; }
@ -655,8 +658,8 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
const gtl::ArraySlice<int64> target_tensor_ids,
const gtl::ArraySlice<int64> source_tensor_ids,
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) {
gtl::ArraySlice<Gradient*> output_gradients, std::vector<Gradient*>* result,
bool build_default_zeros_grads) {
std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
source_tensor_ids.end());
BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
@ -717,14 +720,14 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
const int64 id = trace.output_tensor_info[i].GetID();
auto grad_it = gradients.find(id);
if (grad_it == gradients.end()) {
auto func_name_it =
FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() &&
func_name_it->second.find(i) != func_name_it->second.end()) {
out_gradients.push_back(nullptr);
} else {
out_gradients.push_back(nullptr);
zero_indices.push_back(i);
out_gradients.push_back(nullptr);
if (build_default_zeros_grads) {
auto func_name_it =
FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
if (func_name_it == FunctionsAcceptingNoneForIndicesMap()->end() ||
func_name_it->second.find(i) == func_name_it->second.end()) {
zero_indices.push_back(i);
}
}
} else {
any_gradient_nonzero = true;
@ -745,6 +748,7 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
}
}
std::vector<Gradient*> in_gradients;
DCHECK(build_default_zeros_grads || zero_indices.empty());
if (any_gradient_nonzero) {
for (const auto i : zero_indices) {
out_gradients[i] = trace.output_tensor_info[i].ZerosLike();

View File

@ -26,6 +26,7 @@ cc_library(
}),
deps = [
":aws_crypto",
"//tensorflow/c:logging",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@aws",

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h"
#include "tensorflow/c/logging.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for S3 environments.
@ -281,6 +282,7 @@ void Cleanup(TF_RandomAccessFile* file) {
static int64_t ReadS3Client(S3File* s3_file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
TF_VLog(3, "ReadFile using S3Client\n");
Aws::S3::Model::GetObjectRequest get_object_request;
get_object_request.WithBucket(s3_file->bucket).WithKey(s3_file->object);
Aws::String bytes =
@ -306,12 +308,14 @@ static int64_t ReadS3Client(S3File* s3_file, uint64_t offset, size_t n,
static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
TF_VLog(3, "Using TransferManager\n");
auto create_download_stream = [&]() {
return Aws::New<TFS3UnderlyingStream>(
"S3ReadStream",
Aws::New<Aws::Utils::Stream::PreallocatedStreamBuf>(
"S3ReadStream", reinterpret_cast<unsigned char*>(buffer), n));
};
TF_VLog(3, "Created stream to read with transferManager\n");
auto handle = s3_file->transfer_manager->DownloadFile(
s3_file->bucket, s3_file->object, offset, n, create_download_stream);
handle->WaitUntilFinished();
@ -322,6 +326,10 @@ static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n,
Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE &&
retries++ < kDownloadRetries) {
// Only failed parts will be downloaded again.
TF_VLog(
1,
"Retrying read of s3://%s/%s after failure. Current retry count: %u\n",
s3_file->bucket.c_str(), s3_file->object.c_str(), retries);
s3_file->transfer_manager->RetryDownload(handle);
handle->WaitUntilFinished();
}
@ -341,6 +349,8 @@ static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n,
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
auto s3_file = static_cast<S3File*>(file->plugin_file);
TF_VLog(1, "ReadFilefromS3 s3://%s/%s from %u for n: %u\n",
s3_file->bucket.c_str(), s3_file->object.c_str(), offset, n);
if (s3_file->use_multi_part_download)
return ReadS3TransferManager(s3_file, offset, n, buffer, status);
else
@ -416,6 +426,8 @@ void Sync(const TF_WritableFile* file, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
return;
}
TF_VLog(1, "WriteFileToS3: s3://%s/%s\n", s3_file->bucket.c_str(),
s3_file->object.c_str());
auto position = static_cast<int64_t>(s3_file->outfile->tellp());
auto handle = s3_file->transfer_manager->UploadFile(
s3_file->outfile, s3_file->bucket, s3_file->object,
@ -426,6 +438,10 @@ void Sync(const TF_WritableFile* file, TF_Status* status) {
while (handle->GetStatus() == Aws::Transfer::TransferStatus::FAILED &&
retries++ < kUploadRetries) {
// if multipart upload was used, only the failed parts will be re-sent
TF_VLog(1,
"Retrying upload of s3://%s/%s after failure. Current retry count: "
"%u\n",
s3_file->bucket.c_str(), s3_file->object.c_str(), retries);
s3_file->transfer_manager->RetryUpload(s3_file->outfile, handle);
handle->WaitUntilFinished();
}
@ -613,6 +629,7 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status) {
TF_VLog(1, "Stat on path: %s\n", path);
Aws::String bucket, object;
ParseS3Path(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
@ -737,6 +754,8 @@ static void SimpleCopyFile(const Aws::String& source,
const Aws::String& bucket_dst,
const Aws::String& object_dst, S3File* s3_file,
TF_Status* status) {
TF_VLog(1, "SimpleCopyFile from %s to %s/%s\n", bucket_dst.c_str(),
object_dst.c_str());
Aws::S3::Model::CopyObjectRequest copy_object_request;
copy_object_request.WithCopySource(source)
.WithBucket(bucket_dst)
@ -801,6 +820,8 @@ static void MultiPartCopy(const Aws::String& source,
const Aws::String& object_dst, const size_t num_parts,
const uint64_t file_size, S3File* s3_file,
TF_Status* status) {
TF_VLog(1, "MultiPartCopy from %s to %s/%s\n", bucket_dst.c_str(),
object_dst.c_str());
Aws::S3::Model::CreateMultipartUploadRequest create_multipart_upload_request;
create_multipart_upload_request.WithBucket(bucket_dst).WithKey(object_dst);
@ -827,6 +848,8 @@ static void MultiPartCopy(const Aws::String& source,
auto chunk_size =
s3_file->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD];
TF_VLog(1, "Copying from %s in %u parts of size %u each\n", source.c_str(),
num_parts, chunk_size);
size_t retries = 0;
while (retries++ < 3) {
// Queue up parts.
@ -891,6 +914,9 @@ static void MultiPartCopy(const Aws::String& source,
status);
} else {
// Retry.
TF_Log(TF_ERROR,
"Retrying failed copy of part %u due to an error with S3\n",
part_number);
num_finished_parts--;
}
}
@ -967,6 +993,7 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
TF_VLog(1, "DeleteFile: %s\n", path);
Aws::String bucket, object;
ParseS3Path(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
@ -985,6 +1012,7 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path,
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
TF_VLog(1, "CreateDir: %s\n", path);
Aws::String bucket, object;
ParseS3Path(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
@ -1026,6 +1054,7 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
TF_VLog(1, "DeleteDir: %s\n", path);
Aws::String bucket, object;
ParseS3Path(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
@ -1060,6 +1089,7 @@ void DeleteDir(const TF_Filesystem* filesystem, const char* path,
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) {
TF_VLog(1, "RenameFile from: %s to %s\n", src, dst);
Aws::String bucket_src, object_src;
ParseS3Path(src, false, &bucket_src, &object_src, status);
if (TF_GetCode(status) != TF_OK) return;
@ -1120,6 +1150,7 @@ void RenameFile(const TF_Filesystem* filesystem, const char* src,
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status) {
TF_VLog(1, "GetChildren for path: %s\n", path);
Aws::String bucket, prefix;
ParseS3Path(path, true, &bucket, &prefix, status);
if (TF_GetCode(status) != TF_OK) return -1;

View File

@ -3,6 +3,24 @@ package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "array_grad",
srcs = ["array_grad.cc"],
hdrs = [
"array_grad.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients",
"//tensorflow/core/lib/llvm_rtti",
],
)
cc_library(
name = "math_grad",
srcs = ["math_grad.cc"],

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/array_grad.h"
namespace tensorflow {
namespace gradients {
namespace {
using std::vector;
class IdentityNGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(grad_inputs.size(), nullptr);
for (int i = 0; i < grad_inputs.size(); i++) {
auto grad_input = grad_inputs[i];
// TODO(srbs): Should we add a copy contructor to AbstractTensorHandle
// that takes care of this similar to `Tensor`?
if (grad_input) {
grad_input->Ref();
}
(*grad_outputs)[i] = grad_input;
}
return Status::OK();
}
~IdentityNGradientFunction() override {}
};
} // namespace
BackwardFunction* IdentityNRegisterer(const ForwardOperation& op) {
auto gradient_function = new IdentityNGradientFunction;
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_
#include "tensorflow/c/eager/gradients.h"
namespace tensorflow {
namespace gradients {
BackwardFunction* IdentityNRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
@ -29,8 +30,7 @@ namespace {
class AddGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2);
vector<AbstractTensorHandle*> identity_outputs(1);
@ -54,8 +54,7 @@ class ExpGradientFunction : public GradientFunction {
explicit ExpGradientFunction(AbstractTensorHandle* exp) : exp_(exp) {
exp->Ref();
}
Status Compute(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
vector<AbstractTensorHandle*> conj_outputs(1);
TF_RETURN_IF_ERROR(
@ -74,12 +73,22 @@ class ExpGradientFunction : public GradientFunction {
} // namespace
GradientFunction* AddRegisterer(const ForwardOperation& op) {
return new AddGradientFunction;
BackwardFunction* AddRegisterer(const ForwardOperation& op) {
auto gradient_function = new AddGradientFunction;
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
GradientFunction* ExpRegisterer(const ForwardOperation& op) {
return new ExpGradientFunction(op.outputs[0]);
BackwardFunction* ExpRegisterer(const ForwardOperation& op) {
auto gradient_function = new ExpGradientFunction(op.outputs[0]);
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients

View File

@ -19,8 +19,8 @@ limitations under the License.
namespace tensorflow {
namespace gradients {
GradientFunction* AddRegisterer(const ForwardOperation& op);
GradientFunction* ExpRegisterer(const ForwardOperation& op);
BackwardFunction* AddRegisterer(const ForwardOperation& op);
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/platform/bfloat16.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"

View File

@ -38,6 +38,7 @@ tf_kernel_library(
"//third_party/eigen3",
],
)
tf_kernel_library(
name = "histogram_summary_op",
prefix = "histogram_summary_op",
@ -52,7 +53,6 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "merge_summary_op",
prefix = "merge_summary_op",

View File

@ -20,8 +20,8 @@ limitations under the License.
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/histogram/histogram.h"
#include "tensorflow/core/platform/bfloat16.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/platform/default/logging.h"
#include "tensorflow/core/platform/logging.h"
namespace {

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/platform/bfloat16.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"

View File

@ -65,9 +65,24 @@ static ElementsAttr getSplat(Builder* b, Value val, T constant) {
// Returns DenseElementsAttr of rank zero with the given element type and the
// value.
// Requires `ty` to be either FloatType of IntegerType.
// Requires `ty` to be either FloatType, IntegerType, or ComplexType.
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
// Enum type used to specify scalar argument to GetScalarLimitOfType.
enum ScalarLimit {
kLowest, // The scalar corresponding to numeric_limits<T>::lowest.
kInfinityLowest, // Like kMax, but returns -infinity where available.
kMax, // The scalar corresponding to numeric_limits<T>::max.
kInfinityMax, // Like kMax, but returns infinity where available.
};
// Returns a scalar limit value for the given type.
//
// The argument 'limit' describes which scalar value to return.
//
// Requires `ty` to be either FloatType or IntegerType.
DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit);
} // namespace hlo
} // namespace mlir

View File

@ -60,10 +60,76 @@ DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
if (auto float_ty = ty.dyn_cast<FloatType>()) {
APFloat value(float_ty.getFloatSemantics(), raw_value);
return DenseElementsAttr::get(scalar_ty, value);
} else if (auto int_ty = ty.dyn_cast<IntegerType>()) {
APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
return DenseElementsAttr::get(scalar_ty, value);
} else if (auto complex_ty = ty.dyn_cast<ComplexType>()) {
Type complex_element_ty = complex_ty.getElementType();
if (complex_element_ty.isF32()) {
return DenseElementsAttr::get(
scalar_ty, static_cast<std::complex<float>>(raw_value));
} else if (complex_element_ty.isF64()) {
return DenseElementsAttr::get(
scalar_ty, static_cast<std::complex<double>>(raw_value));
}
}
auto int_ty = ty.cast<IntegerType>();
APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
return DenseElementsAttr::get(scalar_ty, value);
llvm_unreachable("unsupported type");
}
static APFloat GetScalarLimitOfFloatType(FloatType float_ty,
ScalarLimit limit) {
auto &semantics = float_ty.getFloatSemantics();
switch (limit) {
case kLowest:
return APFloat::getLargest(semantics, /*negative=*/true);
case kInfinityLowest:
return APFloat::getInf(semantics, /*negative=*/true);
case kMax:
return APFloat::getLargest(semantics, /*negative=*/false);
case kInfinityMax:
return APFloat::getInf(semantics, /*negative=*/false);
}
llvm_unreachable("invalid limit");
}
// Returns a scalar value for the given integer type.
//
// The argument 'scalar' describes which scalar value to return. `integer_value`
// is used to specify the integer value for kInteger. For any other scalar,
// integer_value is ignored.
static APInt GetScalarLimitOfIntegerType(IntegerType integer_ty,
ScalarLimit limit) {
unsigned width = integer_ty.getWidth();
switch (limit) {
case kLowest:
case kInfinityLowest:
if (integer_ty.isUnsigned()) {
return APInt::getMinValue(width);
} else {
return APInt::getSignedMinValue(width);
}
case kMax:
case kInfinityMax:
if (integer_ty.isUnsigned()) {
return APInt::getMaxValue(width);
} else {
return APInt::getSignedMaxValue(width);
}
}
llvm_unreachable("invalid limit");
}
DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit) {
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
if (auto float_ty = ty.dyn_cast<FloatType>()) {
return DenseElementsAttr::get(scalar_ty,
GetScalarLimitOfFloatType(float_ty, limit));
} else if (auto integer_ty = ty.dyn_cast<IntegerType>()) {
return DenseElementsAttr::get(
scalar_ty, GetScalarLimitOfIntegerType(integer_ty, limit));
}
llvm_unreachable("unsupported type");
}
} // namespace hlo

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "llvm/Support/ToolOutputFile.h"
#include "mlir-hlo/Dialect/mhlo/IR/register.h"
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
@ -80,6 +81,8 @@ int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
// Register any pass manager command line options.
mlir::registerAsmPrinterCLOptions();
mlir::registerMLIRContextCLOptions();
mlir::registerPassManagerCLOptions();
mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");

View File

@ -29,6 +29,7 @@ filegroup(
"ir/tfl_ops.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
],
@ -227,6 +228,7 @@ cc_library(
"@llvm-project//mlir:DerivedAttributeOpInterface",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:LoopLikeInterface",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:SideEffects",
@ -500,6 +502,7 @@ gentbl(
tblgen = "//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen",
td_file = "ir/tfl_ops.td",
td_srcs = [
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
"ir/tfl_op_interfaces.td",

View File

@ -133,63 +133,59 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
return Status(error::INVALID_ARGUMENT,
"'isSigned' can only be set for 8-bits integer type");
}
switch (type.getKind()) {
case mlir::StandardTypes::F32:
return tflite::TensorType_FLOAT32;
case mlir::StandardTypes::F16:
return tflite::TensorType_FLOAT16;
case mlir::StandardTypes::F64:
return tflite::TensorType_FLOAT64;
case mlir::TF::TensorFlowTypes::STRING:
return tflite::TensorType_STRING;
case mlir::TF::TensorFlowTypes::QUINT8:
return tflite::TensorType_UINT8;
case mlir::StandardTypes::Complex: {
auto ftype = type.cast<mlir::ComplexType>().getElementType();
if (ftype && ftype.isF32()) {
return tflite::TensorType_COMPLEX64;
}
if (ftype && ftype.isF64()) {
return tflite::TensorType_COMPLEX128;
}
return Status(error::INVALID_ARGUMENT, "Unsupported type");
if (type.isF32()) {
return tflite::TensorType_FLOAT32;
} else if (type.isF16()) {
return tflite::TensorType_FLOAT16;
} else if (type.isF64()) {
return tflite::TensorType_FLOAT64;
} else if (type.isa<mlir::TF::StringType>()) {
return tflite::TensorType_STRING;
} else if (type.isa<mlir::TF::Quint8Type>()) {
return tflite::TensorType_UINT8;
} else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
auto ftype = complex_type.getElementType();
if (ftype.isF32()) {
return tflite::TensorType_COMPLEX64;
}
case mlir::StandardTypes::Integer: {
const auto& itype = type.cast<mlir::IntegerType>();
switch (itype.getWidth()) {
case 1:
return tflite::TensorType_BOOL;
case 8:
return itype.isUnsigned() ? tflite::TensorType_UINT8
: tflite::TensorType_INT8;
case 16:
return tflite::TensorType_INT16;
case 32:
return tflite::TensorType_INT32;
case 64:
return tflite::TensorType_INT64;
}
if (ftype.isF64()) {
return tflite::TensorType_COMPLEX128;
}
case mlir::quant::QuantizationTypes::UniformQuantized: {
auto qtype = type.cast<mlir::quant::UniformQuantizedType>();
return GetTFLiteType(qtype.getStorageType(), qtype.isSigned());
return Status(error::INVALID_ARGUMENT, "Unsupported type");
} else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
switch (itype.getWidth()) {
case 1:
return tflite::TensorType_BOOL;
case 8:
return itype.isUnsigned() ? tflite::TensorType_UINT8
: tflite::TensorType_INT8;
case 16:
return tflite::TensorType_INT16;
case 32:
return tflite::TensorType_INT32;
case 64:
return tflite::TensorType_INT64;
}
case mlir::quant::QuantizationTypes::UniformQuantizedPerAxis: {
auto qtype = type.cast<mlir::quant::UniformQuantizedPerAxisType>();
return GetTFLiteType(qtype.getStorageType(), qtype.isSigned());
}
case mlir::TF::TensorFlowTypes::RESOURCE: {
// Treat tf.resource values as integer values in flatbuffer.
// TODO(b/146131919): Maybe need to have a detailed design for supporting
// other resource types beyonds hash table resources and resource
// variables.
return tflite::TensorType_INT32;
}
default:
// TFLite export fills FLOAT32 for unknown data types. Returning an error
// for now for safety and this could be revisited when required.
return Status(error::INVALID_ARGUMENT, "Unsupported type");
} else if (auto q_uniform_type =
type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
return GetTFLiteType(q_uniform_type.getStorageType(),
q_uniform_type.isSigned());
} else if (auto q_peraxis_type =
type.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
return GetTFLiteType(q_peraxis_type.getStorageType(),
q_peraxis_type.isSigned());
} else if (type.isa<mlir::TF::ResourceType>()) {
// Treat tf.resource values as integer values in flatbuffer.
// TODO(b/146131919): Maybe need to have a detailed design for supporting
// other resource types beyonds hash table resources and resource
// variables.
return tflite::TensorType_INT32;
}
// TFLite export fills FLOAT32 for unknown data types. Returning an error
// for now for safety and this could be revisited when required.
return Status(error::INVALID_ARGUMENT, "Unsupported type");
}
static bool IsConst(Operation* op) {

View File

@ -95,40 +95,34 @@ static tflite::MirrorPadMode ConvertTFL_MirrorPaddingAttrForOptionWriter(
static tflite::TensorType ConvertDerivedTypeAttrForOptionWriter(
mlir::Type type, flatbuffers::FlatBufferBuilder* builder) {
switch (type.getKind()) {
case mlir::StandardTypes::F16:
return tflite::TensorType_FLOAT16;
case mlir::StandardTypes::F32:
return tflite::TensorType_FLOAT32;
case mlir::TF::TensorFlowTypes::STRING:
return tflite::TensorType_STRING;
case mlir::StandardTypes::Complex: {
auto etype = type.cast<mlir::ComplexType>().getElementType();
if (etype.isF32()) {
return tflite::TensorType_COMPLEX64;
}
llvm_unreachable("invalid complex Type in conversion");
if (type.isF16()) {
return tflite::TensorType_FLOAT16;
} else if (type.isF32()) {
return tflite::TensorType_FLOAT32;
} else if (type.isa<mlir::TF::StringType>()) {
return tflite::TensorType_STRING;
} else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
if (complex_type.getElementType().isF32()) {
return tflite::TensorType_COMPLEX64;
}
case mlir::StandardTypes::Integer: {
const auto& itype = type.cast<mlir::IntegerType>();
switch (itype.getWidth()) {
case 1:
return tflite::TensorType_BOOL;
case 8:
return tflite::TensorType_INT8;
case 16:
return tflite::TensorType_INT16;
case 32:
return tflite::TensorType_INT32;
case 64:
return tflite::TensorType_INT64;
default:
llvm_unreachable("invalid integer Type in conversion");
}
llvm_unreachable("invalid complex Type in conversion");
} else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
switch (itype.getWidth()) {
case 1:
return tflite::TensorType_BOOL;
case 8:
return tflite::TensorType_INT8;
case 16:
return tflite::TensorType_INT16;
case 32:
return tflite::TensorType_INT32;
case 64:
return tflite::TensorType_INT64;
default:
llvm_unreachable("invalid integer Type in conversion");
}
default:
llvm_unreachable("invalid Type in conversion");
}
llvm_unreachable("invalid Type in conversion");
}
// I32Attr already returns an int as required by flatbuffer builders.

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
@ -253,9 +254,8 @@ struct TensorFlowLiteInlinerInterface : public DialectInlinerInterface {
}
};
struct TensorFlowLiteOpFolderDialectInterface
: public OpFolderDialectInterface {
using OpFolderDialectInterface::OpFolderDialectInterface;
struct TensorFlowLiteDialectFoldInterface : public DialectFoldInterface {
using DialectFoldInterface::DialectFoldInterface;
// Registered hook to check if the given region, which is attached to an
// operation that is *not* isolated from above (i.e. no internal regions
@ -275,7 +275,7 @@ TensorFlowLiteDialect::TensorFlowLiteDialect(mlir::MLIRContext *context)
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
>();
addInterfaces<TensorFlowLiteInlinerInterface,
TensorFlowLiteOpFolderDialectInterface>();
TensorFlowLiteDialectFoldInterface>();
}
//===----------------------------------------------------------------------===//
@ -1028,9 +1028,12 @@ static LogicalResult Verify(PackOp op) {
// Check axis bounds.
if (input_type.hasRank()) {
int64_t axis_value = op.axis().getSExtValue();
if (abs(axis_value) > input_type.getRank())
return op.emitOpError("op attribute 'axis' is out of bounds, got ")
<< axis_value;
if (axis_value < 0) axis_value += input_type.getRank() + 1;
if (axis_value < 0 || axis_value >= input_type.getRank() + 1)
return op.emitOpError()
<< "op attribute 'axis' should be in range [-rank - 1, rank + 1), "
<< "got rank = " << input_type.getRank()
<< ", and axis = " << op.axis().getSExtValue();
}
// Make sure all inputs have the same shape and element type.
@ -1443,12 +1446,59 @@ void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
// TODO(b/133486129): Implement shape inference for unpack
static LogicalResult Verify(UnpackOp op) {
// TODO(antiagainst): Implement other checks as in
// tensorflow/lite/kernels/unpack.cc
LogicalResult UnpackOp::inferReturnTypes(
MLIRContext *context, Optional<Location> loc, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
UnpackOpAdaptor op(operands, attributes);
// TODO(jpienaar): Refactor inferReturnTypes.
if (failed(op.verify(loc.hasValue() ? *loc : UnknownLoc::get(context))))
return failure();
if (op.getOperation()->getNumResults() != op.num())
return op.emitOpError("output count should match 'num' attribute");
if (operands.size() != 1) {
return emitOptionalError(loc, "input count should be equal to 1");
}
const int64_t num_value = op.num().getInt();
auto input_type = operands[0].getType().dyn_cast<ShapedType>();
if (!input_type || !input_type.hasRank()) {
// If input is unranked, then so is output.
inferredReturnTypes.assign(
num_value, UnrankedTensorType::get(input_type.getElementType()));
return success();
}
if (input_type.getNumElements() <= 0) {
return emitOptionalError(
loc, "number of elements in input shoule be larger than 0");
}
const int64_t rank = input_type.getRank();
if (rank <= 0) {
return emitOptionalError(loc, "input should be of rank larger than 0");
}
int64_t axis_value = op.axis().getInt();
if (axis_value < 0) {
axis_value += rank;
}
if (axis_value < 0 || axis_value >= rank) {
return emitOptionalError(
loc, "attribute 'axis' should be in range [-rank, rank), got axis = ",
op.axis().getInt(), ", and rank = ", rank);
}
if (!ShapedType::isDynamic(input_type.getDimSize(axis_value)) &&
input_type.getDimSize(axis_value) != num_value) {
return emitOptionalError(loc, "output count should match 'num' attribute");
}
auto output_shape = llvm::to_vector<4>(input_type.getShape());
output_shape.erase(output_shape.begin() + axis_value);
auto output_type =
RankedTensorType::get(output_shape, input_type.getElementType());
inferredReturnTypes.assign(num_value, output_type);
return success();
}

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project

View File

@ -19,6 +19,7 @@ limitations under the License.
#define TFL_OPS
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
@ -3028,7 +3029,8 @@ def TFL_TransposeOp : TFL_Op<"transpose", [
def TFL_UnpackOp : TFL_Op<"unpack", [
NoSideEffect,
SameOperandsAndResultElementType,
SameOperandsAndResultsScale]> {
SameOperandsAndResultsScale,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Unpacks a tensor along a dimension into multiple tensors";
let description = [{
@ -3059,8 +3061,6 @@ def TFL_UnpackOp : TFL_Op<"unpack", [
TFL_VariadicTensorOf<[F32, I1, I8, UI8, I32, QI8, QUI8, I16, QI16]>:$outputs
);
let verifier = [{ return Verify(*this); }];
let hasOptions = 1;
}

View File

@ -1139,9 +1139,15 @@ func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x
// -----
func @packNegInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<2x1x4xi32> {
func @packNegInputAxis2(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x2x4xi32> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<2x1x4xi32>
%0 = "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x2x4xi32>
return %0 : tensor<1x2x4xi32>
}
func @packNegInputAxis3(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<2x1x4xi32> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = -3 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = -3 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<2x1x4xi32>
return %0 : tensor<2x1x4xi32>
}
@ -1172,7 +1178,7 @@ func @pack(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// -----
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// expected-error @+1 {{op attribute 'axis' is out of bounds, got 3}}
// expected-error @+1 {{op attribute 'axis' should be in range [-rank - 1, rank + 1), got rank = 1, and axis = 3}}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 3 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -1183,7 +1189,22 @@ func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
// CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32}
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
return %0#0 : tensor<2xi32>
}
// -----
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
// CHECK: "tfl.unpack"(%arg0) {axis = -1 : i32, num = 3 : i32}
%0:3 = "tfl.unpack"(%arg0) {axis = -1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
return %0#0 : tensor<2xi32>
}
// -----
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<3xi32> {
// CHECK: "tfl.unpack"(%arg0) {axis = -2 : i32, num = 2 : i32}
%0:2 = "tfl.unpack"(%arg0) {axis = -2 : i32, num = 2 : i32} : (tensor<2x3xi32>) -> (tensor<3xi32>, tensor<3xi32>)
return %0#0 : tensor<3xi32>
}
// -----
@ -1204,6 +1225,45 @@ func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
// -----
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
// expected-error @+1 {{attribute 'axis' should be in range [-rank, rank), got axis = 2, and rank = 2}}
%0:3 = "tfl.unpack"(%arg0) {axis = 2 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
return %0#0 : tensor<2xi32>
}
// -----
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
// expected-error @+1 {{attribute 'axis' should be in range [-rank, rank), got axis = -3, and rank = 2}}
%0:3 = "tfl.unpack"(%arg0) {axis = -3 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
return %0#0 : tensor<2xi32>
}
// -----
func @unpack(%arg0: tensor<i32>) -> tensor<2xi32> {
// expected-error @+1 {{input should be of rank larger than 0}}
%0:3 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 3 : i32} : (tensor<i32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
return %0#0 : tensor<2xi32>
}
// -----
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
// expected-error @+1 {{op inferred type incompatible with return type of operation}}
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2x1xi32>, tensor<2xi32>)
return %0#0 : tensor<2xi32>
}
// -----
func @unpack(%arg0: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) {
%0:2 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 2 : i32} : (tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>)
return %0#0, %0#1 : tensor<*xi32>, tensor<*xi32>
}
// -----
// CHECK-LABEL: testMean
func @testMean(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2xf32> {
// CHECK: "tfl.mean"(%arg0, %arg1) {keep_dims = false}

View File

@ -30,80 +30,66 @@ stream_executor::port::StatusOr<ConstantOp> CreateConstOpWithSingleValue(
Type element_type = shaped_type.getElementType();
ShapedType scalar_type = RankedTensorType::get({}, element_type);
Attribute attr;
switch (element_type.getKind()) {
case mlir::StandardTypes::F16: {
auto floatType = mlir::FloatType::getF16(element_type.getContext());
auto floatAttr =
mlir::FloatAttr::get(floatType, static_cast<float>(value));
std::vector<Attribute> floatValues({floatAttr});
attr = DenseElementsAttr::get(scalar_type, floatValues);
break;
}
case mlir::StandardTypes::BF16: {
auto floatType = mlir::FloatType::getBF16(element_type.getContext());
auto floatAttr =
mlir::FloatAttr::get(floatType, static_cast<float>(value));
std::vector<Attribute> floatValues({floatAttr});
attr = DenseElementsAttr::get(scalar_type, floatValues);
break;
}
case mlir::StandardTypes::F32: {
attr =
DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
break;
}
case mlir::StandardTypes::Complex: {
auto etype = element_type.cast<mlir::ComplexType>().getElementType();
if (etype.isF32()) {
auto dialect = etype.getContext()->getRegisteredDialect("tf");
tensorflow::TensorProto repr;
repr.set_dtype(tensorflow::DT_COMPLEX64);
if (element_type.isF16()) {
auto floatType = mlir::FloatType::getF16(element_type.getContext());
auto floatAttr = mlir::FloatAttr::get(floatType, static_cast<float>(value));
std::vector<Attribute> floatValues({floatAttr});
attr = DenseElementsAttr::get(scalar_type, floatValues);
} else if (element_type.isBF16()) {
auto floatType = mlir::FloatType::getBF16(element_type.getContext());
auto floatAttr = mlir::FloatAttr::get(floatType, static_cast<float>(value));
std::vector<Attribute> floatValues({floatAttr});
attr = DenseElementsAttr::get(scalar_type, floatValues);
} else if (element_type.isF32()) {
attr =
DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
} else if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
auto etype = complex_type.getElementType();
if (etype.isF32()) {
auto dialect = etype.getContext()->getRegisteredDialect("tf");
tensorflow::TensorProto repr;
repr.set_dtype(tensorflow::DT_COMPLEX64);
tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
shape->set_unknown_rank(false);
shape->add_dim()->set_size(int64_t{1});
std::string content;
auto complex_value =
std::complex<float>(static_cast<float>(value), 0.0f);
content.assign(reinterpret_cast<const char*>(&complex_value),
sizeof(complex_value));
repr.set_tensor_content(content);
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
shape->set_unknown_rank(false);
shape->add_dim()->set_size(int64_t{1});
std::string content;
auto complex_value = std::complex<float>(static_cast<float>(value), 0.0f);
content.assign(reinterpret_cast<const char*>(&complex_value),
sizeof(complex_value));
repr.set_tensor_content(content);
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
} else {
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
"Unsupported type");
}
} else if (auto itype = element_type.dyn_cast<mlir::IntegerType>()) {
switch (itype.getWidth()) {
case 8:
attr = DenseElementsAttr::get<int8_t>(scalar_type,
static_cast<int8_t>(value));
break;
}
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
"Unsupported type");
case 16:
attr = DenseElementsAttr::get<int16_t>(scalar_type,
static_cast<int16_t>(value));
break;
case 32:
attr = DenseElementsAttr::get<int32_t>(scalar_type,
static_cast<int32_t>(value));
break;
case 64:
attr = DenseElementsAttr::get<int64_t>(scalar_type,
static_cast<int64_t>(value));
break;
default:
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
"Unsupported type");
}
case mlir::StandardTypes::Integer: {
const auto& itype = element_type.cast<mlir::IntegerType>();
switch (itype.getWidth()) {
case 8:
attr = DenseElementsAttr::get<int8_t>(scalar_type,
static_cast<int8_t>(value));
break;
case 16:
attr = DenseElementsAttr::get<int16_t>(scalar_type,
static_cast<int16_t>(value));
break;
case 32:
attr = DenseElementsAttr::get<int32_t>(scalar_type,
static_cast<int32_t>(value));
break;
case 64:
attr = DenseElementsAttr::get<int64_t>(scalar_type,
static_cast<int64_t>(value));
break;
default:
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
"Unsupported type");
}
break;
}
default:
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
"Unsupported type");
} else {
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
"Unsupported type");
}
return rewriter->create<ConstantOp>(loc, scalar_type, attr);
}

View File

@ -32,7 +32,10 @@ void init_types(py::module& m) {
[](mlir::FunctionType& ft) { return ft.getResults().vec(); });
py::class_<mlir::FloatType, mlir::Type>(m, "FloatType")
.def("get", &mlir::FloatType::get);
.def("getBF16", &mlir::FloatType::getBF16)
.def("getF16", &mlir::FloatType::getF16)
.def("getF32", &mlir::FloatType::getF32)
.def("getF64", &mlir::FloatType::getF64);
py::class_<mlir::IntegerType, mlir::Type>(m, "IntegerType")
.def("get", py::overload_cast<unsigned, mlir::MLIRContext*>(

View File

@ -722,6 +722,7 @@ cc_library(
"//tensorflow/core:framework",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
@ -787,6 +788,7 @@ cc_library(
"transforms/tpu_extract_head_tail_outside_compilation.cc",
"transforms/tpu_extract_outside_compilation.cc",
"transforms/tpu_host_computation_expansion.cc",
"transforms/tpu_identity_pruning.cc",
"transforms/tpu_merge_variables_with_execute.cc",
"transforms/tpu_outside_compilation_cluster.cc",
"transforms/tpu_rewrite_pass.cc",
@ -1269,7 +1271,7 @@ cc_library(
name = "tf_dialect_passes",
srcs = [
"transforms/constant_fold.cc",
"transforms/dialect_hooks.cc",
"transforms/decode_attributes_hook.cc",
],
hdrs = [
"transforms/constant_fold.h",
@ -1632,6 +1634,7 @@ cc_library(
deps = [
":lower_tf_inc_gen",
":tensorflow",
":tensorflow_ops",
":tensorflow_types",
"//tensorflow/core:framework",
"@llvm-project//llvm:Support",

View File

@ -21,11 +21,13 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/CallGraph.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
@ -35,6 +37,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
@ -134,12 +137,46 @@ class BacktrackAnalysis {
return GetAnalysisForRegion(region);
}
// Returns the backtrack analysis for the given region if it exists.
// If the region has not yet been analyzed, returns llvm::None.
Optional<const InfoT*> GetAnalysisIfExists(Region& region) const {
auto it = info_map_.find(&region);
if (it == info_map_.end()) return llvm::None;
return &it->second;
}
Optional<const InfoT*> GetAnalysisIfExists(FuncOp func) const {
return GetAnalysisIfExists(func.getBody());
}
private:
llvm::SmallDenseMap<Region*, InfoT> info_map_;
};
// Analyzes all regions attached to all operations in the module.
BacktrackAnalysis::BacktrackAnalysis(ModuleOp module) {
const CallGraph call_graph(module);
// Visit functions bottom up when doing the analysis. Note that SCC iterator
// has the property that if there is an edge from SCC1->SCC2, SCC1 is visited
// after SCC2, i.e., the graph is traversed bottom up just the way we want.
auto scc_begin = llvm::scc_begin(&call_graph);
auto scc_end = llvm::scc_end(&call_graph);
for (auto& scc : make_range(scc_begin, scc_end)) {
// Each SCC node is a collection of callgraph nodes that form a cycle. We
// will visit these nodes in an arbitrary order. If a node being visited
// calls a function that has not yet been analyzed, we will not be able to
// backtrack through that function call (our analysis will be correct but
// pessimistic).
for (CallGraphNode* node : scc) {
if (node->isExternal()) continue;
Region* region = node->getCallableRegion();
GetOrCreateAnalysis(*region);
}
}
// This above call graph analysis will cover all regions attached to functions
// but we also need to analyze regions attached to other ops.
module.walk([this](Operation* op) {
for (Region& region : op->getRegions()) GetOrCreateAnalysis(region);
});
@ -160,6 +197,18 @@ Value BacktrackAnalysis::BacktrackValue(Value value) {
value = island.GetYield().getOperand(res_index);
} else if (isa<IdentityNOp, IdentityOp>(op)) {
value = op->getOperand(res_index);
} else if (auto call = dyn_cast<CallOpInterface>(op)) {
FuncOp func = dyn_cast<FuncOp>(call.resolveCallable());
if (!func) break;
// Check if the function being called has been analyzed. if not,
// we cannot backtrack the value further.
Optional<const InfoT*> callee_info = GetAnalysisIfExists(func);
if (!callee_info) break;
Optional<int> passthrough_arg = callee_info.getValue()->GetArg(res_index);
if (!passthrough_arg) break;
value = call.getArgOperands()[passthrough_arg.getValue()];
} else if (isa<tf_device::LaunchOp, tf_device::ClusterOp>(op)) {
value = op->getRegion(0).front().getTerminator()->getOperand(res_index);
} else {
break;
}
@ -359,6 +408,13 @@ ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
AddValueUniqueIDMapping(result, kUnknownResourceId);
}
}
} else if (isa<tf_device::LaunchOp, tf_device::ClusterOp>(op)) {
Region& region = op->getRegion(0);
const auto& body_info = backtrack_analysis.GetAnalysisForRegion(region);
for (auto result : filter_resources(op->getResults())) {
Value body_result = body_info.GetValue(result.getResultNumber());
PropagateInputToOutput(body_result, result);
}
} else {
assign_unknown_id_to_all(op->getResults());
}

View File

@ -41,6 +41,7 @@ tf_cuda_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <cstddef>
#include <memory>
#include "absl/strings/str_cat.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/raw_ostream.h"
@ -184,11 +185,18 @@ class MlirAbstractOp : public TracingOperation {
}
private:
// Return true is there are still unfilled ODS slots for adding more inputs.
bool IsNextODSArgAvailable();
MLIRContext* context_;
MlirFunctionContext* function_context_;
SmallVector<Value, 8> operands_;
llvm::StringMap<Attribute> attrs_;
std::unique_ptr<OperationState> state_;
// This is the index of the next ODS operand that will be added with AddInput
// or AddInput;
int current_ods_input_ = 0;
const tensorflow::OpDef* op_def_ = nullptr;
const char* op_name_ = nullptr;
string tf_op_type_;
// TODO(srbs): Use this.
@ -267,6 +275,10 @@ Status MlirAbstractOp::Reset(const char* op, const char* device_name) {
return tensorflow::errors::FailedPrecondition(
"Reset called on already built op.");
}
TF_RETURN_IF_ERROR(
tensorflow::OpRegistry::Global()->LookUpOpDef(op, &op_def_));
assert(op_def_);
tf_op_type_ = op;
std::string name = "tf.";
name += op;
@ -315,45 +327,17 @@ Status MlirAbstractOp::AddRef(Type type, Type* output_type) {
Status MlirAbstractOp::Create(ArrayRef<Value> operands,
OperationState** state) {
state_->operands = llvm::to_vector<4>(operands);
const tensorflow::OpDef* op_def;
auto node_name = state_->name.getStringRef().drop_front(
TensorFlowDialect::getDialectNamespace().size() + 1);
TF_RETURN_IF_ERROR(
tensorflow::OpRegistry::Global()->LookUpOpDef(node_name.str(), &op_def));
Builder builder(context_);
// Process operands according to the op_def and infer derived attributes.
int current_operand = 0;
for (const tensorflow::OpDef::ArgDef& input_arg : op_def->input_arg()) {
if (!input_arg.number_attr().empty()) {
// TODO(b/156122856): we don't support variadic operands.
return tensorflow::errors::Unimplemented(
"Unsupported 'number_attr' for '", input_arg.number_attr(), "'");
} else if (!input_arg.type_list_attr().empty()) {
return tensorflow::errors::InvalidArgument(
"Unsupported 'type_list_attr' for '", input_arg.number_attr(), "'");
}
if (current_operand >= operands.size()) {
return tensorflow::errors::InvalidArgument("Missing operand for '",
input_arg.name(), "'");
}
Type expected_type;
if (input_arg.type() != tensorflow::DT_INVALID) {
TF_RETURN_IF_ERROR(
ConvertDataTypeToTensor(input_arg.type(), builder, &expected_type));
Type output_type;
if (input_arg.is_ref())
TF_RETURN_IF_ERROR(AddRef(expected_type, &output_type));
expected_type = output_type;
} else {
expected_type = operands[current_operand].getType();
}
if (!input_arg.type_attr().empty()) {
attrs_[input_arg.type_attr()] = TypeAttr::get(expected_type);
}
++current_operand;
}
for (const tensorflow::OpDef::ArgDef& output_arg : op_def->output_arg()) {
if (current_ods_input_ != op_def_->input_arg_size())
return tensorflow::errors::InvalidArgument(
absl::StrCat("Mismatch in operands number: got ", current_ods_input_,
" expected ", op_def_->input_arg_size(), " ; for op ",
state_->name.getStringRef().str()));
// Process results according to the op_def and infer types for derived
// attributes.
for (const tensorflow::OpDef::ArgDef& output_arg : op_def_->output_arg()) {
int original_size = state_->types.size();
if (!output_arg.number_attr().empty()) {
// Same type repeated "repeats" times.
@ -605,12 +589,38 @@ Status MlirFunctionContext::AddParameter(tensorflow::DataType dtype,
}
Status MlirAbstractOp::AddInput(AbstractTensorHandle* input) {
if (current_ods_input_ >= op_def_->input_arg_size())
return tensorflow::errors::InvalidArgument(
absl::StrCat("More Input() (", current_ods_input_, ") calls than the ",
op_def_->input_arg_size(), " allowed input_args ; for op ",
state_->name.getStringRef().str()));
auto* operand = dyn_cast<MlirTensor>(input);
if (!operand) {
if (!operand)
return tensorflow::errors::InvalidArgument(
"Unable to cast input to MlirTensor");
}
operands_.push_back(operand->getValue());
// Get the next ArgDef and use it to infer the derived attributes associated
// to this input.
const tensorflow::OpDef::ArgDef& arg_def =
op_def_->input_arg(current_ods_input_++);
Type expected_type;
if (arg_def.type() != tensorflow::DT_INVALID) {
Builder builder(context_);
TF_RETURN_IF_ERROR(
tensorflow::ConvertDataType(arg_def.type(), builder, &expected_type));
if (arg_def.is_ref()) {
Type output_type;
TF_RETURN_IF_ERROR(AddRef(expected_type, &output_type));
expected_type = output_type;
}
} else {
expected_type = operands_.back().getType();
}
if (!arg_def.type_attr().empty())
attrs_[arg_def.type_attr()] = TypeAttr::get(expected_type);
return Status::OK();
}
Status MlirFunctionContext::Finalize(OutputList* outputs,

View File

@ -54,9 +54,6 @@ namespace tf_executor {
namespace {
using TF::DropRefType;
using TF::DropTypeSubTypes;
struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
@ -75,9 +72,8 @@ struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
}
};
struct TensorFlowExecutorOpFolderDialectInterface
: public OpFolderDialectInterface {
using OpFolderDialectInterface::OpFolderDialectInterface;
struct TensorFlowExecutorDialectFoldInterface : public DialectFoldInterface {
using DialectFoldInterface::DialectFoldInterface;
// Registered hook to check if the given region, which is attached to an
// operation that is *not* isolated from above (i.e. no internal regions
@ -100,7 +96,7 @@ TensorFlowExecutorDialect::TensorFlowExecutorDialect(MLIRContext *context)
>();
addInterfaces<TensorFlowExecutorInlinerInterface,
TensorFlowExecutorOpFolderDialectInterface>();
TensorFlowExecutorDialectFoldInterface>();
addTypes<ControlType, TokenType>();
}
@ -551,8 +547,8 @@ LogicalResult Verify(SwitchNOp switchn) {
<< operand0_tensor_type << " vs " << output_tensor_type;
}
Type broadcasted_type = OpTrait::util::getBroadcastedType(
DropRefType(DropTypeSubTypes(operand0_tensor_type)),
DropRefType(DropTypeSubTypes(output_tensor_type)));
TF::DropRefAndSubTypes(operand0_tensor_type),
TF::DropRefAndSubTypes(output_tensor_type));
if (!broadcasted_type) {
return switchn.emitOpError()
<< "expects data operand to be broadcastable with all output types"
@ -668,8 +664,8 @@ LogicalResult Verify(MergeOp merge) {
<< operand_tensor_ty << " vs " << output_tensor_ty;
}
Type broadcasted_type = OpTrait::util::getBroadcastedType(
DropRefType(DropTypeSubTypes(output_tensor_ty)),
DropRefType(DropTypeSubTypes(operand_tensor_ty)));
TF::DropRefAndSubTypes(output_tensor_ty),
TF::DropRefAndSubTypes(operand_tensor_ty));
if (!broadcasted_type)
return merge.emitOpError()
<< "expects all operands to be broadcastable with output type"

View File

@ -136,7 +136,7 @@ Inputs must be of same size and shape.
let hasFolder = 1;
}
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef, TF_CwiseBinary]>,
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns x + y element-wise.";
@ -859,15 +859,15 @@ about broadcasting
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x,
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Complex128, TF_Complex64]>:$x,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Complex128, TF_Complex64]>:$y,
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
DefaultValuedAttr<BoolAttr, "false">:$adj_y
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -1404,6 +1404,38 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_CholeskyOp : TF_Op<"Cholesky", [NoSideEffect]> {
let summary = [{
Computes the Cholesky decomposition of one or more square matrices.
}];
let description = [{
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
form square matrices.
The input has to be symmetric and positive definite. Only the lower-triangular
part of the input will be used for this operation. The upper-triangular part
will not be read.
The output is a tensor of the same shape as the input
containing the Cholesky decompositions for all input submatrices `[..., :, :]`.
**Note**: The gradient computation on GPU is faster for large matrices but
not for large batch dimensions when the submatrices are small. In this
case it might be faster to use the CPU.
}];
let arguments = (ins
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$input
);
let results = (outs
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> {
let summary = "Clips tensor values to a specified min and max.";
@ -2025,17 +2057,73 @@ and `B, D, F, H` as group 1. Thus we get the outputs:
}];
let arguments = (ins
TensorOf<[BF16, F32, I32, TF_Uint32]>:$input,
TensorOf<[BF16, F16, F32, I32, TF_Uint32]>:$input,
I32Tensor:$group_assignment
);
let results = (outs
TensorOf<[BF16, F32, I32, TF_Uint32]>:$output
TensorOf<[BF16, F16, F32, I32, TF_Uint32]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_CumprodOp : TF_Op<"Cumprod", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> {
let summary = [{
Compute the cumulative product of the tensor `x` along `axis`.
}];
let description = [{
By default, this op performs an inclusive cumprod, which means that the first
element of the input is identical to the first element of the output:
```python
tf.cumprod([a, b, c]) # => [a, a * b, a * b * c]
```
By setting the `exclusive` kwarg to `True`, an exclusive cumprod is
performed instead:
```python
tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b]
```
By setting the `reverse` kwarg to `True`, the cumprod is performed in the
opposite direction:
```python
tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c]
```
This is more efficient than using separate `tf.reverse` ops.
The `reverse` and `exclusive` kwargs can also be combined:
```python
tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1]
```
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
TF_I32OrI64Tensor:$axis,
DefaultValuedAttr<BoolAttr, "false">:$exclusive,
DefaultValuedAttr<BoolAttr, "false">:$reverse
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$out
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
let verifier = [{
return Verify(*this);
}];
}
def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> {
let summary = "Compute the cumulative sum of the tensor `x` along `axis`.";
@ -2084,6 +2172,10 @@ tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0]
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
let verifier = [{
return Verify(*this);
}];
}
def TF_DataFormatDimMapOp : TF_Op<"DataFormatDimMap", [NoSideEffect, SameOperandsAndResultType]> {
@ -2109,6 +2201,40 @@ the source data format.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_DebugIdentityV2Op : TF_Op<"DebugIdentityV2", []> {
let summary = "Debug Identity V2 Op.";
let description = [{
Provides an identity mapping from input to output, while writing the content of
the input tensor by calling DebugEventsWriter.
The semantics of the input tensor depends on tensor_debug_mode. In typical
usage, the input tensor comes directly from the user computation only when
graph_debug_mode is FULL_TENSOR (see protobuf/debug_event.proto for a
list of all the possible values of graph_debug_mode). For the other debug modes,
the input tensor should be produced by an additional op or subgraph that
computes summary information about one or more tensors.
}];
let arguments = (ins
TF_Tensor:$input,
StrAttr:$tfdbg_context_id,
StrAttr:$op_name,
DefaultValuedAttr<I64Attr, "-1">:$output_slot,
DefaultValuedAttr<I64Attr, "-1">:$tensor_debug_mode,
DefaultValuedAttr<StrArrayAttr, "{}">:$debug_urls,
DefaultValuedAttr<I64Attr, "1000">:$circular_buffer_size,
StrAttr:$tfdbg_run_id
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_DecodeAndCropJpegOp : TF_Op<"DecodeAndCropJpeg", [NoSideEffect]> {
let summary = "Decode and Crop a JPEG-encoded image to a uint8 tensor.";
@ -2421,6 +2547,40 @@ this op runs. The length of the list is returned in two cases:
);
}
def TF_DiagOp : TF_Op<"Diag", [NoSideEffect, SameOperandsAndResultElementType]> {
let summary = "Returns a diagonal tensor with a given diagonal values.";
let description = [{
Given a `diagonal`, this operation returns a tensor with the `diagonal` and
everything else padded with zeros. The diagonal is computed as follows:
Assume `diagonal` has dimensions [D1,..., Dk], then the output is a tensor of
rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where:
`output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik]` and 0 everywhere else.
For example:
```
# 'diagonal' is [1, 2, 3, 4]
tf.diag(diagonal) ==> [[1, 0, 0, 0]
[0, 2, 0, 0]
[0, 0, 3, 0]
[0, 0, 0, 4]]
```
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$diagonal
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_DiagPartOp : TF_Op<"DiagPart", [NoSideEffect]> {
let summary = "Returns the diagonal part of the tensor.";
@ -4185,6 +4345,22 @@ tf.imag(input) ==> [4.75, 5.75]
TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>;
}
def TF_InfeedDequeueOp : TF_Op<"InfeedDequeue", []> {
let summary = [{
A placeholder op for a value that will be fed into the computation.
}];
let arguments = (ins
TF_ShapeAttr:$shape
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_InitializeTableFromTextFileV2Op : TF_Op<"InitializeTableFromTextFileV2", []> {
let summary = "Initializes a table from a text file.";
@ -5673,6 +5849,74 @@ tf.matrix_set_diag(input, diagonals, k = (-1, 2), align="LEFT_RIGHT")
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_MatrixTriangularSolveOp : TF_Op<"MatrixTriangularSolve", [NoSideEffect]> {
let summary = [{
Solves systems of linear equations with upper or lower triangular matrices by backsubstitution.
}];
let description = [{
`matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form
square matrices. If `lower` is `True` then the strictly upper triangular part
of each inner-most matrix is assumed to be zero and not accessed.
If `lower` is False then the strictly lower triangular part of each inner-most
matrix is assumed to be zero and not accessed.
`rhs` is a tensor of shape `[..., M, N]`.
The output is a tensor of shape `[..., M, N]`. If `adjoint` is
`True` then the innermost matrices in `output` satisfy matrix equations
`matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
If `adjoint` is `False` then the strictly then the innermost matrices in
`output` satisfy matrix equations
`adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`.
Note, the batch shapes for the inputs only need to broadcast.
Example:
```python
a = tf.constant([[3, 0, 0, 0],
[2, 1, 0, 0],
[1, 0, 1, 0],
[1, 1, 1, 1]], dtype=tf.float32)
b = tf.constant([[4],
[2],
[4],
[2]], dtype=tf.float32)
x = tf.linalg.triangular_solve(a, b, lower=True)
x
# <tf.Tensor: shape=(4, 1), dtype=float32, numpy=
# array([[ 1.3333334 ],
# [-0.66666675],
# [ 2.6666665 ],
# [-1.3333331 ]], dtype=float32)>
# in python3 one can use `a@x`
tf.matmul(a, x)
# <tf.Tensor: shape=(4, 1), dtype=float32, numpy=
# array([[4. ],
# [2. ],
# [4. ],
# [1.9999999]], dtype=float32)>
```
}];
let arguments = (ins
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$matrix,
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$rhs,
DefaultValuedAttr<BoolAttr, "true">:$lower,
DefaultValuedAttr<BoolAttr, "false">:$adjoint
);
let results = (outs
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_MaxOp : TF_Op<"Max", [NoSideEffect]> {
let summary = [{
Computes the maximum of elements across dimensions of a tensor.
@ -5823,7 +6067,7 @@ def TF_MergeSummaryOp : TF_Op<"MergeSummary", [NoSideEffect, SameOperandsAndResu
let description = [{
This op creates a
[`Summary`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/summary.proto)
[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
protocol buffer that contains the union of all the values in the input
summaries.
@ -6054,7 +6298,7 @@ the result here is consistent with a truncating divide. E.g.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef, TF_CwiseBinary]>,
def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_SameOperandsAndResultElementTypeResolveRef]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns x * y element-wise.";
@ -7215,9 +7459,6 @@ def TF_RangeDatasetOp : TF_Op<"RangeDataset", []> {
Creates a dataset with a range of values. Corresponds to python's xrange.
}];
let description = [{
}];
let arguments = (ins
I64Tensor:$start,
I64Tensor:$stop,
@ -9464,7 +9705,7 @@ I.e., \\(y = x * x = x^2\\).
def TF_SquaredDifferenceOp : TF_Op<"SquaredDifference", [Commutative, NoSideEffect, ResultsBroadcastableShape]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns (x - y)(x - y) element-wise.";
let summary = "Returns conj(x - y)(x - y) element-wise.";
let description = [{
*NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting
@ -9576,6 +9817,49 @@ def TF_StackV2Op : TF_Op<"StackV2", []> {
);
}
def TF_StatelessMultinomialOp : TF_Op<"StatelessMultinomial", [NoSideEffect]> {
let summary = "Draws samples from a multinomial distribution.";
let arguments = (ins
TF_IntOrFpTensor:$logits,
I32Tensor:$num_samples,
TF_I32OrI64Tensor:$seed
);
let results = (outs
TF_I32OrI64Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<2>;
TF_DerivedResultTypeAttr output_dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_StatelessRandomNormalOp : TF_Op<"StatelessRandomNormal", [NoSideEffect]> {
let summary = [{
Outputs deterministic pseudorandom values from a normal distribution.
}];
let description = [{
The generated values will have mean 0 and standard deviation 1.
The outputs are a deterministic function of `shape` and `seed`.
}];
let arguments = (ins
TF_I32OrI64Tensor:$shape,
TF_I32OrI64Tensor:$seed
);
let results = (outs
TF_FpTensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>;
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_StatelessRandomUniformOp : TF_Op<"StatelessRandomUniform", [NoSideEffect]> {
let summary = [{
Outputs deterministic pseudorandom random values from a uniform distribution.
@ -9602,6 +9886,33 @@ The outputs are a deterministic function of `shape` and `seed`.
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_StatelessRandomUniformIntOp : TF_Op<"StatelessRandomUniformInt", [NoSideEffect]> {
let summary = [{
Outputs deterministic pseudorandom random integers from a uniform distribution.
}];
let description = [{
The generated values follow a uniform distribution in the range `[minval, maxval)`.
The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxval`.
}];
let arguments = (ins
TF_I32OrI64Tensor:$shape,
TF_I32OrI64Tensor:$seed,
TF_I32OrI64Tensor:$minval,
TF_I32OrI64Tensor:$maxval
);
let results = (outs
TF_I32OrI64Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>;
TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>;
}
def TF_StatelessTruncatedNormalOp : TF_Op<"StatelessTruncatedNormal", [NoSideEffect]> {
let summary = [{
Outputs deterministic pseudorandom values from a truncated normal distribution.
@ -9871,7 +10182,7 @@ Examples:
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
}
def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef, TF_CwiseBinary]>,
def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_SameOperandsAndResultElementTypeResolveRef]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns x - y element-wise.";
@ -9926,6 +10237,25 @@ retained with length 1.
>];
}
def TF_SymbolicGradientOp : TF_Op<"SymbolicGradient", [NoSideEffect]> {
let summary = [{
Computes the gradient function for function f via backpropagation.
}];
let arguments = (ins
Variadic<TF_Tensor>:$input,
SymbolRefAttr:$f
);
let results = (outs
Variadic<TF_Tensor>:$output
);
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
}
def TF_TPUCompilationResultOp : TF_Op<"TPUCompilationResult", [NoSideEffect]> {
let summary = "Returns the result of a TPU compilation.";
@ -11901,6 +12231,13 @@ https://www.tensorflow.org/performance/xla/operation_semantics#pad
def TF_XlaRecvFromHostOp : TF_Op<"XlaRecvFromHost", []> {
let summary = "An op to receive a tensor from the host.";
let description = [{
output: the tensor that will be received from the host.
Toutput: element type for output.
shape: shape for output.
key: A unique identifier for this region used to match up host transfers.
}];
let arguments = (ins
TF_ShapeAttr:$shape,
StrAttr:$key
@ -11945,6 +12282,31 @@ def TF_XlaReplicaIdOp : TF_Op<"XlaReplicaId", [NoSideEffect]> {
);
}
def TF_XlaScatterOp : TF_Op<"XlaScatter", [NoSideEffect]> {
let summary = "Wraps the XLA Scatter operator documented at";
let description = [{
https://www.tensorflow.org/xla/operation_semantics#scatter.
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$operand,
TF_I32OrI64Tensor:$scatter_indices,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$updates,
SymbolRefAttr:$update_computation,
StrAttr:$dimension_numbers,
BoolAttr:$indices_are_sorted
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_XlaSelfAdjointEigOp : TF_Op<"XlaSelfAdjointEig", [NoSideEffect]> {
let summary = [{
Computes the eigen decomposition of a batch of self-adjoint matrices
@ -11977,6 +12339,12 @@ i=0...N-1.
def TF_XlaSendToHostOp : TF_Op<"XlaSendToHost", []> {
let summary = "An op to send a tensor to the host.";
let description = [{
input: the tensor that will be sent to the host.
Tinput: element type for input.
key: A unique identifier for this region used to match up host transfers.
}];
let arguments = (ins
TF_Tensor:$input,
@ -12183,18 +12551,17 @@ Compiles a computations for execution on one or more TPU devices.
}];
let description = [{
For the internal use of the distributed TPU compiler. Note that currently only
single TPU device is supported.
For the internal use of the distributed TPU compiler.
'mlir_module' is a serialized MLIR module with a `main` function that contains
target computation.
'dynamic_shapes' contains dynamic shapes of arguments whose shapes were not
known statically at TPUReplication rewrite time.
'metadata' is a serialized TPUCompileMetadataProto describing
the shapes and types of the inputs to the computation, as well as a mapping onto
the TPU pod topology.
'program' output is a string key that is passed to the _TPUExecute op and
used to look up the program in the compilation cache.
'metadata' is a serialized TPUCompileMetadataProto describing the shapes and
types of the inputs to the computation, as well as a mapping onto the TPU pod
topology.
'program' output is a string key that is passed to the TPUExecute op and used to
look up the program in the compilation cache.
}];
let arguments = (ins
@ -12231,6 +12598,28 @@ rewrite passes must replace this op with a _TPUCompileMlir op `program` output.
);
}
def TF__UnaryOpsCompositionOp : TF_Op<"_UnaryOpsComposition", [NoSideEffect, SameOperandsAndResultType]> {
let summary = [{
*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
}];
let description = [{
expected to create these operators.
}];
let arguments = (ins
TensorOf<[F16, F32, F64]>:$x,
StrArrayAttr:$op_names
);
let results = (outs
TensorOf<[F16, F32, F64]>:$y
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF__XlaHostComputeMlirOp : TF_Op<"_XlaHostComputeMlir", []> {
let summary = [{
A pseudo-op to represent host-side computation in an XLA program.

View File

@ -55,6 +55,8 @@ limitations under the License.
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Interfaces/DecodeAttributesInterfaces.h" // from @llvm-project
#include "mlir/Interfaces/FoldInterfaces.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
@ -112,6 +114,22 @@ bool HasSingleUse(FuncOp func) {
return true;
}
struct TFConstantFoldInterface : public DialectFoldInterface {
TFConstantFoldInterface(Dialect *dialect) : DialectFoldInterface(dialect) {}
LogicalResult Fold(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) const final {
return TensorFlowDialect::constantFold(op, operands, results);
}
};
struct TFDecodeAttributesInterface : public DialectDecodeAttributesInterface {
TFDecodeAttributesInterface(Dialect *dialect)
: DialectDecodeAttributesInterface(dialect) {}
LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) const {
return TensorFlowDialect::decode(input, output);
}
};
struct TFInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
@ -206,6 +224,9 @@ std::vector<TensorFlowDialect::AdditionalOpFunction>
*TensorFlowDialect::additional_operation_hooks_ =
new std::vector<TensorFlowDialect::AdditionalOpFunction>();
TensorFlowDialect::ConstantFoldHook TensorFlowDialect::constant_fold_hook_;
TensorFlowDialect::DecodeConstantHook TensorFlowDialect::decode_constant_hook_;
TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
: Dialect(/*name=*/"tf", context, TypeID::get<TensorFlowDialect>()) {
addOperations<
@ -217,7 +238,8 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
>();
addInterfaces<TFInlinerInterface>();
addInterfaces<TFInlinerInterface, TFDecodeAttributesInterface,
TFConstantFoldInterface>();
addAttributes<ShapeAttr, FuncAttr>();
// Support unknown operations because not all TensorFlow operations are
@ -385,20 +407,20 @@ Type TensorFlowDialect::parseType(DialectAsmParser &parser) const {
// Prints a type registered to this dialect.
void TensorFlowDialect::printType(Type ty, DialectAsmPrinter &os) const {
assert(ty.isa<TensorFlowType>());
switch (ty.getKind()) {
default:
llvm_unreachable("unexpected tensorflow type kind");
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
case TensorFlowTypes::enumerant: \
os << name; \
break;
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
if (auto derived_ty = ty.dyn_cast<tftype##Type>()) { \
os << name; \
return; \
}
#define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name) \
case TensorFlowTypes::enumerant: \
Print##tftype##Type(ty.cast<tftype##Type>(), os); \
break;
if (auto derived_ty = ty.dyn_cast<tftype##Type>()) { \
Print##tftype##Type(derived_ty, os); \
return; \
}
// NOLINTNEXTLINE
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
}
llvm_unreachable("unexpected tensorflow type kind");
}
namespace {

View File

@ -116,10 +116,35 @@ class TensorFlowDialect : public Dialect {
0, (addOperation(AbstractOperation::get<Args>(*this)), 0)...};
}
using ConstantFoldHook = LogicalResult (*)(Operation *, ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &);
static void RegisterConstantFoldHook(ConstantFoldHook fn) {
constant_fold_hook_ = std::move(fn);
}
static LogicalResult constantFold(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
if (constant_fold_hook_) return constant_fold_hook_(op, operands, results);
return failure();
}
using DecodeConstantHook = LogicalResult (*)(OpaqueElementsAttr input,
ElementsAttr &output);
static void RegisterDecodeConstantHook(DecodeConstantHook fn) {
decode_constant_hook_ = std::move(fn);
}
static LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) {
if (decode_constant_hook_) return decode_constant_hook_(input, output);
return failure();
}
private:
// Hook functions which may add additional operations to the dialect.
// These are invoked at construction time.
static std::vector<AdditionalOpFunction> *additional_operation_hooks_;
static ConstantFoldHook constant_fold_hook_;
static DecodeConstantHook decode_constant_hook_;
};
} // namespace TF

View File

@ -97,10 +97,10 @@ An n-way switch statement, implementing the following:
Variadic<TF_Tensor>:$input,
Confined<SymbolRefArrayAttr, [ArrayMinCount<1>]>:$branches,
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
// Used to map StatelessCase and Case to a common op.
DefaultValuedAttr<BoolAttr, "false">:$is_stateless
// Used to map StatelessCase and Case op defined in TensorFlow to a common
// op.
BoolAttr:$is_stateless
);
let results = (outs
@ -109,10 +109,55 @@ An n-way switch statement, implementing the following:
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
let hasCanonicalizer = 1;
}
def TF_CaseRegionOp : TF_Op<"CaseRegion",
[SingleBlockImplicitTerminator<"YieldOp">, NoRegionArguments]> {
let summary = [{
An n-way switch statement which calls a single branch function.
}];
let description = [{
An n-way switch statement, implementing the following:
```
switch (branch_index) {
case 0:
output = branches[0](input);
break;
case 1:
output = branches[1](input);
break;
...
case [[nbranches-1]]:
default:
output = branches[nbranches-1](input);
break;
}
```
}];
let arguments = (ins
I32Tensor:$branch_index,
// Used to map StatelessCase and Case op defined in TensorFlow to a common
// op.
BoolAttr:$is_stateless
);
let results = (outs
Variadic<TF_Tensor>:$output
);
let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
let verifier = [{
return Verify(*this);
}];
}
// In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with
// its type encoding the tensor's shape and data type.
def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect,
@ -292,7 +337,7 @@ else_branch: A function that takes 'inputs' and returns a list of
}
def TF_YieldOp : TF_Op<"Yield",
[Terminator, ParentOneOf<["IfRegionOp", "WhileRegionOp"]>]> {
[Terminator, ParentOneOf<["CaseRegionOp", "IfRegionOp", "WhileRegionOp"]>]> {
let summary = "Yield operation";
let description = [{

View File

@ -477,7 +477,47 @@ LogicalResult FoldConstantCaseOp::matchAndRewrite(
void CaseOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<FoldConstantCaseOp>(context);
results.insert<FoldConstantCaseOp, DropAttributes<CaseOp>>(context);
}
//===----------------------------------------------------------------------===//
// CaseRegionOp
//===----------------------------------------------------------------------===//
// TODO(lyandy): Extract similar checks for CaseOp.
static LogicalResult Verify(CaseRegionOp op) {
if (op.branches().empty())
return op.emitOpError() << "expects to have at least 1 region";
if (!IsOfRankOrUnranked(op.branch_index(), 0))
return op.emitOpError() << "expects 'branch_index' to be a scalar, but got "
<< op.branch_index().getType();
DenseIntElementsAttr branch_index_attr;
if (matchPattern(op.branch_index(), m_Constant(&branch_index_attr))) {
assert(branch_index_attr.getNumElements() == 1);
int64_t branch_index = branch_index_attr.getSplatValue<IntegerAttr>()
.getValue()
.getSExtValue();
if (branch_index < 0)
return op.emitOpError()
<< "expects 'branch_index' to be non-negative, but got "
<< branch_index;
if (branch_index >= op.branches().size())
return op.emitOpError()
<< "expects 'branch_index' to be less than the number of regions ("
<< op.branches().size() << "), but got " << branch_index;
}
for (auto region_and_idx : llvm::enumerate(op.branches())) {
std::string region_name =
llvm::formatv("region #{0}", region_and_idx.index()).str();
if (failed(VerifyRegionResults(op, region_and_idx.value(), region_name)))
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
@ -734,6 +774,35 @@ void ConcatV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
context);
}
//===----------------------------------------------------------------------===//
// CumsumOp and CumprodOp
//===----------------------------------------------------------------------===//
template <typename OpT, typename std::enable_if<llvm::is_one_of<
OpT, CumsumOp, CumprodOp>::value>::type * = nullptr>
static LogicalResult Verify(OpT op) {
if (!IsOfRankOrUnranked(op.axis(), 0))
return op.emitOpError("requires scalar axis operand");
DenseIntElementsAttr axis_attr;
if (matchPattern(op.axis(), m_Constant(&axis_attr))) {
auto input_ty = op.x().getType().template dyn_cast<RankedTensorType>();
if (input_ty) {
int64_t rank = input_ty.getRank();
assert(axis_attr.getNumElements() == 1 &&
"scalar attribute should have exactly one element");
int64_t axis = (*axis_attr.begin()).getSExtValue();
if (axis < -rank || axis >= rank) {
return op.emitError()
<< "axis operand should be within range [" << -rank << ", "
<< rank << "); actual value: " << axis;
}
}
}
return success();
}
//===----------------------------------------------------------------------===//
// ConcatOffsetOp
//===----------------------------------------------------------------------===//

View File

@ -33,7 +33,7 @@ namespace TF {
static inline LogicalResult VerifyRefTypeMatch(mlir::Type type,
mlir::Type maybe_ref_type) {
if (auto ref_type = maybe_ref_type.dyn_cast<mlir::TF::TensorFlowRefType>())
return success(ref_type.RemoveRef().getKind() == type.getKind());
return success(ref_type.RemoveRef().getTypeID() == type.getTypeID());
return failure();
}

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
@ -100,7 +101,7 @@ mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
if (a == b) return a;
}
}
if (a.getKind() != b.getKind()) return nullptr;
if (a.getTypeID() != b.getTypeID()) return nullptr;
// If either is not a type that contain subtypes then the types are not cast
// compatible.
@ -178,127 +179,116 @@ ResultShapeIterator::ResultShapeIterator(Operation::result_iterator it)
// TF types helper functions
//===----------------------------------------------------------------------===//
bool TensorFlowType::classof(Type type) {
return type.getDialect().getNamespace() == "tf";
}
bool TensorFlowRefType::classof(Type type) {
return type.isa<
#define HANDLE_TF_TYPE(tftype, enumerant, name)
#define HANDLE_TF_REF_TYPE(tftype, enumerant, name) tftype##Type,
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
// NOLINTNEXTLINE
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
>();
}
bool TensorFlowTypeWithSubtype::classof(Type type) {
return type.isa<ResourceType, VariantType>();
}
TensorFlowType TensorFlowRefType::get(Type type) {
MLIRContext* ctx = type.getContext();
switch (getElementTypeOrSelf(type).getKind()) {
case StandardTypes::F16:
return HalfRefType::get(ctx);
case StandardTypes::F32:
return FloatRefType::get(ctx);
case StandardTypes::F64:
return DoubleRefType::get(ctx);
case StandardTypes::BF16:
return Bfloat16RefType::get(ctx);
case StandardTypes::Complex: {
const auto& etype = type.cast<ComplexType>().getElementType();
switch (getElementTypeOrSelf(etype).getKind()) {
case StandardTypes::F32:
return Complex64RefType::get(ctx);
case StandardTypes::F64:
return Complex128RefType::get(ctx);
default:
llvm_unreachable("unexpected complex type");
}
type = getElementTypeOrSelf(type);
if (type.isF16()) {
return HalfRefType::get(ctx);
} else if (type.isF32()) {
return FloatRefType::get(ctx);
} else if (type.isF64()) {
return DoubleRefType::get(ctx);
} else if (type.isBF16()) {
return Bfloat16RefType::get(ctx);
} else if (auto complex_type = type.dyn_cast<ComplexType>()) {
Type etype = complex_type.getElementType();
if (etype.isF32()) {
return Complex64RefType::get(ctx);
} else if (etype.isF64()) {
return Complex128RefType::get(ctx);
}
case StandardTypes::Integer: {
const auto& itype = type.cast<IntegerType>();
switch (itype.getWidth()) {
case 1:
return BoolRefType::get(ctx);
case 8:
return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx))
: Int8RefType::get(ctx);
case 16:
return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx))
: Int16RefType::get(ctx);
case 32:
return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx))
: Int32RefType::get(ctx);
case 64:
return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx))
: Int64RefType::get(ctx);
default:
llvm_unreachable("unexpected integer type");
}
llvm_unreachable("unexpected complex type");
} else if (auto itype = type.dyn_cast<IntegerType>()) {
switch (itype.getWidth()) {
case 1:
return BoolRefType::get(ctx);
case 8:
return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx))
: Int8RefType::get(ctx);
case 16:
return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx))
: Int16RefType::get(ctx);
case 32:
return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx))
: Int32RefType::get(ctx);
case 64:
return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx))
: Int64RefType::get(ctx);
default:
llvm_unreachable("unexpected integer type");
}
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
case TensorFlowTypes::enumerant: \
}
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
if (auto derived_ty = type.dyn_cast<tftype##Type>()) \
return tftype##RefType::get(ctx);
#define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
// NOLINTNEXTLINE
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
default:
llvm_unreachable("unexpected type kind");
}
llvm_unreachable("unexpected type kind");
}
Type TensorFlowRefType::RemoveRef() {
MLIRContext* ctx = getContext();
switch (getKind()) {
case TensorFlowTypes::HALF_REF:
return mlir::FloatType::getF16(ctx);
case TensorFlowTypes::FLOAT_REF:
return mlir::FloatType::getF32(ctx);
case TensorFlowTypes::DOUBLE_REF:
return mlir::FloatType::getF64(ctx);
case TensorFlowTypes::BFLOAT16_REF:
return mlir::FloatType::getBF16(ctx);
case TensorFlowTypes::BOOL_REF:
return mlir::IntegerType::get(1, ctx);
case TensorFlowTypes::INT8_REF:
return mlir::IntegerType::get(8, ctx);
case TensorFlowTypes::INT16_REF:
return mlir::IntegerType::get(16, ctx);
case TensorFlowTypes::INT32_REF:
return mlir::IntegerType::get(32, ctx);
case TensorFlowTypes::INT64_REF:
return mlir::IntegerType::get(64, ctx);
case TensorFlowTypes::UINT8_REF:
return mlir::IntegerType::get(8, IntegerType::Unsigned, ctx);
case TensorFlowTypes::UINT16_REF:
return mlir::IntegerType::get(16, IntegerType::Unsigned, ctx);
case TensorFlowTypes::UINT32_REF:
return mlir::IntegerType::get(32, IntegerType::Unsigned, ctx);
case TensorFlowTypes::UINT64_REF:
return mlir::IntegerType::get(64, IntegerType::Unsigned, ctx);
case TensorFlowTypes::COMPLEX64_REF:
return mlir::ComplexType::get(mlir::FloatType::getF32(ctx));
case TensorFlowTypes::COMPLEX128_REF:
return mlir::ComplexType::get(mlir::FloatType::getF64(ctx));
if (isa<HalfRefType>()) return mlir::FloatType::getF16(ctx);
if (isa<FloatRefType>()) return mlir::FloatType::getF32(ctx);
if (isa<DoubleRefType>()) return mlir::FloatType::getF64(ctx);
if (isa<Bfloat16RefType>()) return mlir::FloatType::getBF16(ctx);
if (isa<BoolRefType>()) return mlir::IntegerType::get(1, ctx);
if (isa<Int8RefType>()) return mlir::IntegerType::get(8, ctx);
if (isa<Int16RefType>()) return mlir::IntegerType::get(16, ctx);
if (isa<Int32RefType>()) return mlir::IntegerType::get(32, ctx);
if (isa<Int64RefType>()) return mlir::IntegerType::get(64, ctx);
if (isa<Uint8RefType>())
return mlir::IntegerType::get(8, IntegerType::Unsigned, ctx);
if (isa<Uint16RefType>())
return mlir::IntegerType::get(16, IntegerType::Unsigned, ctx);
if (isa<Uint32RefType>())
return mlir::IntegerType::get(32, IntegerType::Unsigned, ctx);
if (isa<Uint64RefType>())
return mlir::IntegerType::get(64, IntegerType::Unsigned, ctx);
if (isa<Complex64RefType>())
return mlir::ComplexType::get(mlir::FloatType::getF32(ctx));
if (isa<Complex128RefType>())
return mlir::ComplexType::get(mlir::FloatType::getF64(ctx));
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
case TensorFlowTypes::enumerant##_REF: \
return tftype##Type::get(ctx);
if (isa<tftype##RefType>()) return tftype##Type::get(ctx);
#define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
// NOLINTNEXTLINE
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
default:
llvm_unreachable("unexpected tensorflow ref type kind");
}
llvm_unreachable("unexpected tensorflow ref type kind");
}
Type TensorFlowTypeWithSubtype::RemoveSubtypes() {
MLIRContext* ctx = getContext();
switch (getKind()) {
case TensorFlowTypes::VARIANT:
return VariantType::get(ctx);
case TensorFlowTypes::RESOURCE:
return ResourceType::get(ctx);
default:
llvm_unreachable("unexpected tensorflow type with subtypes kind");
}
if (isa<VariantType>()) return VariantType::get(ctx);
if (isa<ResourceType>()) return ResourceType::get(ctx);
llvm_unreachable("unexpected tensorflow type with subtypes kind");
}
ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes() {
switch (getKind()) {
case TensorFlowTypes::VARIANT:
return this->cast<VariantType>().getSubtypes();
case TensorFlowTypes::RESOURCE:
return this->cast<ResourceType>().getSubtypes();
default:
llvm_unreachable("unexpected tensorflow type with subtypes kind");
}
if (auto variant_type = dyn_cast<VariantType>())
return variant_type.getSubtypes();
if (auto resource_type = dyn_cast<ResourceType>())
return resource_type.getSubtypes();
llvm_unreachable("unexpected tensorflow type with subtypes kind");
}
// TODO(jpienaar): BroadcastCompatible and HasCompatibleElementTypes have
@ -306,8 +296,11 @@ ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes() {
bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
if (lhs.size() != rhs.size()) return false;
for (auto types : llvm::zip(lhs, rhs)) {
auto lhs_type = std::get<0>(types);
auto rhs_type = std::get<1>(types);
// Drop ref types because they don't affect broadcast compatibility. E.g.,
// `tensor<!tf.f32ref>` and `tensor<f32>` should be considered broadcast
// compatible.
auto lhs_type = DropRefType(std::get<0>(types));
auto rhs_type = DropRefType(std::get<1>(types));
// This should be true for all TF ops:
auto lhs_tt = lhs_type.dyn_cast<TensorType>();
@ -366,27 +359,31 @@ bool AreCastCompatible(ArrayRef<Type> types) {
return true;
}
ShapedType DropTypeSubTypes(ShapedType ty) {
Type element_ty = ty.getElementType();
auto subtype_ty = element_ty.dyn_cast<TF::TensorFlowTypeWithSubtype>();
if (!subtype_ty) return ty;
// Assumes a function `GetDefaultTypeOf(ComposedType)` that returns the default
// type for a composed type (such as a ref type or a type with subtypes).
template <typename ComposedType>
Type DropTypeHelper(Type ty) {
Type element_ty = getElementTypeOrSelf(ty);
auto composed_type = element_ty.dyn_cast<ComposedType>();
if (!composed_type) return ty;
Type default_ty = GetDefaultTypeOf(subtype_ty);
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
return UnrankedTensorType::get(default_ty);
Type default_ty = GetDefaultTypeOf(composed_type);
if (auto ranked_ty = ty.dyn_cast<RankedTensorType>()) {
return RankedTensorType::get(ranked_ty.getShape(), default_ty);
} else if (ty.dyn_cast<UnrankedTensorType>()) {
return UnrankedTensorType::get(default_ty);
} else {
return default_ty;
}
}
ShapedType DropRefType(ShapedType ty) {
Type element_ty = ty.getElementType();
TF::TensorFlowRefType ref_ty = element_ty.dyn_cast<TF::TensorFlowRefType>();
if (!ref_ty) return ty;
Type default_ty = TF::GetDefaultTypeOf(ref_ty);
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
return UnrankedTensorType::get(default_ty);
Type DropSubTypes(Type ty) {
return DropTypeHelper<TF::TensorFlowTypeWithSubtype>(ty);
}
Type DropRefType(Type ty) { return DropTypeHelper<TF::TensorFlowRefType>(ty); }
Type DropRefAndSubTypes(Type ty) { return DropRefType(DropSubTypes(ty)); }
} // namespace TF
} // namespace mlir

View File

@ -83,10 +83,7 @@ class TensorFlowType : public Type {
using Type::Type;
// Support method to enable LLVM-style type casting.
static bool classof(Type type) {
return type.getKind() >= Type::FIRST_TENSORFLOW_TYPE &&
type.getKind() <= TensorFlowTypes::LAST_USED_TENSORFLOW_TYPE;
}
static bool classof(Type type);
};
// Returns true if the specified type is a valid TensorFlow element type.
@ -130,10 +127,7 @@ class TensorFlowRefType : public TensorFlowType {
using TensorFlowType::TensorFlowType;
// Checks if a type is TensorFlow Ref type.
static bool classof(Type type) {
return type.getKind() >= TensorFlowTypes::FLOAT_REF &&
type.getKind() <= TensorFlowTypes::LAST_USED_TENSORFLOW_TYPE;
}
static bool classof(Type type);
// Converts a type to the corresponding TensorFlowRef type.
static TensorFlowType get(Type type);
@ -263,10 +257,7 @@ class TensorFlowTypeWithSubtype : public TensorFlowType {
using TensorFlowType::TensorFlowType;
// Checks if a type is TensorFlow type with subtypes.
static bool classof(Type type) {
return type.getKind() == TensorFlowTypes::VARIANT ||
type.getKind() == TensorFlowTypes::RESOURCE;
}
static bool classof(Type type);
// Converts a TypeWithSubtype type to the same type but without its subtypes.
Type RemoveSubtypes();
@ -325,15 +316,21 @@ bool HasCompatibleElementTypes(Type lhs, Type rhs,
// compatible.
bool AreCastCompatible(ArrayRef<Type> types);
// If the given tensor has elements of type with subtypes, then returns a new
// type after dropping subtypes info. Otherwise, returns the original type as
// is.
ShapedType DropTypeSubTypes(ShapedType ty);
// If `ty` is a tensor type and its element type has subtypes, then returns a
// new type of same shape but dropped subtypes for the element type.
// Otherwise, if `ty` has subtypes, then returns corresponding type with dropped
// subtypes.
// Otherwise, returns the original type `ty`.
Type DropSubTypes(Type ty);
// If the given tensor has elements of type ref, then returns a new type
// of the shape, but corresponding non-ref type as element type. Otherwise,
// returns the original type as is.
ShapedType DropRefType(ShapedType ty);
// If `ty` is a tensor type and has elements of a ref type, then returns a new
// type of same shape but corresponding non-ref type as element type.
// Otherwise, if `ty` is a ref type, then returns corresponding non-ref type.
// Otherwise, returns the original type `ty`.
Type DropRefType(Type ty);
// Convenience call for executing both `DropRefType` and `DropSubTypes`.
Type DropRefAndSubTypes(Type ty);
} // end namespace TF
} // end namespace mlir

View File

@ -834,11 +834,11 @@ func @foldCase(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
// CHECK: PartitionedCall
// CHECK-SAME: device = "noodle"
// CHECK-SAME: f = @add
%4 = "tf.Case"(%2, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], device = "noodle"} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
%4 = "tf.Case"(%2, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], device = "noodle", is_stateless = false} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: PartitionedCall
// CHECK-SAME: _cluster_launch = "not_ready"
// CHECK-SAME: f = @sub
%5 = "tf.Case"(%3, %4, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], _cluster_launch = "not_ready"} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
%5 = "tf.Case"(%3, %4, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], _cluster_launch = "not_ready", is_stateless = false} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %5 : tensor<f32>
}

View File

@ -16,7 +16,7 @@ module {
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
%index = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
%input = "tf.opB"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
%result = "tf.Case"(%index, %input) {branches = [@branch_0, @branch_1, @branch_2, @branch_3, @branch_4]} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%result = "tf.Case"(%index, %input) {branches = [@branch_0, @branch_1, @branch_2, @branch_3, @branch_4], is_stateless = false} : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %result : tensor<i32>
}
tf_executor.fetch %output : tensor<i32>

View File

@ -479,6 +479,7 @@ func @DynamicStitch_duplicates(%arg0: tensor<2x2xf32>) -> tensor<1x2xf32> {
return %0 : tensor<1x2xf32>
}
// CHECK-LABEL: @Reciprocal
func @Reciprocal(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
@ -486,6 +487,7 @@ func @Reciprocal(%arg0: tensor<*xf32>) -> tensor<*xf32> {
return %0 : tensor<*xf32>
}
// CHECK-LABEL: @ScatterNd
func @ScatterNd(%arg0: tensor<4x1xi32>, %arg1: tensor<4xf32>) -> tensor<8xf32> {
// CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>} : () -> tensor<8xf32>
// CHECK: "tf.TensorScatterUpdate"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32>
@ -494,3 +496,16 @@ func @ScatterNd(%arg0: tensor<4x1xi32>, %arg1: tensor<4xf32>) -> tensor<8xf32> {
%0 = "tf.ScatterNd"(%arg0, %arg1, %shape) : (tensor<4x1xi32>, tensor<4xf32>, tensor<1xi32>) -> tensor<8xf32>
return %0 : tensor<8xf32>
}
// CHECK-LABEL: @_UnaryOpsComposition
// CHECK-SAME: %[[ARG0:.*]]: tensor<4xf32>
func @_UnaryOpsComposition(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: %[[RESULT0:.*]] = "tf.Asin"(%[[ARG0]])
// CHECK: %[[RESULT1:.*]] = "tf.Abs"(%[[RESULT0]])
// CHECK: %[[RESULT2:.*]] = "tf.Log"(%[[RESULT1]])
// CHECK: return %[[RESULT2]]
%0 = "tf._UnaryOpsComposition"(%arg0) {op_names = ["Asin", "Abs", "Log"]} : (tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}

View File

@ -136,6 +136,7 @@ func @if_region_captured_string(%arg0: tensor<i1>, %arg1: tensor<!tf.string>) ->
// CHECK-NOT: _xla_outside_compilation
// CHECK: "tf.IfRegion"
// CHECK: "tf.StringToNumber"
// CHECK-NOT: _xla_outside_compilation
// CHECK: _xla_outside_compilation = "auto", is_stateless = true
%1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%2 = "tf.IfRegion"(%arg0) ( {

View File

@ -43,7 +43,7 @@ func @main() {
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK: }
%1:2 = tf_executor.island wraps "tf.Case"(%0#0) {Tin = [], Tout = ["tfdtype$DT_FLOAT"], branches = [@foo, @bar], device = "", output_shapes = []} : (tensor<i32>) -> tensor<*xf32> loc("Case")
%1:2 = tf_executor.island wraps "tf.Case"(%0#0) {Tin = [], Tout = ["tfdtype$DT_FLOAT"], branches = [@foo, @bar], device = "", output_shapes = [], is_stateless = false} : (tensor<i32>) -> tensor<*xf32> loc("Case")
tf_executor.fetch
}
return

View File

@ -232,3 +232,55 @@ func @while_region_aliasing(%arg0: !tf_res, %arg1: !tf_res, %arg2: !tf_res) {
return
}
// -----
// Test aliasing through calls
!tf_res = type tensor<*x!tf.resource<tensor<32xf32>>>
// CHECK-LABEL: func @aliasing_through_calls
func @aliasing_through_calls(%arg0: tensor<32xf32>) -> () {
// expected-remark@below {{Result #0, ID 0 : 0, 1, 2}}
%vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res
// expected-remark@below {{Result #0, ID 1 : Unknown}}
// expected-remark@below {{Result #1, ID 2 : 0, 1, 2}}
%c:2 = call @passthru(%vh0) : (!tf_res) -> (!tf_res, !tf_res)
return
}
// expected-remark@below {{Region #0, Arg #0, ID 1 : 1}}
func @passthru(%arg0: !tf_res) -> (!tf_res, !tf_res) {
// expected-remark@below {{Result #0, ID 0 : 0}}
%vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res
return %vh0, %arg0 : !tf_res, !tf_res
}
// -----
// Test aliasing through tf_device.launch
!tf_res = type tensor<*x!tf.resource<tensor<32xf32>>>
// CHECK-LABEL: func @aliasing_through_launch
func @aliasing_through_launch(%arg0: tensor<32xf32>) {
// expected-remark@below {{Result #0, ID 0 : 0, 1}}
%vh = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> !tf_res
// expected-remark@below {{Result #0, ID 1 : 0, 1}}
%launch = "tf_device.launch"() ({
tf_device.return %vh : !tf_res
}) {device = ""} : () -> !tf_res
return
}
// -----
// Test aliasing through tf_device.cluster
!tf_res = type tensor<*x!tf.resource<tensor<32xf32>>>
// CHECK-LABEL: func @aliasing_through_cluster
func @aliasing_through_cluster(%arg0: tensor<32xf32>) {
// expected-remark@below {{Result #0, ID 0 : 0, 1}}
%vh = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> !tf_res
// expected-remark@below {{Result #0, ID 1 : 0, 1}}
%cluster = "tf_device.cluster"() ({
tf_device.return %vh : !tf_res
}) : () -> !tf_res
return
}

View File

@ -424,3 +424,117 @@ func @propagate_if_region_inlined(
}
return
}
// Test propagation through WhileRegion (inlined calls)
// CHECK-LABEL: func @propagate_while_region_inlined
func @propagate_while_region_inlined(
%arg0: !tf_res {tf.device = "/TPU:0"},
%arg1: tensor<i32>) {
tf_executor.graph {
// CHECK: tf_executor.island
%island = tf_executor.island {
// CHECK-NEXT: "tf.Identity"
// CHECK-SAME: {device = "/TPU:0"}
%id0 = "tf.Identity"(%arg0) : (!tf_res) -> !tf_res
// CHECK-NEXT: "tf.VarHandleOp"
%var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"} : () -> !tf_res
// CHECK-NEXT: "tf.WhileRegion"
"tf.WhileRegion"(%arg1, %id0, %var_handle) ({
^bb0(%carg0: tensor<i32>, %carg1: !tf_res, %carg2: !tf_res):
// CHECK: ^bb
// CHECK: "tf.Identity"
// CHECK-SAME: {device = "/TPU:0"}
%cid0 = "tf.Identity"(%carg1) : (!tf_res) -> !tf_res loc("cid0")
%read = "tf.ReadVariableOp"(%cid0) : (!tf_res) -> tensor<32xf32>
%cst = constant dense<3.0> : tensor<32xf32>
%cmp = "tf.Less"(%read, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xi1>
%dims = constant dense<0> : tensor<1xi32>
%reduce = "tf.All"(%cmp, %dims) {keep_dims = false} : (tensor<32xi1>, tensor<1xi32>) -> tensor<i1>
"tf.Yield"(%reduce) : (tensor<i1>) -> ()
}, {
^bb0(%barg0: tensor<i32>, %barg1: !tf_res, %barg2: !tf_res):
// CHECK: ^bb
// CHECK: "tf.Identity"
// CHECK-SAME: {device = "/TPU:0"}
%bid0 = "tf.Identity"(%barg1) : (!tf_res) -> !tf_res
// CHECK-NEXT: "tf.Identity"
// CHECK-SAME: {device = "/TPU:1"}
%id1 = "tf.Identity"(%barg2) : (!tf_res) -> !tf_res
"tf.Yield"(%barg0, %bid0, %id1) : (tensor<i32>, !tf_res,!tf_res) -> ()
}){is_stateless = false}
: (tensor<i32>, !tf_res, !tf_res) -> (tensor<i32>, !tf_res, !tf_res)
tf_executor.yield
}
tf_executor.fetch %island : !tf_executor.control
}
return
}
// Test propagation through WhileRegion (non-inlined calls)
// CHECK-LABEL: func @propagate_while_region
func @propagate_while_region(
%arg0: !tf_res {tf.device = "/TPU:0"},
%arg1: tensor<i32>) {
tf_executor.graph {
// CHECK: tf_executor.island
%island = tf_executor.island {
// CHECK-NEXT: "tf.Identity"
// CHECK-SAME: {device = "/TPU:0"}
%id0 = "tf.Identity"(%arg0) : (!tf_res) -> !tf_res
// CHECK-NEXT: "tf.VarHandleOp"
%var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"} : () -> !tf_res
// CHECK-NEXT: "tf.WhileRegion"
"tf.WhileRegion"(%arg1, %id0, %var_handle) ({
^bb0(%carg0: tensor<i32>, %carg1: !tf_res, %carg2: !tf_res):
%cond = call @whileregion_cond(%carg0, %carg1, %carg2) : (tensor<i32>, !tf_res, !tf_res) -> tensor<i1>
"tf.Yield"(%cond) : (tensor<i1>) -> ()
}, {
^bb0(%barg0: tensor<i32>, %barg1: !tf_res, %barg2: !tf_res):
%new_values:3 = call @whileregion_body(%barg0, %barg1, %barg2) : (tensor<i32>, !tf_res,!tf_res) -> (tensor<i32>, !tf_res,!tf_res)
"tf.Yield"(%new_values#0, %new_values#1, %new_values#2) : (tensor<i32>, !tf_res,!tf_res) -> ()
}){is_stateless = false}
: (tensor<i32>, !tf_res, !tf_res) -> (tensor<i32>, !tf_res, !tf_res)
tf_executor.yield
}
tf_executor.fetch %island : !tf_executor.control
}
return
}
// CHECK-LABEL: func @whileregion_body
func @whileregion_body(%arg0: tensor<i32>, %arg1: !tf_res, %arg2: !tf_res) -> (tensor<i32>, !tf_res, !tf_res) {
%graph:3 = tf_executor.graph {
// CHECK: tf_executor.island
%island:4 = tf_executor.island {
// CHECK-NEXT: "tf.Identity"
// CHECK-SAME: {device = "/TPU:0"}
%id0 = "tf.Identity"(%arg1) : (!tf_res) -> !tf_res
// CHECK-NEXT: "tf.Identity"
// CHECK-SAME: {device = "/TPU:1"}
%id1 = "tf.Identity"(%arg2) : (!tf_res) -> !tf_res
tf_executor.yield %arg0, %id0, %id1 : tensor<i32>, !tf_res, !tf_res
}
tf_executor.fetch %island#0, %island#1, %island#2 : tensor<i32>, !tf_res, !tf_res
}
return %graph#0, %graph#1, %graph#2: tensor<i32>, !tf_res, !tf_res
}
// CHECK-LABEL: func @whileregion_cond
func @whileregion_cond(%arg0: tensor<i32>, %arg1: !tf_res, %arg2: !tf_res) -> tensor<i1> {
%graph = tf_executor.graph {
// CHECK: tf_executor.island
%island:2 = tf_executor.island {
// CHECK-NEXT: "tf.Identity"
// CHECK-SAME: {device = "/TPU:0"}
%id0 = "tf.Identity"(%arg1) : (!tf_res) -> !tf_res
%read = "tf.ReadVariableOp"(%id0) : (!tf_res) -> tensor<32xf32>
%cst = constant dense<3.0> : tensor<32xf32>
%cmp = "tf.Less"(%read, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xi1>
%dims = constant dense<0> : tensor<1xi32>
%reduce = "tf.All"(%cmp, %dims) {keep_dims = false} : (tensor<32xi1>, tensor<1xi32>) -> tensor<i1>
tf_executor.yield %reduce : tensor<i1>
}
tf_executor.fetch %island#0 : tensor<i1>
}
return %graph : tensor<i1>
}

View File

@ -112,26 +112,6 @@ func @internal_resource() -> tensor<*xi32> {
// -----
// Tests that pass fails when there are remaining resource operationss that can
// not be lifted.
func @lifting_failure() -> tensor<*xi32> {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// expected-error @+1 {{has remaining resource inputs that can not be lifted}}
%1 = "tf_device.cluster"() ( {
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
%3 = "tf.SomeResourceOp"(%0, %2) : (tensor<*x!tf.resource>, tensor<*xi32>) -> tensor<*xi32>
"tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
tf_device.return %3 : tensor<*xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
return %1 : tensor<*xi32>
}
// -----
// Tests that pass lifts resource reads/writes from a loop, and removed unused
// resources.
@ -347,30 +327,6 @@ func @while_cond(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> {
// -----
// Tests that pass reports error on unsupported ops in loop body.
func @cluster_with_loop() -> () {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
"tf_device.cluster"() ( {
%1 = "tf.While"(%0) {
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
: (tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>)
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
func @while_body(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>) {
// expected-error @+1 {{found unsupported operations on resource.}}
"tf._UnknownOp"(%arg0) : (tensor<*x!tf.resource<tensor<f32>>>) -> ()
return %arg0 : tensor<*x!tf.resource<tensor<f32>>>
}
func @while_cond(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> {
%read = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
return %read : tensor<f32>
}
// -----
// Tests that pass reports error on unsupported ops in loop cond.
func @cluster_with_loop() -> () {
@ -409,7 +365,7 @@ func @cluster_with_case(%arg0: tensor<i32>) -> tensor<4xf32> {
// CHECK: %[[CLUSTER:.*]]:2 = "tf_device.cluster"()
%2 = "tf_device.cluster"() ( {
// CHECK: %[[CASE:.*]]:2 = "tf.Case"(%[[ARG0]], %[[READ0]], %[[READ1]])
%3:2 = "tf.Case"(%arg0, %0, %1) {branches = [@branch_0, @branch_1, @branch_2]}
%3:2 = "tf.Case"(%arg0, %0, %1) {branches = [@branch_0, @branch_1, @branch_2], is_stateless = false}
: (tensor<i32>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<*x!tf.resource<tensor<4xf32>>>)
-> (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>)
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[CASE]]#1, %[[CASE]]#0)

View File

@ -223,7 +223,7 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
// CHECK-SAME: %[[ARG_1:.*]]: tensor<!tf.resource<tensor<1x2x3xf32>>>
func @shape_from_case_to_branch_functions(%arg0: tensor<i32>, %arg1: tensor<!tf.resource<tensor<1x2x3xf32>>>) -> tensor<1x2x3xf32> {
// CHECK: %[[CASE:.*]] = "tf.Case"(%[[ARG_0]], %[[ARG_1]])
%0 = "tf.Case"(%arg0, %arg1) {branches = [@branch_0, @branch_1]} : (tensor<i32>, tensor<!tf.resource<tensor<1x2x3xf32>>>) -> tensor<1x2x3xf32>
%0 = "tf.Case"(%arg0, %arg1) {branches = [@branch_0, @branch_1], is_stateless = false} : (tensor<i32>, tensor<!tf.resource<tensor<1x2x3xf32>>>) -> tensor<1x2x3xf32>
// CHECK: return %[[CASE]] : tensor<1x2x3xf32>
return %0 : tensor<1x2x3xf32>
}

View File

@ -256,7 +256,7 @@ func @main(%arg0: tensor<i32>) -> () {
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
// CHECK-NOT: tf.EmptyTensorList
%tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<f32>>>
%case_op = "tf.Case"(%arg0, %tl) {branches = [@branch_0, @branch_1, @branch_2]}
%case_op = "tf.Case"(%arg0, %tl) {branches = [@branch_0, @branch_1, @branch_2], is_stateless = false}
: (tensor<i32>, tensor<!tf.variant<tensor<f32>>>) -> tensor<!tf.variant<tensor<f32>>>
// CHECK: "tf.Slice"
%pop:2 = "tf.TensorListPopBack"(%case_op, %elem_shape) : (tensor<!tf.variant<tensor<f32>>>, tensor<0xi32>) -> (tensor<!tf.variant<tensor<f32>>>, tensor<f32>)

View File

@ -848,7 +848,7 @@ func @testInvalidIfOp(tensor<i1>, tensor<*xf32>) -> tensor<2xf32> {
// Test invalid tf.Yield operation (parent should be IfRegion)
func @testInvalidYieldOp(%arg0: f32) -> () {
// expected-error @+1 {{'tf.Yield' op expects parent op to be one of 'tf.IfRegion, tf.WhileRegion'}}
// expected-error @+1 {{'tf.Yield' op expects parent op to be one of 'tf.CaseRegion, tf.IfRegion, tf.WhileRegion'}}
"tf.Yield"(%arg0) : (f32) -> ()
}
@ -3313,3 +3313,88 @@ func @testBatchToSpaceInvalidOutputDepth(%arg0: tensor<16x8x8x3xf32>, %arg1: ten
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<16x8x8x3xf32>, tensor<*xi32>) -> tensor<4x8x8x8xf32>
return
}
// -----
func @testCaseRegionNoRegions(%arg0: tensor<i32>) {
// expected-error @+1 {{expects to have at least 1 region}}
"tf.CaseRegion"(%arg0) {is_stateless = false} : (tensor<i32>) -> ()
return
}
// -----
func @testCaseRegionBadBranchIndicesShape(%arg0: tensor<8xi32>) {
// expected-error @+1 {{expects 'branch_index' to be a scalar, but got 'tensor<8xi32>'}}
"tf.CaseRegion"(%arg0) ( {
"tf.Yield"() : () -> ()
}) {is_stateless = false} : (tensor<8xi32>) -> ()
return
}
// -----
func @testCaseRegionBadBranchIndicesNegative() {
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
// expected-error @+1 {{expects 'branch_index' to be non-negative, but got -1}}
"tf.CaseRegion"(%0) ( {
"tf.Yield"() : () -> ()
}) {is_stateless = false} : (tensor<i32>) -> ()
return
}
// -----
func @testCaseRegionBadBranchIndicesPositive() {
%0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// expected-error @+1 {{expects 'branch_index' to be less than the number of regions (1), but got 1}}
"tf.CaseRegion"(%0) ( {
"tf.Yield"() : () -> ()
}) {is_stateless = false} : (tensor<i32>) -> ()
return
}
// -----
func @testCaseRegionMismatchedNumResults(%arg0: tensor<i32>) {
// expected-error @+1 {{region #0 should have same number (1) of results as tf.CaseRegion but has 0 results}}
%1 = "tf.CaseRegion"(%arg0) ( {
"tf.Yield"() : () -> ()
}) {is_stateless = false} : (tensor<i32>) -> tensor<i1>
return
}
// -----
func @testCaseRegionMismatchedResultTypes(%arg0: tensor<i32>, %arg1: tensor<f32>) {
// expected-error @+1 {{region #0 result type tensor<f32> is incompatible with tf.CaseRegion result type tensor<i1> at index 0}}
%1 = "tf.CaseRegion"(%arg0) ( {
"tf.Yield"(%arg1) : (tensor<f32>) -> ()
}) {is_stateless = false} : (tensor<i32>) -> tensor<i1>
return
}
// -----
// Test valid tf.Cumsum
func @testCumsum(%arg: tensor<8x16xf32>, %axis: tensor<i32>) -> tensor<8x16xf32> {
%0 = "tf.Cumsum"(%arg, %axis) : (tensor<8x16xf32>, tensor<i32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
}
// -----
func @testCumprod(%arg: tensor<8x16xf32>, %axis: tensor<2xi32>) -> tensor<8x16xf32> {
// expected-error @+1 {{requires scalar axis operand}}
%0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<2xi32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
}
// -----
func @testCumprod(%arg: tensor<8x16xf32>) -> tensor<8x16xf32> {
%axis = constant dense<-3> : tensor<i32>
// expected-error @+1 {{axis operand should be within range [-2, 2)}}
%0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<i32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
}

View File

@ -9,7 +9,7 @@ func @select(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<i32>, tensor<f32
// CHECK: return %[[first]],
%0 = "tf.DeviceIndex"() {device = "", device_names = ["CPU", "GPU"]} : () -> tensor<i32>
%1 = "tf.DeviceIndex"() {device = "", device_names = ["CPU", "GPU"]} : () -> tensor<i32>
%4 = "tf.Case"(%1, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>]} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
%4 = "tf.Case"(%1, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], is_stateless = false} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %0, %4 : tensor<i32>, tensor<f32>
}

View File

@ -0,0 +1,93 @@
// RUN: tf-opt %s -tf-tpu-identity-pruning | FileCheck %s --dump-input=always
// Tests Identity op in cluster is pruned away.
// CHECK-LABEL: func @testIdentity
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
func @testIdentity(%arg0: tensor<i32>) {
// CHECK-NOT: "tf.Identity"
// CHECK: "tf_device.cluster"
// CHECK-NEXT: tf_device.return [[ARG0]]
%0 = "tf_device.cluster"() ( {
%1 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
tf_device.return %1 : tensor<i32>
}) : () -> tensor<i32>
return
}
// Tests IdentityN op in cluster is pruned away.
// CHECK-LABEL: func @testIdentityN
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>, [[ARG1:%.*]]: tensor<f32>)
func @testIdentityN(%arg0: tensor<i32>, %arg1: tensor<f32>) {
// CHECK-NOT: "tf.IdentityN"
// CHECK: "tf_device.cluster"
// CHECK-NEXT: tf_device.return [[ARG0]], [[ARG1]]
%0:2 = "tf_device.cluster"() ( {
%1:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>)
tf_device.return %1#0, %1#1 : tensor<i32>, tensor<f32>
}) : () -> (tensor<i32>, tensor<f32>)
return
}
// Tests transitive Identity ops reachable from the cluster are pruned away.
// CHECK-LABEL: func @testTransitiveIdentity
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
func @testTransitiveIdentity(%arg0: tensor<i32>) {
// CHECK: "tf_device.cluster"
// CHECK: "tf.PartitionedCall"([[ARG0]])
// CHECK-SAME: f = @callee0
%0 = "tf_device.cluster"() ( {
%1 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @callee0} : (tensor<i32>) -> tensor<i32>
tf_device.return %1 : tensor<i32>
}) : () -> tensor<i32>
return
}
// CHECK-LABEL: func @callee0
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
func @callee0(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK-NOT: "tf.Identity"
// CHECK: "tf.PartitionedCall"([[ARG0]])
// CHECK-SAME: f = @callee1
%0 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
%1 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @callee1} : (tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
// CHECK-LABEL: func @callee1
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
func @callee1(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK-NOT: "tf.Identity"
// CHECK: return [[ARG0]]
%0 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
// Tests Identity ops not reachable from the cluster are not pruned away.
// CHECK-LABEL: func @testIdentityOutsideCluster
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
func @testIdentityOutsideCluster(%arg0: tensor<i32>) {
// CHECK: [[IDENTITY:%.*]] = "tf.Identity"([[ARG0]])
// CHECK: [[CLUSTER:%.*]] = "tf_device.cluster"
// CHECK-NEXT: tf_device.return [[IDENTITY]]
%0 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
%1 = "tf_device.cluster"() ( {
tf_device.return %0 : tensor<i32>
}) : () -> tensor<i32>
// CHECK: "tf.PartitionedCall"([[CLUSTER]])
// CHECK-SAME: f = @callee2
%2 = "tf.PartitionedCall"(%1) {config = "", config_proto = "", executor_type = "", f = @callee2} : (tensor<i32>) -> tensor<i32>
return
}
// CHECK-LABEL: func @callee2
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
func @callee2(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK: [[IDENTITY:%.*]] = "tf.Identity"([[ARG0]])
%0 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
// CHECK: return [[IDENTITY]]
return %0 : tensor<i32>
}

View File

@ -95,16 +95,16 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
func_pm.addPass(CreateTPUHostComputationExpansionPass());
func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass());
}
// Run another shape inference pass because resource decomposition might have
// created new partial types.
pm.addPass(TF::CreateTFShapeInferencePass());
pm.addPass(TFDevice::CreateResourceOpLiftingPass());
pm.addPass(TF::CreateTFFunctionalControlFlowToRegions());
pm.addPass(mlir::createInlinerPass());
pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass());
pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
// Run another shape inference pass because resource decomposition might have
// created new partial types.
pm.addPass(TF::CreateTFShapeInferencePass());
pm.addNestedPass<FuncOp>(tf_executor::CreateTFExecutorConstantSinkingPass());
pm.addPass(TFDevice::CreateResourceOpLiftingPass());
pm.addPass(TF::CreateResourceDeviceInferencePass());
pm.addPass(TFDevice::CreateClusterOutliningPass());
pm.addPass(CreateTPUDynamicPaddingMapperPass());

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <algorithm>
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/c/eager/c_api.h"
@ -68,7 +69,7 @@ static bool ShouldBeFolded(Operation* inst) {
LogicalResult ConstantFoldFallbackHook(
Operation* inst, ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute>& results) { // NOLINT
SmallVectorImpl<OpFoldResult>& results) { // NOLINT
// Instructions with side effects should not be constant folded to preserve
// the original semantics.
if (inst->getNumRegions() != 0 || !MemoryEffectOpInterface::hasNoEffect(inst))
@ -126,8 +127,16 @@ LogicalResult ConstantFoldFallbackHook(
// TODO(jpienaar): Avoid using global context & mutex here.
static auto* mu = new tensorflow::mutex();
tensorflow::mutex_lock l(*mu);
return tensorflow::EvaluateOperation(inst, inputs, ctx, &results);
SmallVector<Attribute, 8> constants;
LogicalResult status =
tensorflow::EvaluateOperation(inst, inputs, ctx, &constants);
results.assign(constants.begin(), constants.end());
return status;
}
static bool init_hooks = ([] () {
TensorFlowDialect::RegisterConstantFoldHook(ConstantFoldFallbackHook);
}(), true);
} // namespace TF
} // namespace mlir

View File

@ -27,7 +27,7 @@ namespace TF {
LogicalResult ConstantFoldFallbackHook(
Operation *inst, ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results); // NOLINT
SmallVectorImpl<OpFoldResult> &results); // NOLINT
} // namespace TF
} // namespace mlir

View File

@ -19,7 +19,6 @@ limitations under the License.
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/DialectHooks.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
@ -35,31 +34,22 @@ namespace {
// Since this method is passed to MLIR as decode hook it has to conform
// to LLVM style used by MLIR.
bool DecodeOpaqueTensorHook(const OpaqueElementsAttr input,
ElementsAttr& output) { // NOLINT
LogicalResult DecodeOpaqueTensorHook(const OpaqueElementsAttr input,
ElementsAttr& output) { // NOLINT
Builder builder(input.getType().getContext());
auto decoded_attr_or = tensorflow::DecodeOpaqueTensor(input, builder);
if (!decoded_attr_or.ok()) {
VLOG(2) << decoded_attr_or.status().error_message();
return true;
return failure();
}
output = decoded_attr_or.ValueOrDie();
return false;
return success();
}
// Hooks for the TensorFlow dialect.
class TensorFlowHooks : public DialectHooks {
public:
DialectConstantFoldHook getConstantFoldHook() {
return TF::ConstantFoldFallbackHook;
}
DialectConstantDecodeHook getDecodeHook() { return DecodeOpaqueTensorHook; }
};
static bool init_hooks = ([] () {
TF::TensorFlowDialect::RegisterDecodeConstantHook(DecodeOpaqueTensorHook);
}(), true);
} // anonymous namespace
// Static initialization for TensorFlow dialect hooks registration.
static DialectHooksRegistration<TensorFlowHooks> tf_hooks_registration("tf");
} // namespace mlir

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/core/util/tensor_format.h"
@ -427,12 +428,38 @@ class LowerSparseMatMulOp : public OpRewritePattern<TF::SparseMatMulOp> {
}
};
// Lowers _UnaryOpsComposition op as a series of original TensorFlow ops that
// were fused together.
class Lower_UnaryOpsComposition
: public OpRewritePattern<TF::_UnaryOpsCompositionOp> {
public:
using OpRewritePattern<TF::_UnaryOpsCompositionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::_UnaryOpsCompositionOp op,
PatternRewriter &rewriter) const override {
Value result = op.x();
for (StringRef op_name :
op.op_names().getAsRange<StringAttr, StringRef>()) {
std::string full_name = "tf." + op_name.str();
// All ops in the sequences have the same result type as the original
// result type.
OperationState state(op.getLoc(), full_name, /*operands=*/{result},
/*types=*/{op.getType()}, /*attributes=*/{});
Operation *op = rewriter.createOperation(state);
result = op->getResult(0);
}
rewriter.replaceOp(op, {result});
return success();
}
};
} // namespace
void PopulateLoweringTFPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
patterns->insert<LowerAddNOp, LowerDynamicStitchOp, LowerInvertPermutationOp,
LowerPackOp, LowerSparseMatMulOp>(context);
LowerPackOp, LowerSparseMatMulOp, Lower_UnaryOpsComposition>(
context);
populateWithGenerated(context, patterns);
}

View File

@ -131,6 +131,25 @@ LogicalResult MarkUncompilableOps(
return success();
}
// Unmarks outside compilation for any op that has parents already
// marked for outside compilation since the child will be extracted
// anyways.
void UnmarkChildren(Block* block) {
block->walk([&](Operation* op) {
if (!op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) return;
Operation* iter_op = op;
bool remove_attr = false;
while (auto* parent_op = iter_op->getParentOp()) {
if (parent_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
remove_attr = true;
break;
}
iter_op = parent_op;
}
if (remove_attr) op->removeAttr(kXlaOutsideCompilationAttr);
});
}
void MarkOpsForOutsideCompilation::runOnOperation() {
auto module = getOperation();
const Dialect* tf_dialect = getContext().getRegisteredDialect("tf");
@ -168,6 +187,17 @@ void MarkOpsForOutsideCompilation::runOnOperation() {
});
if (result.wasInterrupted()) return signalPassFailure();
module.walk([&](tf_device::ClusterOp cluster) {
// Only if `allow_soft_placement` attribute is true should we unmark ops
// for outside compilation.
auto soft_placement_attr =
cluster.getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
if (!(soft_placement_attr && soft_placement_attr.getValue())) {
return;
}
UnmarkChildren(&cluster.GetBody());
});
}
} // namespace

View File

@ -271,6 +271,9 @@ namespace TFTPU {
// `_tpu_replicate` attribute.
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUClusterFormationPass();
// Creates a pass that removes Identity/IdentityN ops from a cluster.
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUIdentityPruningPass();
// Creates a pass that allows TPU program inputs to have layouts determined at
// run time.
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicLayoutPass();

View File

@ -26,10 +26,13 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
@ -39,6 +42,9 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h"
#define DEBUG_TYPE "tf-resource-device-inference"
namespace mlir {
namespace TF {
@ -132,6 +138,13 @@ inline StringRef GetDeviceAttr(Operation* op) {
return device_attr ? device_attr.getValue() : "";
}
// Print operation with debug info (to get line number info for debugging)
void dump(StringRef message, Operation* op) {
llvm::dbgs() << message;
op->print(llvm::dbgs(), OpPrintingFlags().enableDebugInfo(true));
llvm::dbgs() << "\n";
}
// Propagates device assignment inside a function.
LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
PerFunctionResult* result) {
@ -153,26 +166,67 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
if (failed(res)) return res;
}
auto walk_res = func_op.walk([&](Operation* op) {
if (auto var_handle = dyn_cast<VarHandleOp>(op)) {
// Record VarHandleOp's device attribute.
StringRef device_attr = GetDeviceAttr(op);
if (device_attr.empty()) return WalkResult::advance();
auto res = AddResourceDeviceAndEmitError(var_handle.resource(),
device_attr, op, result);
if (failed(res)) return WalkResult::interrupt();
}
if (auto identity = dyn_cast<IdentityOp>(op)) {
// Try to construct IdentityOp's attribute from recorded assignment.
if (!GetDeviceAttr(op).empty()) return WalkResult::advance();
for (auto output : filter_resources(op->getResults())) {
if (auto device = result->DeviceForResource(output))
identity.setAttr(kDeviceAttr, builder.getStringAttr(*device));
}
return WalkResult::advance();
}
return WalkResult::advance();
});
// To support WhileRegion, we need to propagate device attributes from
// WhileRegion operands to body/cond region arguments *prior* to visiting
// these regions. Use tensorflow::walk() instead of MLIR core walker to
// implement such a pre-order walk.
auto walk_res = tensorflow::GenericWalk(
func_op, [&](Operation* op, const tensorflow::WalkStage& stage) {
// We just need to visit operations in pre-order mode.
if (!stage.IsBeforeAllRegions()) return WalkResult::advance();
if (auto var_handle = dyn_cast<VarHandleOp>(op)) {
// Record VarHandleOp's device attribute.
StringRef device_attr = GetDeviceAttr(op);
if (device_attr.empty()) return WalkResult::advance();
auto res = AddResourceDeviceAndEmitError(var_handle.resource(),
device_attr, op, result);
if (failed(res)) return WalkResult::interrupt();
} else if (auto identity = dyn_cast<IdentityOp>(op)) {
LLVM_DEBUG(dump("Visiting ", identity));
// Try to construct IdentityOp's attribute from recorded assignment.
if (!GetDeviceAttr(op).empty()) return WalkResult::advance();
for (auto output : filter_resources(op->getResults())) {
LLVM_DEBUG(llvm::dbgs() << " Processing output #"
<< output.getResultNumber() << "\n");
if (auto device = result->DeviceForResource(output)) {
LLVM_DEBUG(llvm::dbgs()
<< " Setting device = " << *device << "\n");
identity.setAttr(kDeviceAttr, builder.getStringAttr(*device));
}
}
} else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
// For WhileRegion, do local analysis prior to visiting the attached
// regions and propagate device annotations to the cond and body
// region arguments. The annotations are the union of annotations
// on the input and result. Resource alias analysis already propagates
// resource ID from the inputs to the results for a while, so just
// need to consider the results.
LLVM_DEBUG(llvm::dbgs() << "Visiting WhileRegion\n");
for (auto output : filter_resources(while_region.getResults())) {
auto device = result->DeviceForResource(output);
int output_index = output.getResultNumber();
if (!device) {
LLVM_DEBUG(llvm::dbgs()
<< " No device for output #" << output_index << "\n");
continue;
}
// Transfer the annotation to both region arguments
for (Region* region : while_region.getRegions()) {
BlockArgument arg = region->getArgument(output_index);
LLVM_DEBUG(llvm::dbgs()
<< " Propagating device = '" << *device
<< "' to arg #" << output_index << " of region #"
<< region->getRegionNumber() << "\n");
if (failed(AddResourceDeviceAndEmitError(arg, *device,
while_region, result)))
return WalkResult::interrupt();
}
}
}
return WalkResult::advance();
});
return failure(walk_res.wasInterrupted());
}
@ -201,6 +255,10 @@ void ResourceDeviceInference::runOnOperation() {
Value arg_operand = caller_operands[arg.getArgNumber()];
auto device = caller_res.DeviceForResource(arg_operand);
if (!device) continue;
LLVM_DEBUG(llvm::dbgs()
<< "Propagating '" << *device << "' to arg #"
<< arg.getArgNumber() << " of function @"
<< callee.getName() << "\n");
if (failed(AddResourceDeviceAndEmitError(arg, *device, caller,
&callee_res,
&callee_needs_recompute)))
@ -240,6 +298,8 @@ void ResourceDeviceInference::runOnOperation() {
"call");
return WalkResult::interrupt();
}
LLVM_DEBUG(llvm::dbgs()
<< "Visiting call to function @" << func.getName() << "\n");
if (failed(propagate_operands_to_callee_arguments(
call, call.getArgOperands(), {func}, func_res)))
return WalkResult::interrupt();

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
@ -330,15 +331,6 @@ LogicalResult HoistResourceOpsFromCluster(tf_device::ClusterOp cluster,
getUsedValuesDefinedAbove(new_cluster.body(), new_cluster.body(),
captured_values);
for (Value v : captured_values) {
auto tensor_type = v.getType().dyn_cast<TensorType>();
if (!tensor_type) continue;
if (!tensor_type.getElementType().isa<TF::ResourceType>()) continue;
return new_cluster.emitOpError()
<< "has remaining resource inputs that can not be lifted";
}
return success();
}
@ -361,29 +353,23 @@ LogicalResult FindResourceArgUseInfo(
ResourceArgUseInfo info;
info.used = false;
info.updated = false;
bool do_not_touch = false;
bool read_or_assigned = false;
for (auto user : arg.getUsers()) {
if (user == return_op) continue;
info.used = true;
if (auto read = llvm::dyn_cast<TF::ReadVariableOp>(user)) {
info.used = true;
read_or_assigned = true;
info.data_type = read.getType();
continue;
}
if (auto assign = llvm::dyn_cast<TF::AssignVariableOp>(user)) {
info.used = true;
read_or_assigned = true;
info.updated = true;
info.data_type = assign.value().getType();
continue;
}
if (isa<TF::StackPushV2Op, TF::StackPopV2Op>(user)) {
// Stacks will be handled by a separate pass.
do_not_touch = true;
break;
}
user->emitOpError("found unsupported operations on resource.");
return failure();
}
if (!do_not_touch) (*result)[arg.getArgNumber()] = info;
if (!info.used || read_or_assigned) (*result)[arg.getArgNumber()] = info;
}
return success();
}
@ -914,8 +900,8 @@ LogicalResult HandlePartitionedCallOpCallee(
// resource-lifted new callee function in lifting_info.
template <typename CallOpType>
void UpdatePartitionedCallOpWithNewCallee(
CallOpType call_op, const PartitionedCallLiftingInfo& lifting_info) {
if (lifting_info.lifted_callee == nullptr) return;
CallOpType call_op, PartitionedCallLiftingInfo& lifting_info) {
if (!lifting_info.lifted_callee) return;
// Replace output resource uses with the aliasing input, so that we can remove
// this output.
for (const auto& entry : lifting_info.old_outputs_aliasing_old_inputs) {
@ -929,12 +915,10 @@ void UpdatePartitionedCallOpWithNewCallee(
auto new_operands =
FilterRange<Value, OperandRange>(call_op.args(), lifting_info.use_info);
auto new_call = builder.create<CallOpType>(
call_op.getLoc(),
const_cast<FuncOp&>(lifting_info.lifted_callee).getType().getResults(),
call_op.getLoc(), lifting_info.lifted_callee.getType().getResults(),
new_operands, call_op.getAttrs());
new_call.setAttr(
"f", builder.getSymbolRefAttr(
const_cast<FuncOp&>(lifting_info.lifted_callee).getName()));
"f", builder.getSymbolRefAttr(lifting_info.lifted_callee.getName()));
AddLoadsStoresOutsideControlFlowOp(
new_call, lifting_info.arg_data_type_and_updated_output_index);
// Replace uses.
@ -949,7 +933,8 @@ void UpdatePartitionedCallOpWithNewCallee(
}
LogicalResult HoistForFunctionalControlFlow(
Block*, ModuleOp, llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>*);
Block*, ModuleOp,
llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*);
// A templated routine for handling both PartitionedCallOp and
// StatefulPartitionedCallOp. If the callee is already lifted, it just updates
@ -958,9 +943,10 @@ LogicalResult HoistForFunctionalControlFlow(
template <typename CallOpType>
LogicalResult HandlePartitionedCallOp(
CallOpType call_op, FuncOp callee, ModuleOp module,
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>* lifted_callees) {
auto emplace_res =
lifted_callees->try_emplace(callee, PartitionedCallLiftingInfo());
llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*
lifted_callees) {
auto emplace_res = lifted_callees->try_emplace(callee.getName(),
PartitionedCallLiftingInfo());
if (emplace_res.second) {
// Unseen callee. Perform resource lifting on it.
HoistForFunctionalControlFlow(&callee.front(), module, lifted_callees);
@ -977,7 +963,7 @@ LogicalResult HandlePartitionedCallOp(
// body/cond/branch/callee functions.
LogicalResult HoistForFunctionalControlFlow(
Block* block, ModuleOp module,
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>*
llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*
lifted_partitioned_call_callees) {
// Remove identity nodes to avoid aliasing.
RemoveIdentity(block);
@ -1056,7 +1042,7 @@ LogicalResult HoistForFunctionalControlFlow(
// Returns failure if there are remaining resource-type values that can not be
// lifted.
void ResourceOpLiftingPass::runOnOperation() {
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>
llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
lifted_partitioned_call_callees;
ModuleOp module = getOperation();
auto result = module.walk([&](FuncOp func_op) {
@ -1121,7 +1107,7 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) {
<< function.getBlocks().size();
}
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>
llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
lifted_partitioned_call_callees;
return HoistForFunctionalControlFlow(&function.front(),
cast<ModuleOp>(function.getParentOp()),

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/Interfaces/FoldInterfaces.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
@ -697,11 +698,8 @@ bool ShapeInference::RefineShapeForPassThroughOps(Operation* op) {
// TODO(jpienaar): The tf.Cast op, which is uniformly inserted at the
// moment, cannot handle arbirary types (e.g., it can't handle quantized
// types). This restriction can be relaxed if not only tf.Cast is used.
auto kind = t.getKind();
return (kind >= Type::FIRST_STANDARD_TYPE &&
kind < Type::LAST_STANDARD_TYPE) ||
(kind >= Type::FIRST_TENSORFLOW_TYPE &&
kind < Type::LAST_TENSORFLOW_TYPE);
return t.getDialect().getNamespace().empty() ||
isa<TensorFlowDialect>(t.getDialect());
};
bool changed = false;
@ -1174,10 +1172,11 @@ LogicalResult ShapeInference::TryToFold(Operation* op) {
if (!dialect) return failure();
// Only attempt TF dialect fallback if there are no unknown operands.
if (some_unknown && dialect == tf_dialect_) return failure();
SmallVector<Attribute, 8> constants;
if (failed(dialect->constantFoldHook(op, constant_operands, constants)))
auto* interface = dialect->getRegisteredInterface<DialectFoldInterface>();
if (!interface) return failure();
if (failed(interface->Fold(op, constant_operands, fold_results)))
return failure();
fold_results.assign(constants.begin(), constants.end());
}
for (auto result : zip(op->getResults(), fold_results)) {

View File

@ -0,0 +1,113 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <tuple>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Region.h" // from @llvm-project
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TFTPU {
namespace {
// This pass removes Identity/IdentityN ops from the TPU computation and
// reachable functions.
// TODO(lyandy): Remove this pass once resource op lifting is migrated to use
// resource alias analysis and support region based control flow. Removing
// Identity ops may remove `_XlaSharding` annotation attribute if Identity ops
// are used to propagate such information.
struct TPUIdentityPruning
: public PassWrapper<TPUIdentityPruning, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
// Collects all reachable functions (via call ops) from a given region.
SmallVector<FuncOp, 4> CollectReachableFunctions(Region& region) {
llvm::SmallPtrSet<FuncOp, 4> reachable_funcs;
auto collect_reachable_funcs =
[&reachable_funcs](Region& src, SmallVectorImpl<FuncOp>& funcs_to_visit) {
src.walk([&reachable_funcs, &funcs_to_visit](CallOpInterface call_op) {
auto func = dyn_cast_or_null<FuncOp>(call_op.resolveCallable());
if (func && reachable_funcs.insert(func).second)
funcs_to_visit.push_back(func);
});
};
SmallVector<FuncOp, 4> funcs_to_visit;
collect_reachable_funcs(region, funcs_to_visit);
while (!funcs_to_visit.empty()) {
SmallVector<FuncOp, 4> new_funcs_to_visit;
for (FuncOp func_to_visit : funcs_to_visit) {
if (!func_to_visit.getCallableRegion()) continue;
collect_reachable_funcs(*func_to_visit.getCallableRegion(),
new_funcs_to_visit);
}
funcs_to_visit.swap(new_funcs_to_visit);
}
return llvm::to_vector<4>(reachable_funcs);
}
// Removes Identity/IdentityN ops from a region and forwards its operands to its
// results.
void RemoveIdentityFromRegion(Region& region) {
region.walk([](Operation* op) {
if (isa<TF::IdentityOp, TF::IdentityNOp>(op)) {
op->replaceAllUsesWith(op->getOperands());
op->erase();
}
});
}
void TPUIdentityPruning::runOnOperation() {
SmallVector<tf_device::ClusterOp, 4> clusters;
getOperation().walk(
[&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); });
for (tf_device::ClusterOp cluster : clusters) {
RemoveIdentityFromRegion(cluster.body());
auto reachable_funcs = CollectReachableFunctions(cluster.body());
for (FuncOp reachable_func : reachable_funcs)
RemoveIdentityFromRegion(*reachable_func.getCallableRegion());
}
}
} // anonymous namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUIdentityPruningPass() {
return std::make_unique<TPUIdentityPruning>();
}
static PassRegistration<TPUIdentityPruning> pass(
"tf-tpu-identity-pruning",
"Removes Identity/IdentityN ops from the TPU computation");
} // namespace TFTPU
} // namespace mlir

View File

@ -177,7 +177,8 @@ Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def,
restrict_functionalization_to_tpu_nodes
? [](const Node* n) { return n->attrs().Find(kTpuReplicateAttr); }
: NodeFilter{};
return FunctionalizeControlFlow(graph, flib_def, node_filter);
return FunctionalizeControlFlow(graph, flib_def, node_filter,
/*include_functions=*/true);
}
// Stateful helper class to import a TensorFlow model into an MLIR Module.

View File

@ -219,22 +219,18 @@ StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
if (auto attr = inst.getAttrOfType<mlir::ElementsAttr>(attr_id)) {
mlir::Attribute rand_val;
mlir::Type element_type = attr.getType().getElementType();
if (element_type.isa<mlir::IntegerType>()) {
rand_val = mlir::IntegerAttr::get(element_type, std::rand());
} else if (element_type.isF16() || element_type.isF32() ||
element_type.isF64()) {
rand_val = mlir::FloatAttr::get(element_type,
std::rand() * 1.0 / RAND_MAX);
switch (element_type.getKind()) {
case mlir::StandardTypes::Integer:
rand_val = mlir::IntegerAttr::get(element_type, std::rand());
break;
case mlir::StandardTypes::F16:
case mlir::StandardTypes::F32:
case mlir::StandardTypes::F64:
rand_val = mlir::FloatAttr::get(element_type,
std::rand() * 1.0 / RAND_MAX);
break;
default:
inst.emitWarning()
<< "Skipping splat conversion for "
<< "an unsupported attribute type " << element_type;
continue;
} else {
inst.emitWarning()
<< "Skipping splat conversion for "
<< "an unsupported attribute type " << element_type;
continue;
}
auto new_attr =
mlir::DenseElementsAttr::get(attr.getType(), rand_val);

View File

@ -36,8 +36,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/bfloat16.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/tstring.h"

View File

@ -91,64 +91,62 @@ Status ConvertDataType(DataType dtype, Builder builder, Type* type) {
}
Status ConvertScalarTypeToDataType(Type type, DataType* dtype) {
switch (type.getKind()) {
case mlir::StandardTypes::F16:
*dtype = DT_HALF;
return Status::OK();
case mlir::StandardTypes::F32:
*dtype = DT_FLOAT;
return Status::OK();
case mlir::StandardTypes::F64:
*dtype = DT_DOUBLE;
return Status::OK();
case mlir::StandardTypes::BF16:
*dtype = DT_BFLOAT16;
return Status::OK();
case mlir::StandardTypes::Integer: {
const auto& itype = type.cast<mlir::IntegerType>();
switch (itype.getWidth()) {
case 1:
*dtype = DT_BOOL;
return Status::OK();
case 8:
*dtype = itype.isUnsigned() ? DT_UINT8 : DT_INT8;
return Status::OK();
case 16:
*dtype = itype.isUnsigned() ? DT_UINT16 : DT_INT16;
return Status::OK();
case 32:
*dtype = itype.isUnsigned() ? DT_UINT32 : DT_INT32;
return Status::OK();
case 64:
*dtype = itype.isUnsigned() ? DT_UINT64 : DT_INT64;
return Status::OK();
default:
return errors::Unimplemented(
absl::StrCat("Converting ", debugString(type), " to DataType"));
}
}
case mlir::StandardTypes::Complex: {
auto etype = type.cast<mlir::ComplexType>().getElementType();
if (etype.isF32()) {
*dtype = DT_COMPLEX64;
return Status::OK();
} else if (etype.isF64()) {
*dtype = DT_COMPLEX128;
return Status::OK();
}
return errors::Unimplemented(
absl::StrCat("Converting ", debugString(type), " to DataType"));
}
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
case mlir::TF::TensorFlowTypes::enumerant: \
*dtype = DT_##enumerant; \
if (type.isF16()) {
*dtype = DT_HALF;
return Status::OK();
} else if (type.isF32()) {
*dtype = DT_FLOAT;
return Status::OK();
} else if (type.isF64()) {
*dtype = DT_DOUBLE;
return Status::OK();
} else if (type.isBF16()) {
*dtype = DT_BFLOAT16;
return Status::OK();
} else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
switch (itype.getWidth()) {
case 1:
*dtype = DT_BOOL;
return Status::OK();
case 8:
*dtype = itype.isUnsigned() ? DT_UINT8 : DT_INT8;
return Status::OK();
case 16:
*dtype = itype.isUnsigned() ? DT_UINT16 : DT_INT16;
return Status::OK();
case 32:
*dtype = itype.isUnsigned() ? DT_UINT32 : DT_INT32;
return Status::OK();
case 64:
*dtype = itype.isUnsigned() ? DT_UINT64 : DT_INT64;
return Status::OK();
default:
return errors::Unimplemented(
absl::StrCat("Converting ", debugString(type), " to DataType"));
}
} else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
auto etype = complex_type.getElementType();
if (etype.isF32()) {
*dtype = DT_COMPLEX64;
return Status::OK();
} else if (etype.isF64()) {
*dtype = DT_COMPLEX128;
return Status::OK();
}
return errors::Unimplemented(
absl::StrCat("Converting ", debugString(type), " to DataType"));
}
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
if (type.isa<mlir::TF::tftype##Type>()) { \
*dtype = DT_##enumerant; \
return Status::OK(); \
}
// NOLINTNEXTLINE
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
default:
return errors::Unimplemented(
absl::StrCat("Converting ", debugString(type), " to DataType"));
}
return errors::Unimplemented(
absl::StrCat("Converting ", debugString(type), " to DataType"));
}
Status ConvertToDataType(Type type, DataType* dtype) {

View File

@ -9,103 +9,41 @@ package(
package_group(
name = "friends",
includes = ["//third_party/mlir:subpackages"],
packages = [
"//tensorflow/compiler/mlir/...",
"//tensorflow/core/kernels/mlir_generated/...",
],
packages = ["//tensorflow/compiler/mlir/..."],
)
cc_library(
name = "passes",
srcs = ["passes.cc"],
hdrs = ["passes.h"],
name = "cubin_creator",
srcs = ["cubin_creator.cc"],
hdrs = ["cubin_creator.h"],
copts = if_cuda(["-DGOOGLE_CUDA=1"]),
deps = [
"@com_google_absl//absl/memory",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:TargetNVVMIR",
"@llvm-project//mlir:Transforms",
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/hlo:materialize_broadcasts", # buildcleaner: keep
"//tensorflow/compiler/xla/service/gpu:stream_executor_util",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/mlir/hlo:unfuse_batch_norm", # buildcleaner: keep
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla/service/gpu:target_constants",
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
"//tensorflow/core:cuda_libdevice_path",
"//tensorflow/core:lib",
] + if_cuda(["//tensorflow/stream_executor/gpu:asm_compiler"]),
)
cc_library(
name = "kernel_creator",
srcs = ["kernel_creator.cc"],
hdrs = ["kernel_creator.h"],
copts = if_cuda(["-DGOOGLE_CUDA=1"]),
deps = [
":passes",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:AffineToStandardTransforms",
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_dialect_registration",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:CFGTransforms",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:GPUToNVVMTransforms",
"@llvm-project//mlir:GPUTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LLVMTransforms",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgToLLVM",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:GPUToGPURuntimeTransforms",
"@llvm-project//mlir:SCFToGPUPass",
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TargetNVVMIR",
"@llvm-project//mlir:Transforms",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/hlo:all_passes",
"//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration",
"//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
"//tensorflow/compiler/mlir/hlo:legalize_tanh_to_approximation",
"//tensorflow/compiler/mlir/hlo:legalize_to_linalg",
"//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/mlir/hlo:lhlo_copy_removal",
"//tensorflow/compiler/mlir/hlo:lhlo_fuse_linalg",
"//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_affine",
"//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/hlo:materialize_broadcasts", # buildcleaner: keep
"//tensorflow/compiler/mlir/hlo:unfuse_batch_norm", # buildcleaner: keep
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service/gpu:stream_executor_util",
"//tensorflow/compiler/xla/service/gpu:target_constants",
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
"//tensorflow/compiler/xla/service/mlir_gpu:kernel_lowering",
"//tensorflow/compiler/xla/service/mlir_gpu:passes",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:cuda_libdevice_path",
"//tensorflow/core:lib",
"//tensorflow/compiler/xla:util",
] + if_cuda(["//tensorflow/stream_executor/gpu:asm_compiler"]),
)
@ -114,36 +52,11 @@ tf_cc_binary(
srcs = ["tf_to_cubin.cc"],
visibility = ["//tensorflow/core/kernels/mlir_generated:__pkg__"],
deps = [
":kernel_creator",
":cubin_creator",
"//tensorflow/compiler/mlir:init_mlir",
"//tensorflow/core:lib",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Pass",
],
)
tf_cc_binary(
name = "tf_to_kernel",
srcs = ["tf_to_kernel.cc"],
visibility = ["//tensorflow/core/kernels/mlir_generated:__pkg__"],
deps = [
":kernel_creator",
"//tensorflow/compiler/mlir:init_mlir",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Analysis",
"@llvm-project//llvm:CodeGen",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
"@llvm-project//llvm:X86Disassembler", # fixdeps: keep
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TargetLLVMIR",
],
)
@ -159,7 +72,6 @@ tf_cc_binary(
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:MlirOptMain",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],

View File

@ -13,33 +13,59 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tools/kernel_gen/passes.h"
//===- cubin_creator.cc -----------------------------------------*- C++ -*-===//
//
// This file implements the function to compile a TF kernel function to a cubin.
//
//===----------------------------------------------------------------------===//
#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h"
#include <string>
#include <utility>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/escaping.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Target/NVVMIR.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h"
#include "tensorflow/core/platform/cuda_libdevice_path.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/path.h"
#if GOOGLE_CUDA
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
#endif
namespace mlir {
namespace kernel_gen {
namespace {
using tensorflow::Status;
using xla::InternalError;
using xla::StatusOr;
xla::StatusOr<std::string> GetLibdeviceDir(
StatusOr<std::string> GetLibdeviceDir(
const xla::HloModuleConfig& hlo_module_config) {
for (const std::string& cuda_root : tensorflow::CandidateCudaRoots(
hlo_module_config.debug_options().xla_gpu_cuda_data_dir())) {
@ -51,7 +77,7 @@ xla::StatusOr<std::string> GetLibdeviceDir(
return libdevice_dir;
}
}
return xla::InternalError(
return InternalError(
"Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice");
}
@ -87,11 +113,34 @@ struct UnfuseBatchNormPass
}
};
struct PropagateTensorFlowABIKnowledgePass
: public mlir::PassWrapper<PropagateTensorFlowABIKnowledgePass,
Status LowerTfOpToLhloWithDynamicShapes(mlir::ModuleOp module) {
mlir::PassManager pm(module.getContext());
auto enable_if_vlog_is_on = [](mlir::Pass* pass, mlir::Operation* op) {
return VLOG_IS_ON(1);
};
pm.enableIRPrinting(/*shouldPrintBeforePass=*/{},
/*shouldPrintAfterPass=*/enable_if_vlog_is_on,
/*printModuleScope=*/false,
/*printAfterOnlyOnChange=*/false, llvm::dbgs());
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(false));
pm.addNestedPass<mlir::FuncOp>(
absl::make_unique<MaterializeBroadcastsPass>());
pm.addNestedPass<mlir::FuncOp>(absl::make_unique<UnfuseBatchNormPass>());
pm.addPass(mlir::mhlo::createLegalizeToLhloPass(
/*results_escape_functions=*/true));
pm.addNestedPass<mlir::FuncOp>(mlir::lmhlo::createLhloCopyRemovalPass());
if (failed(pm.run(module))) {
return InternalError("Lowering TF to LHLO failed.");
}
return Status::OK();
}
struct PropagateTensorFlowABIKnowledge
: public mlir::PassWrapper<PropagateTensorFlowABIKnowledge,
mlir::OperationPass<mlir::LLVM::LLVMFuncOp>> {
explicit PropagateTensorFlowABIKnowledgePass(
mlir::FunctionType type, llvm::ArrayRef<uint32_t> same_shape_)
explicit PropagateTensorFlowABIKnowledge(mlir::FunctionType type,
llvm::ArrayRef<uint32_t> same_shape_)
: func_type(type), same_shape(same_shape_) {}
void runOnOperation() override {
@ -125,7 +174,8 @@ struct PropagateTensorFlowABIKnowledgePass
for (mlir::Type arg_type : arg_types) {
if (!arg_type.isa<mlir::MemRefType>()) {
func.emitError() << "argument of surrounding func is not ranked memref";
return signalPassFailure();
signalPassFailure();
return;
}
positions.push_back(arg_pos);
// Set alignment and aliasing on the pointers.
@ -154,7 +204,8 @@ struct PropagateTensorFlowABIKnowledgePass
func.emitOpError() << "same shape constraints on arguments with "
"non-matching shapes: #"
<< first << " and #" << same;
return signalPassFailure();
signalPassFailure();
continue;
}
for (uint32_t i = 0; i < 2 * rank; ++i) {
@ -171,93 +222,91 @@ struct PropagateTensorFlowABIKnowledgePass
llvm::ArrayRef<uint32_t> same_shape;
};
class GpuKernelToBlobPass
: public mlir::PassWrapper<GpuKernelToBlobPass,
mlir::OperationPass<mlir::gpu::GPUModuleOp>> {
public:
GpuKernelToBlobPass(mlir::StringRef blob_annotation,
std::pair<int32_t, int32_t> compute_capability)
: blob_annotation_(blob_annotation),
compute_capability_(compute_capability) {}
Status PropagateTensorFlowABIKnowledgeToKernel(
mlir::ModuleOp module, llvm::ArrayRef<uint32_t> same_shape) {
// Grab the original signature from the single function.
auto func = *module.getBody()->op_begin<mlir::FuncOp>();
void runOnOperation() override {
mlir::gpu::GPUModuleOp module = getOperation();
mlir::PassManager pm(module.getContext());
auto enable_if_vlog_is_on = [](mlir::Pass*, mlir::Operation*) {
return VLOG_IS_ON(1);
};
pm.enableIRPrinting(/*shouldPrintBeforePass=*/{},
/*shouldPrintAfterPass=*/enable_if_vlog_is_on,
/*printModuleScope=*/false,
/*printAfterOnlyOnChange=*/false, llvm::dbgs());
auto& kernel_pm = pm.nest<::mlir::gpu::GPUModuleOp>();
kernel_pm.addNestedPass<mlir::LLVM::LLVMFuncOp>(
absl::make_unique<PropagateTensorFlowABIKnowledge>(func.getType(),
same_shape));
llvm::LLVMContext llvmContext;
auto llvmModule = mlir::translateModuleToNVVMIR(module, llvmContext);
if (!llvmModule) {
return signalPassFailure();
}
llvmModule->setModuleIdentifier("acme");
llvmModule->setDataLayout(xla::gpu::nvptx::kDataLayout);
xla::HloModuleConfig config;
config.set_debug_options(xla::GetDebugOptionsFromFlags());
auto enable_fusion = [](llvm::TargetMachine* target) {
target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast;
};
auto libdevice_dir_or = GetLibdeviceDir(config);
if (!libdevice_dir_or.ok()) {
return signalPassFailure();
}
auto ptx_or = xla::gpu::nvptx::CompileToPtx(
llvmModule.get(), compute_capability_, config,
libdevice_dir_or.ValueOrDie(), enable_fusion);
if (!ptx_or.ok()) {
return signalPassFailure();
}
auto ptx = ptx_or.ValueOrDie();
#if GOOGLE_CUDA
auto blob_or = tensorflow::se::CompileGpuAsm(
std::get<0>(compute_capability_), std::get<1>(compute_capability_),
ptx.c_str(), xla::gpu::PtxOptsFromConfig(config));
if (blob_or.ok()) {
const auto& blob = blob_or.ValueOrDie();
std::string blob_string(blob.begin(), blob.end());
module.setAttr(blob_annotation_,
mlir::StringAttr::get(blob_string, &getContext()));
return;
} else {
return signalPassFailure();
}
#endif
return signalPassFailure();
if (failed(pm.run(module))) {
return InternalError("Static knowledge propagation failed.");
}
return Status::OK();
}
private:
mlir::StringRef blob_annotation_;
std::pair<int32_t, int32_t> compute_capability_;
};
void RegisterDialects() {
static bool init_once = []() {
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
return true;
}();
(void)init_once;
}
} // namespace
std::unique_ptr<mlir::FunctionPass> createMaterializeBroadcastsPass() {
return absl::make_unique<MaterializeBroadcastsPass>();
}
StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
llvm::StringRef tf_code, std::pair<int32_t, int32_t> compute_capability,
llvm::ArrayRef<uint32_t> tile_sizes, llvm::ArrayRef<uint32_t> same_shape,
llvm::ArrayRef<uint32_t> unroll_factors) {
RegisterDialects();
mlir::MLIRContext context;
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
std::unique_ptr<mlir::FunctionPass> createUnfuseBatchNormPass() {
return absl::make_unique<UnfuseBatchNormPass>();
}
TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get()));
{
xla::mlir_gpu::LowerLHLOToGPUOptions options;
options.tile_sizes = tile_sizes;
options.unroll_factors = unroll_factors;
options.collapse_parallel_loops = false;
options.use_approximations = true;
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerLHLOToGPU(module.get(), options));
}
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
TF_RETURN_IF_ERROR(
PropagateTensorFlowABIKnowledgeToKernel(module.get(), same_shape));
std::unique_ptr<mlir::OperationPass<mlir::LLVM::LLVMFuncOp>>
createPropagateTensorFlowABIKnowledgePass(mlir::FunctionType type,
llvm::ArrayRef<uint32_t> same_shape) {
return absl::make_unique<PropagateTensorFlowABIKnowledgePass>(type,
same_shape);
}
mlir::OwningModuleRef kernel_module =
xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie();
llvm::LLVMContext llvmContext;
auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module, llvmContext);
if (!llvmModule) {
return InternalError("Could not translate MLIR module to NVVM");
}
std::unique_ptr<mlir::OperationPass<mlir::gpu::GPUModuleOp>>
createGpuKernelToBlobPass(
mlir::StringRef blob_annotation,
const std::pair<int32_t, int32_t>& compute_capability) {
return absl::make_unique<GpuKernelToBlobPass>(blob_annotation,
compute_capability);
}
llvmModule->setModuleIdentifier("acme");
llvmModule->setDataLayout(xla::gpu::nvptx::kDataLayout);
} // namespace kernel_gen
} // namespace mlir
xla::HloModuleConfig config;
config.set_debug_options(xla::GetDebugOptionsFromFlags());
auto enable_fusion = [](llvm::TargetMachine* target) {
target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast;
};
TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config));
TF_ASSIGN_OR_RETURN(
std::string ptx,
xla::gpu::nvptx::CompileToPtx(llvmModule.get(), compute_capability,
config, libdevice_dir, enable_fusion));
VLOG(1) << ptx;
#if GOOGLE_CUDA
return tensorflow::se::CompileGpuAsm(
std::get<0>(compute_capability), std::get<1>(compute_capability),
ptx.c_str(), xla::gpu::PtxOptsFromConfig(config));
#else
return InternalError(
"GOOGLE_CUDA not defined. Did you specify --config=cuda ?");
#endif
}

View File

@ -13,43 +13,30 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
//===- kernel_creator.h -----------------------------------------*- C++ -*-===//
//===- cubin_creator.h ------------------------------------------*- C++ -*-===//
//
// This file declares the function to compile a TF kernel function to a cubin.
//
//===----------------------------------------------------------------------===//
#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_KERNEL_CREATOR_H_
#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_KERNEL_CREATOR_H_
#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_
#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_
#include <utility>
#include <vector>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "tensorflow/compiler/xla/statusor.h"
namespace tensorflow {
namespace kernel_gen {
// Registers necessary dialects. It should be called before creating
// MLIRContext.
void RegisterDialects();
// Converts TF code to LLVM/NVVM. If `cubin_only` is true, then the conversion
// stops after cubin binary blob is generated. If `cubin_only` is false, lowers
// the host side to LLVM Dialect.
xla::StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
mlir::MLIRContext& mlir_context, llvm::StringRef tf_code, bool cubin_only,
xla::StatusOr<std::vector<uint8_t>> GenerateCubinForTfCode(
llvm::StringRef tf_code,
std::pair<int32_t, int32_t> compute_capability = {7, 5},
llvm::ArrayRef<uint32_t> tile_sizes = {16, 64},
llvm::ArrayRef<uint32_t> same_shape = {},
llvm::ArrayRef<uint32_t> unroll_factors = {});
// Extracts cubin from the converted module.
xla::StatusOr<std::string> ExtractGpuBinary(mlir::ModuleOp module);
} // namespace kernel_gen
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_KERNEL_CREATOR_H_
#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_

View File

@ -48,13 +48,11 @@ Type TFFrameworkDialect::parseType(DialectAsmParser &parser) const {
/// Print a type registered to this dialect.
void TFFrameworkDialect::printType(Type type, DialectAsmPrinter &os) const {
switch (type.getKind()) {
case TFFrameworkTypes::OpKernelContextType:
os << "op_kernel_context";
return;
default:
llvm_unreachable("unexpected TF Framework type kind");
if (type.isa<OpKernelContextType>()) {
os << "op_kernel_context";
return;
}
llvm_unreachable("unexpected TF Framework type kind");
}
template <typename OpTy>

View File

@ -1,258 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
//===- kernel_creator.cc ----------------------------------------*- C++ -*-===//
//
// This file implements the function to compile a TF kernel function to a cubin.
//
//===----------------------------------------------------------------------===//
#include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
#include "mlir/Dialect/GPU/ParallelLoopMapper.h" // from @llvm-project
#include "mlir/Dialect/GPU/Passes.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project
#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
#include "mlir/Dialect/SCF/Passes.h" // from @llvm-project
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/passes.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h"
#include "tensorflow/compiler/xla/service/mlir_gpu/passes.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/path.h"
namespace tensorflow {
namespace kernel_gen {
namespace {
using tensorflow::Status;
using xla::InternalError;
using xla::StatusOr;
constexpr llvm::StringRef kGpuBinaryAttrName = "nvvm.cubin";
Status LowerTFtoGPU(mlir::ModuleOp module, bool cubin_only,
llvm::ArrayRef<uint32_t> tile_sizes,
llvm::ArrayRef<uint32_t> unroll_factors) {
mlir::PassManager pm(module.getContext());
applyPassManagerCLOptions(pm);
pm.addPass(mlir::mhlo::createLegalizeTFPass(false));
if (cubin_only) {
pm.addNestedPass<mlir::FuncOp>(
mlir::kernel_gen::createMaterializeBroadcastsPass());
pm.addNestedPass<mlir::FuncOp>(
mlir::kernel_gen::createUnfuseBatchNormPass());
pm.addPass(mlir::mhlo::createLegalizeToLhloPass(
/*results_escape_functions=*/true));
// Moving `AllocOp`s and inserting missing `DeallocOp`s
pm.addPass(::mlir::createBufferPlacementPass());
pm.addNestedPass<mlir::FuncOp>(mlir::lmhlo::createLhloCopyRemovalPass());
} else {
pm.addPass(mlir::mhlo::createTransformUnrankedHloPass());
pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass());
pm.addPass(mlir::kernel_gen::transforms::CreateBufferizePass());
pm.addPass(mlir::createCanonicalizerPass());
}
// We have to anticipate later unrolling in tiling to make sure that we get
// the requested tiling after unrolling. Compute the new tiling here if
// needed.
llvm::SmallVector<unsigned, 4> tiling_for_unrolling;
llvm::SmallVector<int64_t, 4> as_int64;
if (!unroll_factors.empty()) {
tiling_for_unrolling.reserve(tile_sizes.size());
for (auto pair : llvm::zip(tile_sizes, unroll_factors)) {
tiling_for_unrolling.push_back(std::get<0>(pair) * std::get<1>(pair));
as_int64.push_back(std::get<1>(pair));
}
} else {
tiling_for_unrolling.append(tile_sizes.begin(), tile_sizes.end());
}
// Transform LHLO operations to LinAlg.
pm.addPass(::mlir::lmhlo::createLegalizeLhloToLinalgPass());
// Fuse linalg operations.
pm.addPass(::mlir::lmhlo::createLhloFuseLinalgPass(
/*use_parallel_loops=*/true, tiling_for_unrolling));
// Transform the Linalg operations inside of the loop nest into parallel
// loops.
pm.addPass(::mlir::createConvertLinalgToParallelLoopsPass());
// Canonicalize the code to simplify index computations. This is needed so
// that loop bounds have the same value.
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
// Fuse the inner-most loops.
pm.addPass(xla::mlir_gpu::createFuseInnerParallelLoopsPass());
// Run CSE to ensure that loads and stores to the same subview get
// recognized as such.
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
// Forward stores to buffers to loads.
pm.addPass(xla::mlir_gpu::createStoreForwardingPass());
// Remove now unused temporary buffers.
pm.addPass(xla::mlir_gpu::createDeadTempBufferRemovalPass());
if (!unroll_factors.empty()) {
pm.addPass(::mlir::createParallelLoopTilingPass(as_int64));
}
// Some basic cleanup.
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
// Greedily map the remaining loop to GPU hardware dimensions.
pm.addPass(xla::mlir_gpu::createMapParallelLoopsPass());
// Apply the mapping.
pm.addPass(mlir::createParallelLoopToGpuPass());
// Embed TF Framework ops.
if (!cubin_only) {
pm.addPass(mlir::kernel_gen::tf_framework::createEmbedTFFrameworkPass());
}
// Some basic cleanup.
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
// Make loops with min bounds into a conditional plus static bounds.
// Only do this if we unrolled in the first place.
if (!unroll_factors.empty()) {
pm.addNestedPass<::mlir::FuncOp>(mlir::createForLoopSpecializationPass());
}
// Approximate Tanh using standard operations.
pm.addNestedPass<::mlir::FuncOp>(
::mlir::mhlo::createLegalizeTanhToApproximationPass());
// Move scalar operations into the launch to ensure smaller signatures.
pm.addPass(xla::mlir_gpu::createMoveScalarComputationsIntoGpuLaunchPass());
// Take launches to launches with kernels.
pm.addPass(::mlir::createGpuKernelOutliningPass());
if (cubin_only) {
// Make kernel signature deterministic so that we can call it externally.
pm.addPass(xla::mlir_gpu::createRewriteKernelSignaturePass());
}
pm.addPass(::mlir::createLowerAffinePass());
pm.addPass(::mlir::createLowerToCFGPass());
if (failed(pm.run(module))) {
return InternalError("Lowering to GPU kernels failed.");
}
return Status::OK();
}
Status PropagateTensorFlowABIKnowledgeToKernel(
mlir::ModuleOp module, llvm::ArrayRef<uint32_t> same_shape) {
// Grab the original signature from the single function.
auto func = *module.getBody()->op_begin<mlir::FuncOp>();
mlir::PassManager pm(module.getContext());
applyPassManagerCLOptions(pm);
auto& kernel_pm = pm.nest<::mlir::gpu::GPUModuleOp>();
kernel_pm.addNestedPass<mlir::LLVM::LLVMFuncOp>(
mlir::kernel_gen::createPropagateTensorFlowABIKnowledgePass(
func.getType(), same_shape));
if (failed(pm.run(module))) {
return InternalError("Static knowledge propagation failed.");
}
return Status::OK();
}
Status LowerGPUToLLVM(mlir::ModuleOp module, bool cubin_only,
llvm::ArrayRef<uint32_t> same_shape,
llvm::StringRef gpu_binary_attr_name,
std::pair<int32_t, int32_t> compute_capability) {
mlir::PassManager pm(module.getContext());
applyPassManagerCLOptions(pm);
auto& kernel_pm = pm.nest<mlir::gpu::GPUModuleOp>();
if (cubin_only) {
// Grab the original signature from the single function.
auto func = *module.getBody()->op_begin<mlir::FuncOp>();
kernel_pm.addNestedPass<mlir::LLVM::LLVMFuncOp>(
mlir::kernel_gen::createPropagateTensorFlowABIKnowledgePass(
func.getType(), same_shape));
}
kernel_pm.addPass(mlir::createStripDebugInfoPass());
kernel_pm.addPass(mlir::kernel_gen::createGpuKernelToBlobPass(
gpu_binary_attr_name, compute_capability));
if (!cubin_only) {
pm.addPass(mlir::kernel_gen::tf_framework::
createTestTFFrameworkLegalizeToLLVMPass());
pm.addPass(mlir::createGpuToLLVMConversionPass(gpu_binary_attr_name));
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
}
return failed(pm.run(module)) ? InternalError("Lowering to LLVM IR failed.")
: Status::OK();
}
} // namespace
void RegisterDialects() {
static bool init_once = []() {
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
return true;
}();
(void)init_once;
}
StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
mlir::MLIRContext& context, llvm::StringRef tf_code, bool cubin_only,
std::pair<int32_t, int32_t> compute_capability,
llvm::ArrayRef<uint32_t> tile_sizes, llvm::ArrayRef<uint32_t> same_shape,
llvm::ArrayRef<uint32_t> unroll_factors) {
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
TF_RETURN_IF_ERROR(
LowerTFtoGPU(module.get(), cubin_only, tile_sizes, unroll_factors));
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
TF_RETURN_IF_ERROR(LowerGPUToLLVM(module.get(), cubin_only, same_shape,
kGpuBinaryAttrName, compute_capability));
return module;
}
StatusOr<std::string> ExtractGpuBinary(mlir::ModuleOp module) {
auto gpu_modules = module.getOps<mlir::gpu::GPUModuleOp>();
if (std::distance(gpu_modules.begin(), gpu_modules.end()) != 1) {
return InternalError("There should be exactly one GPU Module");
}
mlir::gpu::GPUModuleOp gpu_mod = *gpu_modules.begin();
auto blob = gpu_mod.getAttrOfType<mlir::StringAttr>(kGpuBinaryAttrName);
if (!blob) {
return InternalError("No binary blob found in the module");
}
return blob.getValue().str();
}
} // namespace kernel_gen
} // namespace tensorflow

View File

@ -1,43 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_PASSES_H_
#include <memory>
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
namespace mlir {
namespace kernel_gen {
std::unique_ptr<mlir::FunctionPass> createMaterializeBroadcastsPass();
std::unique_ptr<mlir::FunctionPass> createUnfuseBatchNormPass();
std::unique_ptr<mlir::OperationPass<mlir::LLVM::LLVMFuncOp>>
createPropagateTensorFlowABIKnowledgePass(mlir::FunctionType type,
llvm::ArrayRef<uint32_t> same_shape);
std::unique_ptr<mlir::OperationPass<mlir::gpu::GPUModuleOp>>
createGpuKernelToBlobPass(
mlir::StringRef blob_annotation,
const std::pair<int32_t, int32_t>& compute_capability);
} // namespace kernel_gen
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_PASSES_H_

View File

@ -23,47 +23,10 @@
#include "absl/strings/string_view.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/compiler/mlir/init_mlir.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
namespace kernel_gen {
namespace {
xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
int32_t architecture, llvm::ArrayRef<uint32_t> tile_sizes,
llvm::ArrayRef<uint32_t> same_shape,
llvm::ArrayRef<uint32_t> unroll_factors) {
std::pair<int32_t, int32_t> compute_capability(architecture / 10,
architecture % 10);
// Read TF code.
std::string tf_code;
TF_RETURN_IF_ERROR(
ReadFileToString(Env::Default(), input_file.str(), &tf_code));
// Compile.
RegisterDialects();
mlir::MLIRContext mlir_context;
TF_ASSIGN_OR_RETURN(
mlir::OwningModuleRef module,
GenerateKernelForTfCode(mlir_context, tf_code, /*cubin_only=*/true,
compute_capability, tile_sizes, same_shape,
unroll_factors));
// Extract cubin.
TF_ASSIGN_OR_RETURN(std::string cubin, ExtractGpuBinary(*module));
// Write cubin binary blob.
TF_RETURN_IF_ERROR(
WriteStringToFile(Env::Default(), output_file.str(), cubin));
return xla::Status::OK();
}
} // namespace
} // namespace kernel_gen
} // namespace tensorflow
int main(int argc, char** argv) {
llvm::cl::opt<std::string> input_file("input", llvm::cl::desc("input file"),
@ -88,15 +51,38 @@ int main(int argc, char** argv) {
llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated);
tensorflow::InitMlir y(&argc, &argv);
mlir::registerPassManagerCLOptions();
llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n");
auto status =
tensorflow::kernel_gen::Run(input_file, output_file, architecture,
tile_sizes, same_shape, unroll_factors);
std::pair<int32_t, int32_t> compute_capability(architecture / 10,
architecture % 10);
std::string tf_code;
auto read_status = tensorflow::ReadFileToString(tensorflow::Env::Default(),
input_file, &tf_code);
if (!read_status.ok()) {
LOG(ERROR) << read_status;
return 1;
}
auto cubin = tensorflow::kernel_gen::GenerateCubinForTfCode(
tf_code, compute_capability, tile_sizes, same_shape, unroll_factors);
if (!cubin.ok()) {
LOG(ERROR) << cubin.status();
return 1;
}
std::vector<uint8_t> cubin_data = cubin.ConsumeValueOrDie();
auto status = tensorflow::WriteStringToFile(
tensorflow::Env::Default(), output_file,
absl::string_view{reinterpret_cast<char*>(cubin_data.data()),
cubin_data.size()});
if (!status.ok()) {
LOG(ERROR) << status;
return 1;
}
return 0;
}

View File

@ -1,164 +0,0 @@
// Copyright 2020 The TensorFlow Runtime Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//===- tf_to_kernel.cc ------------------------------------------*- C++ -*-===//
//
// This file implements the entry point to compile a tf op to a cubin file.
//
//===----------------------------------------------------------------------===//
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/string_view.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/CodeGen/CommandFlags.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Host.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Target/LLVMIR.h" // from @llvm-project
#include "tensorflow/compiler/mlir/init_mlir.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
namespace kernel_gen {
namespace {
static llvm::codegen::RegisterCodeGenFlags CGF;
std::unique_ptr<llvm::TargetMachine> GetTargetMachine(llvm::Module* module) {
llvm::Triple triple(module->getTargetTriple());
if (triple.getTriple().empty()) {
triple = llvm::Triple(llvm::sys::getDefaultTargetTriple());
module->setTargetTriple(triple.getTriple());
}
std::string error;
const llvm::Target* target =
llvm::TargetRegistry::lookupTarget("", triple, error);
if (!target) {
return nullptr;
}
llvm::TargetOptions target_options =
llvm::codegen::InitTargetOptionsFromCodeGenFlags();
return std::unique_ptr<llvm::TargetMachine>(target->createTargetMachine(
triple.str(), "generic", "", target_options, llvm::Reloc::Model::PIC_));
}
// Compiles the given MLIR module via LLVM into an executable binary format.
xla::StatusOr<std::string> EmitToBinary(mlir::ModuleOp module) {
// Translate the module.
llvm::LLVMContext llvm_context;
std::unique_ptr<llvm::Module> llvm_module =
mlir::translateModuleToLLVMIR(module, llvm_context);
// Set up the output stream.
llvm::SmallString<8> outstr;
llvm::raw_svector_ostream ostream(outstr);
ostream.SetUnbuffered();
llvm::legacy::PassManager codegen_passes;
codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass(
llvm::Triple(llvm_module->getTargetTriple())));
// TODO(b/163818770): Apply optimizations before dumping .a file.
auto target_machine = GetTargetMachine(llvm_module.get());
llvm_module->setDataLayout(target_machine->createDataLayout());
if (target_machine->addPassesToEmitFile(codegen_passes, ostream, nullptr,
llvm::CGFT_ObjectFile, false)) {
return xla::InternalError("Failed add passes to emit file");
}
codegen_passes.run(*llvm_module);
return ostream.str().str();
}
xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
int32_t architecture, llvm::ArrayRef<uint32_t> tile_sizes,
llvm::ArrayRef<uint32_t> same_shape,
llvm::ArrayRef<uint32_t> unroll_factors) {
std::pair<int32_t, int32_t> compute_capability(architecture / 10,
architecture % 10);
// Read TF code.
std::string tf_code;
TF_RETURN_IF_ERROR(
ReadFileToString(Env::Default(), input_file.str(), &tf_code));
// Compile.
RegisterDialects();
mlir::MLIRContext mlir_context;
TF_ASSIGN_OR_RETURN(
mlir::OwningModuleRef module,
GenerateKernelForTfCode(mlir_context, tf_code, /*cubin_only=*/false,
compute_capability, tile_sizes, same_shape,
unroll_factors));
// Get binary.
TF_ASSIGN_OR_RETURN(std::string binary, EmitToBinary(*module));
// Write .a file.
TF_RETURN_IF_ERROR(
WriteStringToFile(Env::Default(), output_file.str(), binary));
return xla::Status::OK();
}
} // namespace
} // namespace kernel_gen
} // namespace tensorflow
int main(int argc, char** argv) {
llvm::cl::opt<std::string> input_file("input", llvm::cl::desc("input file"),
llvm::cl::value_desc("filename"),
llvm::cl::init("foo.mlir"));
llvm::cl::opt<std::string> output_file(
"output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"),
llvm::cl::init("foo.bin"));
llvm::cl::opt<int32_t> architecture(
"arch", llvm::cl::desc("target architecture (e.g. 50 for sm_50)"),
llvm::cl::init(50));
llvm::cl::list<uint32_t> tile_sizes(
"tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore,
llvm::cl::CommaSeparated);
llvm::cl::list<uint32_t> unroll_factors(
"unroll_factors",
llvm::cl::desc("factors to unroll by, separated by commas"),
llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated);
llvm::cl::list<uint32_t> same_shape(
"same_shape",
llvm::cl::desc("arguments with same shape, separated by commas"),
llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated);
tensorflow::InitMlir y(&argc, &argv);
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
mlir::registerPassManagerCLOptions();
llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n");
auto status =
tensorflow::kernel_gen::Run(input_file, output_file, architecture,
tile_sizes, same_shape, unroll_factors);
if (!status.ok()) {
LOG(ERROR) << status;
return 1;
}
return 0;
}

View File

@ -77,7 +77,6 @@ cc_library(
"//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_llvm",
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LLVMTransforms",

View File

@ -15,7 +15,6 @@ limitations under the License.
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
@ -53,11 +52,10 @@ class TestTFFrameworkToLLVMPass
// Set target.
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalDialect<gpu::GPUDialect>();
target.addIllegalDialect<tf_framework::TFFrameworkDialect>();
target.addIllegalOp<LLVM::DialectCastOp>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
if (failed(applyPartialConversion(m, target, patterns))) {
if (failed(applyFullConversion(m, target, patterns))) {
signalPassFailure();
}
}

View File

@ -55,6 +55,7 @@ cc_library(
"transforms/passes.h",
],
deps = [
":attribute_importer",
":type_to_shape",
":xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/mlir/hlo",
@ -69,7 +70,7 @@ cc_library(
"//tensorflow/compiler/xla/client/lib:conv_grad_size_util",
"//tensorflow/core:framework",
"//tensorflow/core/kernels:conv_grad_shape_utils",
"//tensorflow/core/lib/bfloat16",
"//tensorflow/core/platform:bfloat16",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:Dialect",

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/platform/bfloat16.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@ -83,6 +83,9 @@ StatusOr<llvm::SmallVector<AffineMap, 1>> GetPermutationIfAvailable(
strides[dim] = accumulated_stride;
accumulated_stride *= shape.dimensions(dim);
}
if (accumulated_stride == 0) {
return llvm::SmallVector<AffineMap, 1>{};
}
return llvm::SmallVector<AffineMap, 1>{
makeStridedLinearLayoutMap(strides, /*offset=*/0, builder.getContext())};
}

View File

@ -8,6 +8,6 @@ HloModule TestModule
ENTRY TestComputation {
x = f32[3, 2]{1,0} parameter(0)
// CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}}) : (memref<3x2xf32>, memref<3x2xf32, #[[MAP]]>) -> ()
// CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}}) {name = "copy.1"} : (memref<3x2xf32>, memref<3x2xf32, #[[MAP]]>) -> ()
ROOT x.copy = f32[3, 2]{0,1} copy(x)
}

View File

@ -44,7 +44,7 @@ attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} {
// CHECK-LABEL: func @case
// CHECK-SAME: %[[BRANCH_INDEX:.*]]: tensor<i32>, %[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>)
func @case(%index: tensor<i32>, %arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
%0:2 = "tf.Case"(%index, %arg0, %arg1) {branches = [@exponential, @log, @floor]} : (tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
%0:2 = "tf.Case"(%index, %arg0, %arg1) {branches = [@exponential, @log, @floor], is_stateless = true} : (tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// CHECK: %[[TUPLE_INPUT:.*]] = "mhlo.tuple"(%[[ARG0]], %[[ARG1]]) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
// CHECK: %[[CASE:.*]]:2 = "mhlo.case"(%[[BRANCH_INDEX]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]]) ( {
// CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple<tensor<f32>, tensor<f32>>):

View File

@ -265,6 +265,31 @@ func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2
return %0#0 : tensor<2xi32>
}
// CHECK-LABEL: bessel_i0e
func @bessel_i0e(%arg0: tensor<3xf16>, %arg1: tensor<3xf32>, %arg2: tensor<3xf64>) -> (tensor<3xf16>, tensor<3xf32>, tensor<3xf64>) {
// CHECK-NOT: tf.BesselI0e
%0 = "tf.BesselI0e"(%arg0) : (tensor<3xf16>) -> (tensor<3xf16>)
%1 = "tf.BesselI0e"(%arg1) : (tensor<3xf32>) -> (tensor<3xf32>)
%2 = "tf.BesselI0e"(%arg2) : (tensor<3xf64>) -> (tensor<3xf64>)
return %0, %1, %2 : tensor<3xf16>, tensor<3xf32>, tensor<3xf64>
}
// CHECK-LABEL: bessel_i1e
func @bessel_i1e(%arg0: tensor<3xf16>, %arg1: tensor<3xf32>, %arg2: tensor<3xf64>) -> (tensor<3xf16>, tensor<3xf32>, tensor<3xf64>) {
// CHECK-NOT: tf.BesselI1e
%0 = "tf.BesselI1e"(%arg0) : (tensor<3xf16>) -> (tensor<3xf16>)
%1 = "tf.BesselI1e"(%arg1) : (tensor<3xf32>) -> (tensor<3xf32>)
%2 = "tf.BesselI1e"(%arg2) : (tensor<3xf64>) -> (tensor<3xf64>)
return %0, %1, %2 : tensor<3xf16>, tensor<3xf32>, tensor<3xf64>
}
// CHECK-LABEL: diag
func @diag(%arg0: tensor<2xf32>) -> tensor<2x2xf32> {
// CHECK-NOT: tf.Diag
%0 = "tf.Diag"(%arg0) : (tensor<2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
// available but doesn't support this instance.
}

View File

@ -3484,6 +3484,20 @@ func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tenso
return %result : tensor<2x8x8x8x1xf32>
}
// CHECK-LABEL: @collective_permute
func @collective_permute(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> {
%source_target_pairs = "tf.Const" () {
value = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi32>
} : () -> tensor<3x2xi32>
// CHECK: "mhlo.collective_permute"
// CHECK-SAME: source_target_pairs = dense<{{\[}}[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>
%0 = "tf.CollectivePermute"(%arg0, %source_target_pairs) {
} : (tensor<128x32xf32>, tensor<3x2xi32>) -> tensor<128x32xf32>
return %0 : tensor<128x32xf32>
}
// CHECK-LABEL: @cross_replica_sum
func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> {
%replica_groups = "tf.Const" () {
@ -3775,7 +3789,7 @@ func @unsorted_segment_prod(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x1
// CHECK-LABEL: @unsorted_segment_min
func @unsorted_segment_min(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) {
%num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
// CHECK: mhlo.constant dense<0x7F800000> : tensor<f32>
// CHECK: mhlo.constant dense<3.40282347E+38> : tensor<f32>
// CHECK: mhlo.scatter
// CHECK: mhlo.minimum
%0 = "tf.UnsortedSegmentMin"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>)
@ -3785,7 +3799,7 @@ func @unsorted_segment_min(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16
// CHECK-LABEL: @unsorted_segment_max
func @unsorted_segment_max(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) {
%num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
// CHECK: mhlo.constant dense<0xFF800000> : tensor<f32>
// CHECK: mhlo.constant dense<-3.40282347E+38> : tensor<f32>
// CHECK: mhlo.scatter
// CHECK: mhlo.maximum
%0 = "tf.UnsortedSegmentMax"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>)
@ -4668,6 +4682,20 @@ func @cumsum_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<i32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
//===----------------------------------------------------------------------===//
// Cumprod op legalizations.
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @cumprod
func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: [[INIT:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) ( {
// CHECK: mhlo.mul
%0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.Cumprod"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
return %1 : tensor<4xf32>
}
//===----------------------------------------------------------------------===//
// Qr op legalization
//===----------------------------------------------------------------------===//
@ -4766,3 +4794,20 @@ func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> {
// CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf64>
return %0 : tensor<8x16xf64>
}
// CHECK-LABEL: @xla_gather
func @xla_gather(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x1x300xf32> {
%cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi64> } : () -> tensor<3xi64>
// CHECK: "mhlo.gather"
// CHECK-SAME: dimension_numbers =
// CHECK-SAME: collapsed_slice_dims = dense<0> : tensor<1xi64>
// CHECK-SAME: index_vector_dim = 1 : i64
// CHECK-SAME: offset_dims = dense<1> : tensor<1xi64>
// CHECK-SAME: start_index_map = dense<0> : tensor<1xi64>
// CHECK-SAME: indices_are_sorted = true
// CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>
%0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\01\01\12\01\00\1A\01\00 \01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<10x1x300xf32>
return %0 : tensor<10x1x300xf32>
}

View File

@ -50,6 +50,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
#include "tensorflow/compiler/mlir/xla/attribute_importer.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h"
#include "tensorflow/compiler/xla/client/padding.h"
@ -57,7 +58,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/kernel_shape_util.h"
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/platform/bfloat16.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
@ -262,49 +263,21 @@ tensorflow::TensorShape ToTensorShape(
sizes.begin(), sizes.end()));
}
// Returns minimal value for the given int or float element type.
static ConstOp GetMinValueForType(Type ty, Location loc,
PatternRewriter *rewriter) {
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
DenseElementsAttr attr;
if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) {
APFloat neg_inf =
APFloat::getInf(float_ty.getFloatSemantics(), /*negative=*/true);
attr = DenseElementsAttr::get(scalar_ty, neg_inf);
} else {
auto int_ty = ty.cast<IntegerType>();
APInt min_val = APInt::getSignedMinValue(int_ty.getWidth());
attr = DenseElementsAttr::get(scalar_ty, min_val);
}
return rewriter->create<ConstOp>(loc, attr);
}
// Returns maximal value for the given int or float element type.
static ConstOp GetMaxValueForType(Type ty, Location loc,
PatternRewriter *rewriter) {
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
DenseElementsAttr attr;
if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) {
APFloat pos_inf =
APFloat::getInf(float_ty.getFloatSemantics(), /*negative=*/false);
attr = DenseElementsAttr::get(scalar_ty, pos_inf);
} else {
auto int_ty = ty.cast<IntegerType>();
APInt max_val = APInt::getSignedMaxValue(int_ty.getWidth());
attr = DenseElementsAttr::get(scalar_ty, max_val);
}
return rewriter->create<ConstOp>(loc, attr);
}
// Returns int or float scalar DenseElementsAttr attribute with the given
// element type and the value.
// Returns int, float, or complex scalar DenseElementsAttr attribute with the
// given element type and the value.
static ConstOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value,
OpBuilder *builder) {
return builder->create<ConstOp>(loc, hlo::GetScalarOfType(ty, raw_value));
}
// Returns a limit scalar const op for the given type.
// Requires FloatType or IntegerType
static ConstOp GetScalarLimitConstOfType(Type ty, Location loc,
hlo::ScalarLimit limit,
OpBuilder *builder) {
return builder->create<ConstOp>(loc, hlo::GetScalarLimitOfType(ty, limit));
}
// Creates an mhlo::SliceOp where the major dimensions have full size, and
// the minor dimensions have the provided offsets and sizes.
static Value SliceInMinorDims(Location loc, Value v,
@ -1065,6 +1038,21 @@ static void BuildSortComparisonBody(llvm::ArrayRef<Type> element_types,
builder->create<mhlo::ReturnOp>(loc, compare);
}
//===----------------------------------------------------------------------===//
// XlaGather op utilities.
//===----------------------------------------------------------------------===//
bool HasValidGatherDims(StringAttr attr) {
::xla::GatherDimensionNumbers dims;
return dims.ParseFromString(attr.getValue().str());
}
GatherDimensionNumbers GetGatherDimNumsAttr(StringAttr attr, Builder *builder) {
::xla::GatherDimensionNumbers dims;
if (!dims.ParseFromString(attr.getValue().str())) return {};
return ::xla::ConvertGatherDimensionNumbers(dims, builder);
}
//===----------------------------------------------------------------------===//
// Op converters.
//===----------------------------------------------------------------------===//
@ -2385,15 +2373,16 @@ class ConvertMaxPoolOp : public OpRewritePattern<OpTy> {
op.input().getType().template cast<TensorType>().getElementType();
if (!element_type.isSignlessIntOrFloat()) return failure();
Location loc = op.getLoc();
ConstOp init = GetMinValueForType(element_type, loc, &rewriter);
ConstOp init = GetScalarLimitConstOfType(element_type, loc,
hlo::kInfinityLowest, &rewriter);
auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
if (!input_ty) return failure();
DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
auto reduce = rewriter.create<ReduceWindowOp>(
loc, op.getType(), op.input(), init.getResult(),
GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
loc, op.getType(), op.input(), init, GetI64ElementsAttr(op.ksize()),
GetI64ElementsAttr(op.strides()),
/*base_dilations=*/DenseIntElementsAttr(),
/*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
BuildReduceBody<MaxOp>(element_type, &reduce.body(), &rewriter);
@ -3636,7 +3625,8 @@ class ConvertMaxOp
static Value GetInitialValue(Type reduce_element_type, Location loc,
PatternRewriter *rewriter) {
return GetMinValueForType(reduce_element_type, loc, rewriter);
return GetScalarLimitConstOfType(reduce_element_type, loc,
hlo::kInfinityLowest, rewriter);
}
};
@ -3653,7 +3643,8 @@ class ConvertMinOp
static Value GetInitialValue(Type reduce_element_type, Location loc,
PatternRewriter *rewriter) {
return GetMaxValueForType(reduce_element_type, loc, rewriter);
return GetScalarLimitConstOfType(reduce_element_type, loc,
hlo::kInfinityMax, rewriter);
}
};
@ -3789,7 +3780,8 @@ class ConvertArgMaxOp
static Value GetInitialValue(Type reduce_element_type, Location loc,
PatternRewriter &rewriter) {
return GetMinValueForType(reduce_element_type, loc, &rewriter);
return GetScalarLimitConstOfType(reduce_element_type, loc,
hlo::kInfinityLowest, &rewriter);
}
static StringRef GetDirection() { return "GT"; }
@ -4728,7 +4720,7 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern<OpTy> {
auto output_type =
RankedTensorType::get(output_shape, data_type.getElementType());
// Broadccast the initial value for reduction. This will become the
// Broadcast the initial value for reduction. This will become the
// 'operand' parameter to scatter to for the final scatter op.
Value init = ConcreteClass::GetInitialValue(data_type.getElementType(),
op.getLoc(), &rewriter);
@ -4768,7 +4760,8 @@ class ConvertUnsortedSegmentMaxOp
static Value GetInitialValue(Type reduce_element_type, Location loc,
PatternRewriter *rewriter) {
return GetMinValueForType(reduce_element_type, loc, rewriter);
return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kLowest,
rewriter);
}
};
@ -4781,7 +4774,8 @@ class ConvertUnsortedSegmentMinOp
static Value GetInitialValue(Type reduce_element_type, Location loc,
PatternRewriter *rewriter) {
return GetMaxValueForType(reduce_element_type, loc, rewriter);
return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kMax,
rewriter);
}
};
@ -5092,17 +5086,19 @@ class ConvertXlaDynamicUpdateSliceOp
}
};
/// Converts the Cumsum TensorFlow op to the HLO ReduceWindow op by setting
/// appropriate window dimensions, with 'add' as the reduction function. The
/// input tensor needs to have a static shape, and 'axis' must be const. The
/// TableGen pattern is not used for this rewrite because it involves regions.
class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
using OpRewritePattern<TF::CumsumOp>::OpRewritePattern;
// Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by
// setting appropriate window dimensions, with the given aggregation op as the
// reduction function. The input tensor needs to have a static shape, and 'axis'
// must be const. The TableGen pattern is not used for this rewrite because it
// involves regions.
template <typename OpT, typename AggregationOp>
class ConvertCumOp : public OpRewritePattern<OpT> {
using OpRewritePattern<OpT>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::CumsumOp op,
LogicalResult matchAndRewrite(OpT op,
PatternRewriter &rewriter) const override {
auto input = op.x();
auto input_type = input.getType().dyn_cast<ShapedType>();
auto input_type = input.getType().template dyn_cast<ShapedType>();
if (!input_type || !input_type.hasStaticShape()) {
return failure();
}
@ -5135,6 +5131,10 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
// Convert if we need to enlarge the element type's bitwidth to avoid
// precision loss.
Type input_element_type = input_type.getElementType();
// TODO(hinsu): Handle complex element types.
if (!input_element_type.isIntOrFloat()) return failure();
Type sum_element_type = GetSumAccumulationType(input_element_type);
input = rewriter.create<ConvertOp>(op.getLoc(), input, sum_element_type);
@ -5148,8 +5148,9 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
RankedTensorType::get({rank, 2}, rewriter.getIntegerType(64)),
paddings);
Value init =
GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter);
int64_t init_value = (std::is_same<AggregationOp, AddOp>::value) ? 0 : 1;
Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value,
&rewriter);
auto reduce = rewriter.create<ReduceWindowOp>(
op.getLoc(), input_type, input, init,
@ -5157,7 +5158,7 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)),
/*base_dilations=*/DenseIntElementsAttr(),
/*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
BuildReduceBody<AddOp>(sum_element_type, &reduce.body(), &rewriter);
BuildReduceBody<AggregationOp>(sum_element_type, &reduce.body(), &rewriter);
Value result = reduce.getResult();
if (op.exclusive()) {
@ -5193,6 +5194,9 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
}
};
using ConvertCumsumOp = ConvertCumOp<TF::CumsumOp, AddOp>;
using ConvertCumprodOp = ConvertCumOp<TF::CumprodOp, MulOp>;
// Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard
// dialect lowerings. This involves extracting the shape type, extracting and
// converting each dimension to a known integer type, and repacking into a final
@ -5857,7 +5861,7 @@ void PopulateLegalizeTfPatterns(MLIRContext *context,
ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp,
ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp,
ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp,
ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp,
ConvertCumprodOp, ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp,
ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op,
ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,

View File

@ -51,6 +51,10 @@ def GetHLOAxisFromTFAxisVariadic : NativeCodeCall<
"$0, (*$1.begin()).getType().cast<RankedTensorType>().getRank(), "
"&$_builder)">;
def CastElementsToI64Elements : NativeCodeCall<
"hlo::ConvertElementsAttr("
"$0, $_builder.getIntegerType(64)).cast<DenseIntElementsAttr>()">;
def : Pattern<
(TF_FusedBatchNormOp:$root $x, $scale, $offset, $mean, $variance, $epsilon,
$exponential_avg_factor, $data_format,
@ -255,12 +259,16 @@ def : Pat<(TF_ConcatV2Op $inputs, (TF_ConstOp OneElementAttr:$axis)),
[(HasRankedFirstOperand $inputs)]>;
//===----------------------------------------------------------------------===//
// CrossReplicaSum op patterns.
// CollectivePermute op patterns.
//===----------------------------------------------------------------------===//
def CastElementsToI64Elements : NativeCodeCall<
"hlo::ConvertElementsAttr("
"$0, $_builder.getIntegerType(64)).cast<DenseIntElementsAttr>()">;
def : Pat<(TF_CollectivePermuteOp $input, (TF_ConstOp $source_target_pairs)),
(HLO_CollectivePermuteOp $input,
(CastElementsToI64Elements $source_target_pairs))>;
//===----------------------------------------------------------------------===//
// CrossReplicaSum op patterns.
//===----------------------------------------------------------------------===//
def : Pat<(TF_CrossReplicaSumOp $input, (TF_ConstOp $group_assignment)),
(HLO_CrossReplicaSumOp $input,
@ -660,3 +668,18 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features),
),
(replaceWithValue $output)
]>;
//===----------------------------------------------------------------------===//
// XlaGather op.
//===----------------------------------------------------------------------===//
def ToGatherDimNumsAttr : NativeCodeCall<"GetGatherDimNumsAttr($0, &$_builder)">;
def HasValidGatherDims : Constraint<CPred<"HasValidGatherDims($0)">>;
def : Pat<(TF_XlaGatherOp $operand, $start_indices, (TF_ConstOp $slice_sizes),
$dimension_numbers, $indices_are_sorted),
(HLO_GatherOp $operand, $start_indices,
(ToGatherDimNumsAttr $dimension_numbers),
$slice_sizes, $indices_are_sorted),
[(HasValidGatherDims $dimension_numbers)]>;

View File

@ -102,6 +102,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::BatchMatMulV2Op>(),
TypeID::get<TF::BatchToSpaceNDOp>(),
TypeID::get<TF::BatchToSpaceOp>(),
TypeID::get<TF::BesselI0eOp>(),
TypeID::get<TF::BesselI1eOp>(),
TypeID::get<TF::BiasAddGradOp>(),
TypeID::get<TF::BiasAddOp>(),
TypeID::get<TF::BitwiseAndOp>(),
@ -116,6 +118,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::CrossOp>(),
TypeID::get<TF::DataFormatDimMapOp>(),
TypeID::get<TF::DataFormatVecPermuteOp>(),
TypeID::get<TF::DiagOp>(),
TypeID::get<TF::DigammaOp>(),
TypeID::get<TF::DivNoNanOp>(),
TypeID::get<TF::EluGradOp>(),

View File

@ -34,7 +34,6 @@ limitations under the License.
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassOptions.h" // from @llvm-project
#include "mlir/Translation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
@ -182,7 +181,10 @@ template <typename OpType>
StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
HloInstruction* instr) {
Location loc = getLocation(instr);
ArrayRef<std::pair<Identifier, Attribute>> attrs;
std::pair<Identifier, Attribute> attrs[] = {
{Identifier::get("name", builder_.getContext()),
builder_.getStringAttr(instr->name())},
};
ArrayRef<Type> rets{};
llvm::SmallVector<Value, 4> operands;
@ -252,15 +254,14 @@ Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) {
return Status::OK();
}
StatusOr<mlir::Operation*> LhloDialectEmitter::EmitSortOp(
HloInstruction* instr) {
StatusOr<lmhlo::SortOp> LhloDialectEmitter::EmitSortOp(HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs<lmhlo::SortOp>(instr));
auto* sort_instr = ::xla::Cast<::xla::HloSortInstruction>(instr);
sort.dimensionAttr(builder_.getI64IntegerAttr(sort_instr->sort_dimension()));
sort.is_stableAttr(builder_.getBoolAttr(sort_instr->is_stable()));
TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
*sort_instr->called_computations()[0], &sort.comparator(), &builder_));
return sort.getOperation();
return sort;
}
Status LhloDialectEmitter::HandleSort(HloInstruction* instr) {

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@ -41,7 +42,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
builder_(module.getContext()),
i8_type_(builder_.getIntegerType(8)) {}
::xla::StatusOr<mlir::Operation*> EmitSortOp(::xla::HloInstruction* instr);
::xla::StatusOr<lmhlo::SortOp> EmitSortOp(::xla::HloInstruction* instr);
private:
template <typename OpType>

Some files were not shown because too many files have changed in this diff Show More