Merge branch 'master' into yunfeimao/matmul_tanh_fusion
This commit is contained in:
commit
3f9d668ed0
16
.github/bot_config.yml
vendored
16
.github/bot_config.yml
vendored
@ -40,6 +40,22 @@ segfault_memory:
|
||||
# assignees
|
||||
filesystem_security_assignee:
|
||||
- mihaimaruseac
|
||||
|
||||
tflite_micro_path:
|
||||
- tensorflow/lite/micro
|
||||
|
||||
tflite_micro_comment: >
|
||||
Thanks for contributing to TensorFlow Lite Micro.
|
||||
|
||||
|
||||
To keep this process moving along, we'd like to make sure that you have completed the items on this list:
|
||||
* Read the [contributing guidelines for TensorFlow Lite Micro](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/CONTRIBUTING.md)
|
||||
* Created a [TF Lite Micro Github issue](https://github.com/tensorflow/tensorflow/issues/new?labels=comp%3Amicro&template=70-tflite-micro-issue.md)
|
||||
* Linked to the issue from the PR description
|
||||
|
||||
|
||||
We would like to have a discussion on the Github issue first to determine the best path forward, and then proceed to the PR review.
|
||||
|
||||
# Cuda Comment
|
||||
cuda_comment: >
|
||||
From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries:
|
||||
|
12
RELEASE.md
12
RELEASE.md
@ -37,6 +37,9 @@
|
||||
* XLA:CPU and XLA:GPU devices are no longer registered by default. Use
|
||||
`TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
|
||||
removed).
|
||||
* `tf.raw_ops.Max` and `tf.raw_ops.Min` no longer accept inputs of type
|
||||
`tf.complex64` or `tf.complex128`, because the behavior of these ops is not
|
||||
well defined for complex types.
|
||||
|
||||
## Known Caveats
|
||||
|
||||
@ -120,6 +123,13 @@
|
||||
customization of how gradients are aggregated across devices, as well as
|
||||
`gradients_transformers` to allow for custom gradient transformations
|
||||
(such as gradient clipping).
|
||||
* The `steps_per_execution` argument in `compile()` is no longer
|
||||
experimental; if you were passing `experimental_steps_per_execution`,
|
||||
rename it to `steps_per_execution` in your code. This argument controls
|
||||
the number of batches to run during each `tf.function` call when calling
|
||||
`fit()`. Running multiple batches inside a single `tf.function` call can
|
||||
greatly improve performance on TPUs or small models with a large Python
|
||||
overhead.
|
||||
* `tf.function` / AutoGraph:
|
||||
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
||||
True, the function may use type annotations to optimize the tracing
|
||||
@ -147,6 +157,8 @@
|
||||
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API
|
||||
* Prefer using `NnApiDelegate()` and related delegate configuration methods directly.
|
||||
* Add NNAPI Delegation support for requantization use cases by converting the operation into a dequantize-quantize pair.
|
||||
* TFLite Profiler for Android is available. See the detailed
|
||||
[guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android).
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.random`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
@ -387,6 +387,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/platform",
|
||||
"//tensorflow/core/platform:blocking_counter",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/platform/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/net.h"
|
||||
@ -560,6 +561,21 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
|
||||
collective_executor_handle->get()->StartAbort(status->status);
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
|
||||
const char* task,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto collective_executor_handle = context->GetCollectiveExecutorHandle();
|
||||
tensorflow::Notification done;
|
||||
collective_executor_handle->get()->remote_access()->CheckPeerHealth(
|
||||
task, [&done, status](const Status& s) {
|
||||
status->status = s;
|
||||
done.Notify();
|
||||
});
|
||||
done.WaitForNotification();
|
||||
}
|
||||
|
||||
TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
|
||||
TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
|
||||
result->num_items = num_items;
|
||||
|
@ -238,6 +238,13 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
|
||||
TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// Checks the health of collective ops peers. Explicit health check is needed in
|
||||
// multi worker collective ops to detect failures in the cluster. If a peer is
|
||||
// down, collective ops may hang.
|
||||
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
|
||||
const char* task,
|
||||
TF_Status* status);
|
||||
|
||||
// Information about the shape of a Tensor and its type.
|
||||
struct TF_ShapeAndType {
|
||||
// Number of dimensions. -1 indicates unknown rank.
|
||||
|
@ -1704,66 +1704,5 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) {
|
||||
TF_DeleteFunction(func1);
|
||||
}
|
||||
|
||||
// This test only works when the TF build includes XLA compiler. One way to set
|
||||
// this up is via bazel build option "--define with_xla_support=true".
|
||||
//
|
||||
// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to
|
||||
// something like TENSORFLOW_CAPI_USE_XLA.
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
TEST_F(CApiFunctionTest, StatelessIf_XLA) {
|
||||
TF_Function* func;
|
||||
const std::string funcName = "BranchFunc";
|
||||
DefineFunction(funcName.c_str(), &func);
|
||||
TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
TF_Operation* feed = Placeholder(host_graph_, s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
TF_Operation* true_cond = ScalarConst(true, host_graph_, s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
TF_OperationDescription* desc =
|
||||
TF_NewOperation(host_graph_, "StatelessIf", "IfNode");
|
||||
TF_AddInput(desc, {true_cond, 0});
|
||||
TF_Output inputs[] = {{feed, 0}};
|
||||
TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs));
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
TF_SetAttrType(desc, "Tcond", TF_BOOL);
|
||||
TF_DataType inputType = TF_INT32;
|
||||
TF_SetAttrTypeList(desc, "Tin", &inputType, 1);
|
||||
TF_SetAttrTypeList(desc, "Tout", &inputType, 1);
|
||||
TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size());
|
||||
TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size());
|
||||
TF_SetDevice(desc, "/device:XLA_CPU:0");
|
||||
auto op = TF_FinishOperation(desc, s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
ASSERT_NE(op, nullptr);
|
||||
|
||||
// Create a session for this graph.
|
||||
CSession csession(host_graph_, s_, /*use_XLA*/ true);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
// Run the graph.
|
||||
csession.SetInputs({{feed, Int32Tensor(17)}});
|
||||
csession.SetOutputs({op});
|
||||
csession.Run(s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
TF_Tensor* out = csession.output_tensor(0);
|
||||
ASSERT_TRUE(out != nullptr);
|
||||
EXPECT_EQ(TF_INT32, TF_TensorType(out));
|
||||
EXPECT_EQ(0, TF_NumDims(out)); // scalar
|
||||
ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
|
||||
int32* output_contents = static_cast<int32*>(TF_TensorData(out));
|
||||
EXPECT_EQ(-17, *output_contents);
|
||||
|
||||
// Clean up
|
||||
csession.CloseAndDelete(s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
TF_DeleteFunction(func);
|
||||
}
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -6,7 +6,6 @@ load(
|
||||
"tf_copts",
|
||||
"tf_cuda_cc_test",
|
||||
"tf_cuda_library",
|
||||
"tfe_xla_copts",
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config.bzl",
|
||||
@ -31,7 +30,7 @@ tf_cuda_library(
|
||||
"c_api_unified_experimental.h",
|
||||
],
|
||||
hdrs = ["c_api.h"],
|
||||
copts = tf_copts() + tfe_xla_copts(),
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
@ -72,13 +71,6 @@ tf_cuda_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
],
|
||||
}) + select({
|
||||
"//tensorflow:with_xla_support": [
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/jit",
|
||||
"//tensorflow/compiler/jit:xla_device",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}) + [
|
||||
"@com_google_absl//absl/memory",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
@ -228,7 +220,6 @@ tf_cuda_cc_test(
|
||||
"gradients_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
deps = [
|
||||
@ -278,6 +269,7 @@ cc_library(
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
@ -290,12 +282,9 @@ tf_cuda_cc_test(
|
||||
"mnist_gradients_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + [
|
||||
"nomac",
|
||||
"notap", # TODO(b/166150182): Enable
|
||||
"no_oss", # TODO(b/166150182): Enable
|
||||
],
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
@ -553,7 +542,6 @@ tf_cuda_cc_test(
|
||||
"c_api_debug_test.cc",
|
||||
"c_api_test.cc",
|
||||
],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"noguitar", # TODO(b/155445984): flaky
|
||||
#"guitar",
|
||||
@ -608,7 +596,6 @@ tf_cuda_cc_test(
|
||||
],
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"no_windows",
|
||||
],
|
||||
@ -641,7 +628,6 @@ tf_cuda_cc_test(
|
||||
],
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"no_windows",
|
||||
],
|
||||
@ -660,7 +646,6 @@ tf_cuda_cc_test(
|
||||
],
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"no_windows",
|
||||
"noasan", # leaks gRPC server instances
|
||||
@ -694,7 +679,6 @@ tf_cuda_cc_test(
|
||||
],
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"no_windows",
|
||||
],
|
||||
@ -729,7 +713,7 @@ tf_cuda_library(
|
||||
"c_api_experimental.h",
|
||||
"c_api_unified_experimental.h",
|
||||
],
|
||||
copts = tf_copts() + tfe_xla_copts(),
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
@ -801,7 +785,6 @@ tf_cuda_cc_test(
|
||||
"c_api_experimental_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
deps = [
|
||||
@ -825,7 +808,6 @@ tf_cuda_cc_test(
|
||||
"c_api_unified_experimental_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
deps = [
|
||||
|
@ -51,9 +51,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/protobuf/device_filters.pb.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
#include "tensorflow/core/common_runtime/copy_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
@ -1148,26 +1145,23 @@ void TFE_DeleteOp(TFE_Op* op) {
|
||||
tensorflow::unwrap(op)->Release();
|
||||
}
|
||||
|
||||
const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) {
|
||||
return tensorflow::unwrap(op)->Name().c_str();
|
||||
}
|
||||
|
||||
TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
|
||||
return tensorflow::wrap(
|
||||
&(OperationFromInterface(tensorflow::unwrap(op))->EagerContext()));
|
||||
}
|
||||
|
||||
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
||||
status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
|
||||
}
|
||||
|
||||
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
|
||||
const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status) {
|
||||
return tensorflow::unwrap(op)->DeviceName().c_str();
|
||||
}
|
||||
|
||||
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
tensorflow::Status s = tensorflow::unwrap(op)->SetUseXla(enable);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
|
||||
}
|
||||
#else
|
||||
LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
|
||||
"built with XLA support.";
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
}
|
||||
|
||||
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
||||
status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
|
||||
}
|
||||
@ -1180,6 +1174,15 @@ void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||
static_cast<size_t>(num_inputs)});
|
||||
}
|
||||
|
||||
extern int TFE_OpGetFlatInputCount(const TFE_Op* op, TF_Status* status) {
|
||||
return tensorflow::unwrap(op)->GetInputs().size();
|
||||
}
|
||||
|
||||
extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, int index,
|
||||
TF_Status* status) {
|
||||
return tensorflow::wrap(tensorflow::unwrap(op)->GetInputs()[index]);
|
||||
}
|
||||
|
||||
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
unsigned char* is_list, TF_Status* status) {
|
||||
TF_AttrType ret = TF_ATTR_INT;
|
||||
@ -1485,7 +1488,7 @@ void TFE_ContextEndStep(TFE_Context* ctx) {
|
||||
tensorflow::unwrap(ctx)->EndStep();
|
||||
}
|
||||
|
||||
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
|
||||
const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op) {
|
||||
return tensorflow::wrap(
|
||||
&OperationFromInterface(tensorflow::unwrap(op))->Attrs());
|
||||
}
|
||||
@ -1611,19 +1614,12 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
return status.status;
|
||||
}
|
||||
|
||||
tensorflow::Status Execute(tensorflow::EagerOperation* op,
|
||||
tensorflow::Status Execute(const tensorflow::EagerOperation* op,
|
||||
tensorflow::TensorHandle** retvals,
|
||||
int* num_retvals) override {
|
||||
std::vector<TFE_TensorHandle*> inputs;
|
||||
inputs.reserve(op->Inputs().size());
|
||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
||||
op->Inputs()[i]->Ref();
|
||||
inputs.push_back(tensorflow::wrap(op->Inputs()[i]));
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||
TF_Status status;
|
||||
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
|
||||
wrap(&op->Attrs()), num_retvals, outputs.data(), &status,
|
||||
device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
|
||||
info_);
|
||||
if (status.status.ok()) {
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
@ -1633,10 +1629,6 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
TFE_DeleteTensorHandle(outputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto inp : inputs) {
|
||||
TFE_DeleteTensorHandle(inp);
|
||||
}
|
||||
return status.status;
|
||||
}
|
||||
|
||||
|
@ -248,22 +248,22 @@ typedef struct TFE_Op TFE_Op;
|
||||
TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx,
|
||||
const char* op_or_function_name,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op);
|
||||
|
||||
// Returns the op or function name `op` will execute.
|
||||
//
|
||||
// The returned string remains valid throughout the lifetime of 'op'.
|
||||
TF_CAPI_EXPORT extern const char* TFE_OpGetName(const TFE_Op* op,
|
||||
TF_Status* status);
|
||||
TF_CAPI_EXPORT extern TFE_Context* TFE_OpGetContext(const TFE_Op* op,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name,
|
||||
TF_Status* status);
|
||||
// The returned string remains valid throughout the lifetime of 'op'.
|
||||
TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(TFE_Op* op,
|
||||
TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(const TFE_Op* op,
|
||||
TF_Status* status);
|
||||
|
||||
// When 'enable' is set to 1, and if TensorFlow library is built with XLA
|
||||
// support, a subsequent TFE_Execute() call on `op` will run the op via XLA.
|
||||
//
|
||||
// If the library is not built with XLA support, this call would be a no-op.
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetXLACompilation(TFE_Op* op,
|
||||
unsigned char enable);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input,
|
||||
TF_Status* status);
|
||||
|
||||
@ -272,6 +272,23 @@ TF_CAPI_EXPORT extern void TFE_OpAddInputList(TFE_Op* op,
|
||||
int num_inputs,
|
||||
TF_Status* status);
|
||||
|
||||
// Fetches the current number of inputs attached to `op`.
|
||||
//
|
||||
// Does not use the operation's definition to determine how many inputs should
|
||||
// be attached. It is intended for use with TFE_OpGetFlatInput to inspect an
|
||||
// already-finalized operation.
|
||||
//
|
||||
// Note that TFE_OpGetFlatInputCount and TFE_OpGetFlatInput operate on a flat
|
||||
// sequence of inputs, unlike TFE_OpGetInputLength (for getting the length of a
|
||||
// particular named input list, which may only be part of the op's inputs).
|
||||
TF_CAPI_EXPORT extern int TFE_OpGetFlatInputCount(const TFE_Op* op,
|
||||
TF_Status* status);
|
||||
// Returns a borrowed reference to one of `op`'s inputs. Use
|
||||
// `TFE_TensorHandleCopySharingTensor` to make a new reference.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op,
|
||||
int index,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op,
|
||||
const char* attr_name,
|
||||
unsigned char* is_list,
|
||||
|
@ -22,9 +22,6 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
@ -64,87 +61,6 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
auto* device = absl::get<tensorflow::Device*>(handle->device());
|
||||
|
||||
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
|
||||
auto* xla_device = dynamic_cast<tensorflow::XlaDevice*>(device);
|
||||
if (xla_device != nullptr) {
|
||||
tensorflow::XlaDevice::PaddedShapeFn shape_fn =
|
||||
xla_device->metadata().padded_shape_fn();
|
||||
xla::Shape padded_shape;
|
||||
status->status = shape_fn(*tensor, &padded_shape);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (VLOG_IS_ON(3)) {
|
||||
std::vector<tensorflow::int64> shape_to_log =
|
||||
TensorShapeAsVector(*handle, &status->status);
|
||||
if (!status->status.ok()) {
|
||||
// Ignore the status here as we are simply logging.
|
||||
status->status = tensorflow::Status::OK();
|
||||
} else {
|
||||
VLOG(3) << "Fully padded shape of ["
|
||||
<< absl::StrJoin(shape_to_log, ", ") << "] is "
|
||||
<< padded_shape.DebugString();
|
||||
}
|
||||
}
|
||||
|
||||
if (padded_shape.IsTuple()) {
|
||||
if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) {
|
||||
// Currently, the only case of XlaTensor containing a tuple shape is to
|
||||
// represent 64 bit ints, doubles, and complex numbers (we don't support
|
||||
// 64bit complex numbers).
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"XlaTensors should only contain tuples of size 2. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// shape0 is not a const& because we will assign it to padded_shape below.
|
||||
// It is illegal to assign a part of a message to itself.
|
||||
xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0);
|
||||
const xla::Shape& shape1 =
|
||||
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
|
||||
if (shape0.IsTuple() || shape1.IsTuple()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"XlaTensors should not contain nested tuples. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
}
|
||||
if (!xla::ShapeUtil::Equal(shape0, shape1)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Subshapes of XlaTensors should be the same. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Since the only case we handle here are two equal subshapes, we
|
||||
// simply return one of them. The caller will interpret it as this
|
||||
// shape directly storing the 64bit types. This approximation is good
|
||||
// enough for this API's debugging use case.
|
||||
padded_shape = shape0;
|
||||
}
|
||||
|
||||
int rank = padded_shape.dimensions_size();
|
||||
std::vector<tensorflow::int64> dev_dims;
|
||||
dev_dims.reserve(rank);
|
||||
if (rank == 1) {
|
||||
// Rank 1 tensors might not have padded_shape.layout.minor_to_major set,
|
||||
dev_dims.push_back(padded_shape.dimensions(0));
|
||||
} else {
|
||||
for (int i = rank - 1; i >= 0; --i) {
|
||||
tensorflow::int64 dim_index = padded_shape.layout().minor_to_major(i);
|
||||
dev_dims.push_back(padded_shape.dimensions(dim_index));
|
||||
}
|
||||
}
|
||||
status->status = tensorflow::Status::OK();
|
||||
return new TFE_TensorDebugInfo(dev_dims);
|
||||
}
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
|
||||
// If the tensor is not an XLA tensor, the device shape is
|
||||
// the same as regular tensor shape.
|
||||
std::vector<tensorflow::int64> dev_dims =
|
||||
TensorShapeAsVector(*handle, &status->status);
|
||||
if (!status->status.ok()) {
|
||||
|
@ -414,7 +414,7 @@ typedef struct TFE_OpAttrs TFE_OpAttrs;
|
||||
|
||||
// Fetch a reference to `op`'s attributes. The returned reference is only valid
|
||||
// while `op` is alive.
|
||||
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op);
|
||||
TF_CAPI_EXPORT extern const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op);
|
||||
// Add attributes in `attrs` to `op`.
|
||||
//
|
||||
// Does not overwrite or update existing attributes, but adds new ones.
|
||||
@ -435,7 +435,11 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
|
||||
size_t proto_len,
|
||||
TF_Status* status);
|
||||
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 2
|
||||
// TODO(b/166642410): It would be nice, for custom devices and for other users,
|
||||
// to have a non-string representation of devices (TF_Device) extracted from
|
||||
// tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc.
|
||||
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 3
|
||||
|
||||
// Struct to be filled in
|
||||
typedef struct TFE_CustomDevice {
|
||||
@ -454,9 +458,16 @@ typedef struct TFE_CustomDevice {
|
||||
void* device_info);
|
||||
|
||||
// Method to execute an operation.
|
||||
void (*execute)(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
//
|
||||
// Arguments provide enough information to reconstruct the original `TFE_Op`,
|
||||
// or construct a transformed version, by inspecting the passed `op`.
|
||||
//
|
||||
// TFE_OpGetDevice(op) records the original placement of the operation. It may
|
||||
// be an empty string if no device was explicitly requested, but will
|
||||
// otherwise be the name of this custom device. Ops are placed onto a custom
|
||||
// device if any of their inputs are on that custom device, but custom devices
|
||||
// are free to set a bad status in order to require explicit placement.
|
||||
void (*execute)(const TFE_Op* op, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
|
||||
|
||||
// Method to delete a device.
|
||||
|
@ -316,86 +316,6 @@ TEST(CAPI, Function_ident_CPU) {
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
TEST(CAPI, Function_ident_XLA_CPU) {
|
||||
// First create a simple identity function.
|
||||
TF_Graph* function_graph = TF_NewGraph();
|
||||
TF_OperationDescription* arg_descr =
|
||||
TF_NewOperation(function_graph, "Placeholder", "arg");
|
||||
TF_SetAttrType(arg_descr, "dtype", TF_INT32);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Operation* arg = TF_FinishOperation(arg_descr, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_OperationDescription* id_descr =
|
||||
TF_NewOperation(function_graph, "Identity", "id");
|
||||
TF_SetAttrType(id_descr, "T", TF_INT32);
|
||||
TF_AddInput(id_descr, {arg, 0});
|
||||
TF_Operation* id = TF_FinishOperation(id_descr, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_Output input{arg, 0};
|
||||
TF_Output output{id, 0};
|
||||
TF_Function* fn =
|
||||
TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
|
||||
&output, nullptr, nullptr, "test", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteGraph(function_graph);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TFE_ContextAddFunction(ctx, fn, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteFunction(fn);
|
||||
|
||||
for (bool async : {false, true, false}) {
|
||||
TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_Executor* executor = TFE_NewExecutor(async);
|
||||
TFE_ContextSetExecutorForThread(ctx, executor);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK);
|
||||
TF_Tensor* t =
|
||||
TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
|
||||
*reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
|
||||
TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ident", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(op, h, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
|
||||
// Now run it via XLA.
|
||||
TFE_OpSetXLACompilation(op, true);
|
||||
|
||||
std::vector<TFE_TensorHandle*> result;
|
||||
result.push_back(nullptr);
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, result.data(), &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
ASSERT_EQ(num_retvals, 1);
|
||||
|
||||
TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
|
||||
TFE_ContextSetExecutorForThread(ctx, old_executor);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteExecutor(old_executor);
|
||||
TFE_DeleteTensorHandle(h);
|
||||
TF_DeleteTensor(r);
|
||||
TFE_DeleteTensorHandle(result[0]);
|
||||
}
|
||||
TFE_ContextRemoveFunction(ctx, "ident", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
|
||||
void Executor_MatMul_CPU(bool async) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
|
@ -876,89 +876,6 @@ TEST(CAPI, Execute_Min_CPU) {
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
void Execute_MatMul_XLA_CPU(bool async) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
|
||||
TFE_OpSetXLACompilation(matmul, true);
|
||||
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
// Running a primitive TF operator via XLA is not yet supported.
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
EXPECT_EQ(1, num_retvals);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
float product[4] = {0};
|
||||
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
||||
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(7, product[0]);
|
||||
EXPECT_EQ(10, product[1]);
|
||||
EXPECT_EQ(15, product[2]);
|
||||
EXPECT_EQ(22, product[3]);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); }
|
||||
TEST(CAPI, Execute_MatMul_XLA_CPUAsync) { Execute_MatMul_XLA_CPU(true); }
|
||||
|
||||
void Execute_Min_XLA_CPU(bool async) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
|
||||
TFE_Op* minOp = MinOp(ctx, input, axis);
|
||||
|
||||
TFE_OpSetXLACompilation(minOp, true);
|
||||
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(minOp, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(minOp);
|
||||
TFE_DeleteTensorHandle(input);
|
||||
TFE_DeleteTensorHandle(axis);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
float output[2] = {0};
|
||||
EXPECT_EQ(sizeof(output), TF_TensorByteSize(t));
|
||||
memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(1, output[0]);
|
||||
EXPECT_EQ(3, output[1]);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); }
|
||||
TEST(CAPI, Execute_Min_XLA_CPUAsync) { Execute_Min_XLA_CPU(true); }
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
|
||||
void ExecuteWithTracing(bool async) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -1620,4 +1537,91 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
// Needs to work with a const TFE_Op since custom devices should not modify the
|
||||
// op they are called with.
|
||||
TFE_Op* CloneOp(const TFE_Op* other) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_Context* context = TFE_OpGetContext(other, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
const char* op_name = TFE_OpGetName(other, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Op* ret = TFE_NewOp(context, op_name, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
const char* device = TFE_OpGetDevice(other, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpSetDevice(ret, device, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddAttrs(ret, TFE_OpGetAttrs(other));
|
||||
int num_inputs = TFE_OpGetFlatInputCount(other, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
for (int input_index = 0; input_index < num_inputs; ++input_index) {
|
||||
TFE_TensorHandle* input = TFE_OpGetFlatInput(other, input_index, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(ret, input, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
return ret;
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTFE_OpRecreation) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
// Clone an op with attributes and a device set.
|
||||
TFE_Op* original_var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpSetAttrType(original_var_op, "dtype", TF_INT64);
|
||||
TFE_OpSetAttrShape(original_var_op, "shape", {}, 0, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ("", std::string(TFE_OpGetDevice(original_var_op, status)));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpSetDevice(original_var_op,
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Op* cloned = CloneOp(original_var_op);
|
||||
|
||||
EXPECT_EQ("/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
std::string(TFE_OpGetDevice(cloned, status)));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_EQ("VarHandleOp", std::string(TFE_OpGetName(cloned, status)));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
int num_retvals = 1;
|
||||
TFE_TensorHandle* ret;
|
||||
TFE_Execute(cloned, &ret, &num_retvals, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(ret);
|
||||
|
||||
// Clone an op with inputs and no device set.
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* original_identity = TFE_NewOp(ctx, "IdentityN", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||
TFE_OpAddInputList(original_identity, inputs, 2, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Op* cloned_identity = CloneOp(original_identity);
|
||||
EXPECT_EQ("", std::string(TFE_OpGetDevice(cloned_identity, status)));
|
||||
TFE_TensorHandle* identity_ret[] = {nullptr, nullptr};
|
||||
num_retvals = 2;
|
||||
TFE_Execute(cloned_identity, identity_ret, &num_retvals, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_DeleteTensorHandle(input1);
|
||||
TFE_DeleteTensorHandle(input2);
|
||||
TFE_DeleteTensorHandle(identity_ret[0]);
|
||||
TFE_DeleteTensorHandle(identity_ret[1]);
|
||||
|
||||
TFE_DeleteOp(cloned_identity);
|
||||
TFE_DeleteOp(original_identity);
|
||||
TFE_DeleteOp(original_var_op);
|
||||
TFE_DeleteOp(cloned);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -36,7 +36,8 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context, name, &arrived, &executed, status.get());
|
||||
RegisterLoggingDevice(context, name, /*strict_scope_placement=*/true,
|
||||
&arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context);
|
||||
ASSERT_FALSE(arrived);
|
||||
@ -73,7 +74,8 @@ TEST(CUSTOM_DEVICE, ResetOperation) {
|
||||
bool executed = false;
|
||||
const char* custom_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed,
|
||||
RegisterLoggingDevice(context.get(), custom_device_name,
|
||||
/*strict_scope_placement=*/true, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
@ -103,7 +105,8 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
|
||||
&arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle placed on the custom device.
|
||||
@ -187,7 +190,8 @@ TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) {
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/false,
|
||||
&arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle placed on the custom device.
|
||||
@ -264,10 +268,12 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) {
|
||||
const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
RegisterLoggingDevice(context.get(), custom0, &arrived, &executed,
|
||||
RegisterLoggingDevice(context.get(), custom0,
|
||||
/*strict_scope_placement=*/false, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
RegisterLoggingDevice(context.get(), custom1, &arrived, &executed,
|
||||
RegisterLoggingDevice(context.get(), custom1,
|
||||
/*strict_scope_placement=*/true, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
@ -314,14 +320,34 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) {
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
|
||||
|
||||
// Custom device: mix of custom/physical fails.
|
||||
// Custom device: mix of custom/physical places the op on the custom device.
|
||||
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
|
||||
num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
|
||||
ASSERT_TRUE(
|
||||
absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull
|
||||
EXPECT_TRUE(executed);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_DeleteTensorHandle(retval);
|
||||
|
||||
// Explicit placement still forces the op onto the requested device
|
||||
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
|
||||
TFE_OpSetDevice(matmul.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
EXPECT_FALSE(executed);
|
||||
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
|
||||
|
||||
// Custom devices can refuse to do type-based dispatch (as hcustom1 is
|
||||
// configured to do)
|
||||
matmul.reset(MatMulOp(context.get(), hcustom1.get(), hcpu.get()));
|
||||
num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
EXPECT_FALSE(executed);
|
||||
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
||||
@ -334,21 +360,24 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", &arrived, &executed,
|
||||
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0",
|
||||
/*strict_scope_placement=*/true, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
|
||||
<< TF_Message(status.get());
|
||||
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
|
||||
&arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
|
||||
<< TF_Message(status.get());
|
||||
|
||||
RegisterLoggingDevice(context.get(),
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
|
||||
&arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
|
||||
<< TF_Message(status.get());
|
||||
|
||||
RegisterLoggingDevice(
|
||||
context.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
/*strict_scope_placement=*/true, &arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
@ -33,6 +33,9 @@ struct LoggingDevice {
|
||||
bool* arrived_flag;
|
||||
// Set to true whenever an operation is executed
|
||||
bool* executed_flag;
|
||||
// If true, only explicit op placements are accepted. If false, uses
|
||||
// type-based dispatch.
|
||||
bool strict_scope_placement;
|
||||
};
|
||||
|
||||
struct LoggedTensor {
|
||||
@ -84,18 +87,35 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
void LoggingDeviceExecute(const TFE_Op* original_op, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info) {
|
||||
const char* requested_placement = TFE_OpGetDevice(original_op, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
if (dev->strict_scope_placement && *requested_placement == '\0') {
|
||||
TF_SetStatus(s, TF_INTERNAL,
|
||||
"Ops must be placed on the device explicitly, or their inputs "
|
||||
"first copied to other devices.");
|
||||
return;
|
||||
}
|
||||
TFE_Context* context = TFE_OpGetContext(original_op, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
const char* operation_name = TFE_OpGetName(original_op, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
|
||||
|
||||
TFE_Op* op(TFE_NewOp(context, operation_name, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddAttrs(op, attributes);
|
||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
int num_inputs = TFE_OpGetFlatInputCount(original_op, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
for (int j = 0; j < num_inputs; ++j) {
|
||||
TFE_TensorHandle* input = inputs[j];
|
||||
TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, j, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
const char* input_device = TFE_TensorHandleDeviceName(input, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
if (dev->device_name == input_device) {
|
||||
@ -131,8 +151,8 @@ void DeleteLoggingDevice(void* device_info) {
|
||||
} // namespace
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status) {
|
||||
bool strict_scope_placement, bool* arrived_flag,
|
||||
bool* executed_flag, TF_Status* status) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
@ -143,6 +163,7 @@ void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
device->executed_flag = executed_flag;
|
||||
device->device_name = name;
|
||||
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
device->strict_scope_placement = strict_scope_placement;
|
||||
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
|
||||
}
|
||||
|
||||
@ -168,5 +189,6 @@ void AllocateLoggingDevice(const char* name, bool* arrived_flag,
|
||||
logging_device->device_name = name;
|
||||
logging_device->underlying_device =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
logging_device->strict_scope_placement = true;
|
||||
*device_info = reinterpret_cast<void*>(logging_device);
|
||||
}
|
||||
|
@ -25,8 +25,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status);
|
||||
bool strict_scope_placement, bool* arrived_flag,
|
||||
bool* executed_flag, TF_Status* status);
|
||||
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
|
||||
bool* executed_flag, TFE_CustomDevice** device,
|
||||
void** device_info);
|
||||
|
@ -242,6 +242,7 @@ namespace internal {
|
||||
Status Reset(AbstractOperation* op_, const char* op,
|
||||
const char* raw_device_name, ForwardOperation* forward_op_) {
|
||||
forward_op_->op_name = op;
|
||||
forward_op_->attrs.Reset(op);
|
||||
return op_->Reset(op, raw_device_name);
|
||||
}
|
||||
Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input,
|
||||
@ -418,6 +419,11 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx,
|
||||
// TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
|
||||
forward_op_->outputs.push_back(retvals[i]);
|
||||
}
|
||||
// TODO(b/166669239): This is needed to support AttrBuilder::Get for string
|
||||
// attributes. Number type attrs and DataType attrs work fine without this.
|
||||
// Consider getting rid of this and making the behavior between number types
|
||||
// and string consistent.
|
||||
forward_op_->attrs.BuildNodeDef();
|
||||
std::vector<TapeTensor> tape_tensors;
|
||||
for (auto t : retvals) {
|
||||
tape_tensors.push_back(TapeTensor(t, ctx));
|
||||
|
@ -507,6 +507,57 @@ TEST_P(CppGradients, TestIdentityNGrad) {
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSetAttrString) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr t;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
t.reset(x_raw);
|
||||
}
|
||||
|
||||
AbstractOperationPtr check_numerics_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx.get();
|
||||
Status s = Reset(check_numerics_op.get(), "CheckNumerics",
|
||||
/*raw_device_name=*/nullptr, &forward_op);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
if (isa<TracingOperation>(check_numerics_op.get())) {
|
||||
s = dyn_cast<TracingOperation>(check_numerics_op.get())
|
||||
->SetOpName("check_numerics");
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
s = AddInput(check_numerics_op.get(), t.get(), &forward_op);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
string message = "This is the way!";
|
||||
s = SetAttrString(check_numerics_op.get(), "message", message.data(),
|
||||
message.length(), &forward_op);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
int num_retvals = 1;
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
GradientRegistry registry;
|
||||
std::unique_ptr<Tape> tape(new Tape(/*persistent=*/false));
|
||||
s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs),
|
||||
&num_retvals, &forward_op, tape.get(), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
string read_message;
|
||||
s = forward_op.attrs.Get("message", &read_message);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ASSERT_EQ(read_message, message);
|
||||
}
|
||||
|
||||
// TODO(b/164171226): Enable this test with tfrt after AddInputList is
|
||||
// supported. It is needed for IdentityN.
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
|
@ -47,9 +47,6 @@ class ImmediateExecutionOperation : public AbstractOperation {
|
||||
virtual Status InputLength(const char* input_name, int* length) = 0;
|
||||
virtual Status OutputLength(const char* output_name, int* length) = 0;
|
||||
|
||||
// Experimental
|
||||
virtual Status SetUseXla(bool enable) = 0;
|
||||
|
||||
// Set stack trace to be used for potential async error reporting.
|
||||
virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0;
|
||||
|
||||
|
@ -765,13 +765,13 @@ TEST_P(CppGradients, TestMNIST_Training) {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
::testing::Combine(::testing::Values("graphdef"),
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
::testing::Combine(::testing::Values("graphdef"),
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#endif
|
||||
|
@ -31,11 +31,15 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
using std::vector;
|
||||
using tracing::TracingOperation;
|
||||
|
||||
// ========================== Tape Ops ==============================
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
|
||||
using std::vector;
|
||||
using tensorflow::tracing::TracingOperation;
|
||||
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
@ -272,6 +276,7 @@ Status MNISTForwardModel(AbstractContext* ctx,
|
||||
|
||||
AbstractTensorHandle* scores = temp_outputs[0];
|
||||
|
||||
temp_outputs.resize(2);
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
|
||||
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmax_loss", registry)); // Compute Softmax(Scores,labels)
|
||||
@ -592,3 +597,7 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -27,13 +27,13 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
using namespace tensorflow;
|
||||
using namespace tensorflow::gradients;
|
||||
using namespace tensorflow::gradients::internal;
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
// ========================== Tape Ops ==============================
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
@ -144,3 +144,7 @@ Status RunModel(Model model, AbstractContext* ctx,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -255,28 +255,44 @@ TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
|
||||
// Since this function is used to satisfy the TFE_CustomDevice C API,
|
||||
// device_info is passed in using a C-style generic. It must always be a
|
||||
// ParallelDevice.
|
||||
void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
void ParallelDeviceExecute(const TFE_Op* original_op, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* status,
|
||||
void* device_info) {
|
||||
const char* requested_placement = TFE_OpGetDevice(original_op, status);
|
||||
if (*requested_placement == '\0') {
|
||||
TF_SetStatus(
|
||||
status, TF_INTERNAL,
|
||||
"Ops must be placed on the parallel device explicitly, or their inputs "
|
||||
"first un-packed. Got an un-placed op with an input placed on the "
|
||||
"parallel device.");
|
||||
return;
|
||||
}
|
||||
TFE_Context* context = TFE_OpGetContext(original_op, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* operation_name = TFE_OpGetName(original_op, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
|
||||
|
||||
NamedParallelDevice* named_device =
|
||||
reinterpret_cast<NamedParallelDevice*>(device_info);
|
||||
std::vector<MaybeParallelTensorUnowned> typed_inputs;
|
||||
int num_inputs = TFE_OpGetFlatInputCount(original_op, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
typed_inputs.reserve(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* tensor_handle_device =
|
||||
TFE_TensorHandleDeviceName(inputs[i], status);
|
||||
TFE_TensorHandleDeviceName(input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (named_device->name() == tensor_handle_device) {
|
||||
// We assume that any tensors already placed on this device are
|
||||
// ParallelTensors.
|
||||
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
|
||||
TFE_TensorHandleDevicePointer(inputs[i], status)));
|
||||
TFE_TensorHandleDevicePointer(input, status)));
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
} else {
|
||||
typed_inputs.emplace_back(inputs[i]);
|
||||
typed_inputs.emplace_back(input);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -29,6 +29,7 @@ cc_library(
|
||||
":gcs_helper",
|
||||
":ram_file_block_cache",
|
||||
"//tensorflow/c:env",
|
||||
"//tensorflow/c:logging",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "google/cloud/storage/client.h"
|
||||
#include "tensorflow/c/env.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
|
||||
#include "tensorflow/c/logging.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// Implementation of a filesystem for GCS environments.
|
||||
@ -134,6 +135,8 @@ static int64_t LoadBufferFromGCS(const std::string& path, size_t offset,
|
||||
}
|
||||
// `TF_OUT_OF_RANGE` isn't considered as an error. So we clear it here.
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
TF_VLog(1, "Successful read of %s @ %u of size: %u", path.c_str(), offset,
|
||||
read);
|
||||
stream.read(buffer, read);
|
||||
read = stream.gcount();
|
||||
if (read < buffer_size) {
|
||||
@ -146,6 +149,8 @@ static int64_t LoadBufferFromGCS(const std::string& path, size_t offset,
|
||||
path, " @ ", offset)
|
||||
.c_str());
|
||||
}
|
||||
TF_VLog(2, "Successful integrity check for: %s @ %u", path.c_str(),
|
||||
offset);
|
||||
}
|
||||
}
|
||||
return read;
|
||||
@ -284,6 +289,8 @@ static void SyncImpl(const std::string& bucket, const std::string& object,
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
return;
|
||||
}
|
||||
TF_VLog(3, "AppendObject: gs://%s/%s to gs://%s/%s", bucket.c_str(),
|
||||
temporary_object.c_str(), bucket.c_str(), object.c_str());
|
||||
const std::vector<gcs::ComposeSourceObject> source_objects = {
|
||||
{object, {}, {}}, {temporary_object, {}, {}}};
|
||||
metadata = gcs_client->ComposeObject(bucket, source_objects, object);
|
||||
@ -321,6 +328,8 @@ void Append(const TF_WritableFile* file, const char* buffer, size_t n,
|
||||
"The internal temporary file is not writable.");
|
||||
return;
|
||||
}
|
||||
TF_VLog(3, "Append: gs://%s/%s size %u", gcs_file->bucket.c_str(),
|
||||
gcs_file->object.c_str(), n);
|
||||
gcs_file->sync_need = true;
|
||||
gcs_file->outfile.write(buffer, n);
|
||||
if (!gcs_file->outfile)
|
||||
@ -346,6 +355,8 @@ int64_t Tell(const TF_WritableFile* file, TF_Status* status) {
|
||||
void Flush(const TF_WritableFile* file, TF_Status* status) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
if (gcs_file->sync_need) {
|
||||
TF_VLog(3, "Flush started: gs://%s/%s", gcs_file->bucket.c_str(),
|
||||
gcs_file->object.c_str());
|
||||
if (!gcs_file->outfile) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Could not append to the internal temporary file.");
|
||||
@ -353,6 +364,8 @@ void Flush(const TF_WritableFile* file, TF_Status* status) {
|
||||
}
|
||||
SyncImpl(gcs_file->bucket, gcs_file->object, &gcs_file->offset,
|
||||
&gcs_file->outfile, gcs_file->gcs_client, status);
|
||||
TF_VLog(3, "Flush finished: gs://%s/%s", gcs_file->bucket.c_str(),
|
||||
gcs_file->object.c_str());
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
gcs_file->sync_need = false;
|
||||
} else {
|
||||
@ -361,11 +374,16 @@ void Flush(const TF_WritableFile* file, TF_Status* status) {
|
||||
}
|
||||
|
||||
void Sync(const TF_WritableFile* file, TF_Status* status) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
TF_VLog(3, "Sync: gs://%s/%s", gcs_file->bucket.c_str(),
|
||||
gcs_file->object.c_str());
|
||||
Flush(file, status);
|
||||
}
|
||||
|
||||
void Close(const TF_WritableFile* file, TF_Status* status) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
TF_VLog(3, "Close: gs://%s/%s", gcs_file->bucket.c_str(),
|
||||
gcs_file->object.c_str());
|
||||
if (gcs_file->sync_need) {
|
||||
Flush(file, status);
|
||||
}
|
||||
@ -428,6 +446,8 @@ GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client)
|
||||
if (absl::SimpleAtoi(std::getenv(kMaxStaleness), &value)) {
|
||||
max_staleness = value;
|
||||
}
|
||||
TF_VLog(1, "GCS cache max size = %u ; block size = %u ; max staleness = %u",
|
||||
max_bytes, block_size, max_staleness);
|
||||
|
||||
file_block_cache = std::make_unique<RamFileBlockCache>(
|
||||
block_size, max_bytes, max_staleness,
|
||||
@ -511,6 +531,10 @@ static void UncachedStatForObject(const std::string& bucket,
|
||||
stat->base.mtime_nsec =
|
||||
metadata->time_storage_class_updated().time_since_epoch().count();
|
||||
stat->base.is_directory = object.back() == '/';
|
||||
TF_VLog(1,
|
||||
"Stat of: gs://%s/%s -- length: %u generation: %u; mtime_nsec: %u;",
|
||||
bucket.c_str(), object.c_str(), stat->base.length,
|
||||
stat->generation_number, stat->base.mtime_nsec);
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
@ -545,9 +569,10 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
if (!gcs_file->file_block_cache->ValidateAndUpdateFileSignature(
|
||||
path, stat.generation_number)) {
|
||||
std::cout
|
||||
<< "File signature has been changed. Refreshing the cache. Path: "
|
||||
<< path;
|
||||
TF_VLog(
|
||||
1,
|
||||
"File signature has been changed. Refreshing the cache. Path: %s",
|
||||
path.c_str());
|
||||
}
|
||||
read = gcs_file->file_block_cache->Read(path, offset, n, buffer, status);
|
||||
} else {
|
||||
@ -579,6 +604,7 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
(gcs_file->compose ? 0 : -1)});
|
||||
// We are responsible for freeing the pointer returned by TF_GetTempFileName
|
||||
free(temp_file_name);
|
||||
TF_VLog(3, "GcsWritableFile: %s", path);
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
@ -624,7 +650,8 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
TF_VLog(3, "GcsWritableFile: %s with existing file %s", path,
|
||||
temp_file_name.c_str());
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
@ -812,6 +839,10 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
std::string dir = path;
|
||||
MaybeAppendSlash(&dir);
|
||||
TF_VLog(3,
|
||||
"CreateDir: creating directory with path: %s and "
|
||||
"path_with_slash: %s",
|
||||
path, dir.c_str());
|
||||
std::string bucket, object;
|
||||
ParseGCSPath(dir, true, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
@ -826,8 +857,11 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
||||
}
|
||||
|
||||
PathExists(filesystem, dir.c_str(), status);
|
||||
if (TF_GetCode(status) == TF_OK)
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
// Use the original name for a correct error here.
|
||||
TF_VLog(3, "CreateDir: directory already exists, not uploading %s", path);
|
||||
return TF_SetStatus(status, TF_ALREADY_EXISTS, path);
|
||||
}
|
||||
|
||||
auto metadata = gcs_file->gcs_client.InsertObject(
|
||||
bucket, object, "",
|
||||
@ -933,6 +967,7 @@ bool IsDirectory(const TF_Filesystem* filesystem, const char* path,
|
||||
static void RenameObject(const TF_Filesystem* filesystem,
|
||||
const std::string& src, const std::string& dst,
|
||||
TF_Status* status) {
|
||||
TF_VLog(3, "RenameObject: started %s to %s", src.c_str(), dst.c_str());
|
||||
std::string bucket_src, object_src;
|
||||
ParseGCSPath(src, false, &bucket_src, &object_src, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
@ -946,6 +981,7 @@ static void RenameObject(const TF_Filesystem* filesystem,
|
||||
bucket_src, object_src, bucket_dst, object_dst);
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TF_VLog(3, "RenameObject: finished %s to %s", src.c_str(), dst.c_str());
|
||||
|
||||
ClearFileCaches(gcs_file, dst);
|
||||
DeleteFile(filesystem, src.c_str(), status);
|
||||
|
@ -43,8 +43,8 @@ class ConcreteFunction {
|
||||
virtual ~ConcreteFunction() = default;
|
||||
|
||||
// This method returns the "Call" Op used to execute the function.
|
||||
virtual Status GetCallOp(absl::Span<AbstractTensorHandle* const> inputs,
|
||||
ImmediateOpPtr* out) = 0;
|
||||
virtual Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
|
||||
ImmediateOpPtr* out) const = 0;
|
||||
|
||||
virtual const FunctionMetadata& GetFunctionMetadata() const = 0;
|
||||
};
|
||||
|
@ -28,6 +28,26 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "flat_tensor_function",
|
||||
srcs = [
|
||||
"flat_tensor_function.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"flat_tensor_function.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "variable",
|
||||
srcs = [
|
||||
@ -68,7 +88,7 @@ cc_library(
|
||||
"tf_concrete_function.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorhandle_convertible",
|
||||
":flat_tensor_function",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
@ -81,3 +101,26 @@ cc_library(
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_signature_def_function",
|
||||
srcs = [
|
||||
"tf_signature_def_function.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"tf_signature_def_function.h",
|
||||
],
|
||||
deps = [
|
||||
":flat_tensor_function",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function",
|
||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,85 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
FlatTensorFunction::FlatTensorFunction(
|
||||
const std::string& name,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
ImmediateExecutionContext* ctx)
|
||||
: name_(name), captures_(std::move(captures)), ctx_(ctx) {}
|
||||
|
||||
FlatTensorFunction::~FlatTensorFunction() {
|
||||
Status status = ctx_->RemoveFunction(name_);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
|
||||
<< status.error_message();
|
||||
}
|
||||
}
|
||||
|
||||
Status FlatTensorFunction::Create(
|
||||
const FunctionDef* function_def,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
ImmediateExecutionContext* ctx, std::unique_ptr<FlatTensorFunction>* out) {
|
||||
TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def));
|
||||
out->reset(new FlatTensorFunction(function_def->signature().name(),
|
||||
std::move(captures), ctx));
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status FlatTensorFunction::MakeCallOp(
|
||||
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
|
||||
out->reset(ctx_->CreateOperation());
|
||||
// In eager mode, TF2 python executes functions by constructing an op with
|
||||
// the name of the functiondef:
|
||||
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
|
||||
// In graph mode, we create a PartitionedCallOp instead:
|
||||
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
|
||||
|
||||
// TODO(bmzhao): After discussing with Allen, we should execute this via a
|
||||
// PartitionedCallOp for compatibility with "tooling that assumes functions in
|
||||
// graphs are PartitionedCallOps".
|
||||
TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
|
||||
|
||||
// Adding the user-provided inputs to the function.
|
||||
TF_RETURN_IF_ERROR((*out)->AddInputList(inputs));
|
||||
|
||||
absl::Span<AbstractTensorHandle* const> captures(
|
||||
reinterpret_cast<AbstractTensorHandle* const*>(captures_.data()),
|
||||
captures_.size());
|
||||
|
||||
// Adding the captures of the function.
|
||||
TF_RETURN_IF_ERROR((*out)->AddInputList(captures));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,84 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// FlatTensorFunction models a TF2 eager runtime view of a callable function,
|
||||
// taking + returning flat lists of tensors, including any captures.
|
||||
// Effectively, it is a thin wrapper around a FunctionDef owned by the
|
||||
// EagerContext, and any TensorHandle captures associated with the function. The
|
||||
// MakeCallOp method handles the logic of marshaling captures after the user
|
||||
// provided inputs automatically.
|
||||
// Note(bmzhao): This class is mainly intended to house low-level reusable
|
||||
// function logic between SignatureDefFunction and ConcreteFunction, which
|
||||
// present higher level interfaces. This type does *not* hold any "function
|
||||
// metadata".
|
||||
class FlatTensorFunction {
|
||||
public:
|
||||
// Factory for creating a FlatTensorFunction.
|
||||
//
|
||||
// Params:
|
||||
// function_def - The function_def associated with the created
|
||||
// FlatTensorFunction. FlatTensorFunction will register this
|
||||
// function_def with `ctx` on creation, and de-register it on
|
||||
// destruction. function_def must be non-null, but
|
||||
// otherwise has no lifetime requirements.
|
||||
// captures - The captured TensorHandles associated with this
|
||||
// FlatTensorFunction.
|
||||
// ctx - A handle to the Tensorflow runtime. This MUST be non-null and
|
||||
// outlive TFConcreteFunction.
|
||||
// out - The output FlatTensorFunction.
|
||||
static Status Create(const FunctionDef* function_def,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
ImmediateExecutionContext* ctx,
|
||||
std::unique_ptr<FlatTensorFunction>* out);
|
||||
|
||||
// This method creates a "Call" Op used to execute the function.
|
||||
Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
|
||||
ImmediateOpPtr* out) const;
|
||||
|
||||
~FlatTensorFunction();
|
||||
|
||||
private:
|
||||
FlatTensorFunction(const std::string& name,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
ImmediateExecutionContext* ctx);
|
||||
|
||||
FlatTensorFunction(const FlatTensorFunction&) = delete;
|
||||
FlatTensorFunction& operator=(const FlatTensorFunction&) = delete;
|
||||
|
||||
// Name of the FunctionDef corresponding to this TFConcreteFunction
|
||||
std::string name_;
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures_;
|
||||
ImmediateExecutionContext* ctx_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
@ -33,32 +33,20 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TFConcreteFunction::TFConcreteFunction(
|
||||
const std::string& name,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
FunctionMetadata metadata, ImmediateExecutionContext* ctx)
|
||||
: name_(name),
|
||||
captures_(std::move(captures)),
|
||||
metadata_(std::move(metadata)),
|
||||
ctx_(ctx) {}
|
||||
|
||||
TFConcreteFunction::~TFConcreteFunction() {
|
||||
Status status = ctx_->RemoveFunction(name_);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
|
||||
<< status.error_message();
|
||||
}
|
||||
}
|
||||
TFConcreteFunction::TFConcreteFunction(std::unique_ptr<FlatTensorFunction> func,
|
||||
FunctionMetadata metadata)
|
||||
: func_(std::move(func)), metadata_(std::move(metadata)) {}
|
||||
|
||||
Status TFConcreteFunction::Create(
|
||||
const FunctionDef* function_def,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
FunctionMetadata metadata, ImmediateExecutionContext* ctx,
|
||||
std::unique_ptr<TFConcreteFunction>* out) {
|
||||
TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def));
|
||||
out->reset(new TFConcreteFunction(function_def->signature().name(),
|
||||
std::move(captures), std::move(metadata),
|
||||
ctx));
|
||||
std::unique_ptr<FlatTensorFunction> func;
|
||||
TF_RETURN_IF_ERROR(FlatTensorFunction::Create(
|
||||
function_def, std::move(captures), ctx, &func));
|
||||
|
||||
out->reset(new TFConcreteFunction(std::move(func), std::move(metadata)));
|
||||
return Status();
|
||||
}
|
||||
|
||||
@ -66,30 +54,9 @@ const FunctionMetadata& TFConcreteFunction::GetFunctionMetadata() const {
|
||||
return metadata_;
|
||||
}
|
||||
|
||||
Status TFConcreteFunction::GetCallOp(
|
||||
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) {
|
||||
out->reset(ctx_->CreateOperation());
|
||||
// In eager mode, TF2 python executes functions by constructing an op with
|
||||
// the name of the functiondef:
|
||||
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
|
||||
// In graph mode, we create a PartitionedCallOp instead:
|
||||
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
|
||||
|
||||
// TODO(bmzhao): After discussing with Allen, we should execute this via a
|
||||
// PartitionedCallOp for compatibility with "tooling that assumes functions in
|
||||
// graphs are PartitionedCallOps".
|
||||
TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
|
||||
|
||||
// Adding the user-provided inputs to the function.
|
||||
TF_RETURN_IF_ERROR((*out)->AddInputList(inputs));
|
||||
|
||||
absl::Span<AbstractTensorHandle* const> captures(
|
||||
reinterpret_cast<AbstractTensorHandle**>(captures_.data()),
|
||||
captures_.size());
|
||||
|
||||
// Adding the captures of the function.
|
||||
TF_RETURN_IF_ERROR((*out)->AddInputList(captures));
|
||||
return Status();
|
||||
Status TFConcreteFunction::MakeCallOp(
|
||||
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
|
||||
return func_->MakeCallOp(inputs, out);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
@ -58,26 +58,22 @@ class TFConcreteFunction : public ConcreteFunction {
|
||||
std::unique_ptr<TFConcreteFunction>* out);
|
||||
|
||||
// This method returns the "Call" Op used to execute the function.
|
||||
Status GetCallOp(absl::Span<AbstractTensorHandle* const> inputs,
|
||||
ImmediateOpPtr* out) override;
|
||||
Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
|
||||
ImmediateOpPtr* out) const override;
|
||||
|
||||
const FunctionMetadata& GetFunctionMetadata() const override;
|
||||
|
||||
~TFConcreteFunction() override;
|
||||
~TFConcreteFunction() override = default;
|
||||
|
||||
private:
|
||||
TFConcreteFunction(const std::string& name,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
FunctionMetadata metadata, ImmediateExecutionContext* ctx);
|
||||
TFConcreteFunction(std::unique_ptr<FlatTensorFunction> func,
|
||||
FunctionMetadata metadata);
|
||||
|
||||
TFConcreteFunction(const TFConcreteFunction&) = delete;
|
||||
TFConcreteFunction& operator=(const TFConcreteFunction&) = delete;
|
||||
|
||||
// Name of the FunctionDef corresponding to this TFConcreteFunction
|
||||
std::string name_;
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures_;
|
||||
std::unique_ptr<FlatTensorFunction> func_;
|
||||
FunctionMetadata metadata_;
|
||||
ImmediateExecutionContext* ctx_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -0,0 +1,64 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TFSignatureDefFunction::TFSignatureDefFunction(
|
||||
std::unique_ptr<FlatTensorFunction> func,
|
||||
SignatureDefFunctionMetadata metadata)
|
||||
: func_(std::move(func)), metadata_(std::move(metadata)) {}
|
||||
|
||||
Status TFSignatureDefFunction::Create(
|
||||
const FunctionDef* function_def,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
SignatureDefFunctionMetadata metadata, ImmediateExecutionContext* ctx,
|
||||
std::unique_ptr<TFSignatureDefFunction>* out) {
|
||||
std::unique_ptr<FlatTensorFunction> func;
|
||||
TF_RETURN_IF_ERROR(FlatTensorFunction::Create(
|
||||
function_def, std::move(captures), ctx, &func));
|
||||
|
||||
out->reset(new TFSignatureDefFunction(std::move(func), std::move(metadata)));
|
||||
return Status();
|
||||
}
|
||||
|
||||
const SignatureDefFunctionMetadata&
|
||||
TFSignatureDefFunction::GetFunctionMetadata() const {
|
||||
return metadata_;
|
||||
}
|
||||
|
||||
Status TFSignatureDefFunction::MakeCallOp(
|
||||
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
|
||||
return func_->MakeCallOp(inputs, out);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,85 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This is the TF eager runtime implementation of SignatureDefFunction (separate
|
||||
// from the TFRT implementation). The user-facing API of SignatureDefFunctions
|
||||
// and their semantic differences from ConcreteFunction are described here:
|
||||
// https://github.com/tensorflow/tensorflow/blob/e2db60c9d9598ebae0b7741587ce6f5d473584d9/tensorflow/cc/saved_model/experimental/public/signature_def_function.h#L30-L59
|
||||
// Additional implementation notes are available here:
|
||||
// https://github.com/tensorflow/tensorflow/blob/e2db60c9d9598ebae0b7741587ce6f5d473584d9/tensorflow/c/experimental/saved_model/core/signature_def_function.h#L31-L48
|
||||
class TFSignatureDefFunction : public SignatureDefFunction {
|
||||
public:
|
||||
// Factory function for creating a TFSignatureDefFunction.
|
||||
//
|
||||
// Params:
|
||||
// function_def - The function_def associated with the created
|
||||
// TFSignatureDefFunction. TFSignatureDefFunction will
|
||||
// register this function_def with `ctx` on creation, and
|
||||
// de-register it on destruction. function_def must be
|
||||
// non-null, but otherwise has no lifetime requirements.
|
||||
// captures - The captured TensorHandles associated with this
|
||||
// TFConcreteFunction.
|
||||
// metadata - FunctionMetadata associated with this TFSignatureDefFunction.
|
||||
// ctx - A handle to the Tensorflow runtime. This MUST be non-null and
|
||||
// outlive TFSignatureDefFunction.
|
||||
// out - The output TFSignatureDefFunction.
|
||||
static Status Create(const FunctionDef* function_def,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
SignatureDefFunctionMetadata metadata,
|
||||
ImmediateExecutionContext* ctx,
|
||||
std::unique_ptr<TFSignatureDefFunction>* out);
|
||||
|
||||
// This method creates a "Call" Op used to execute the function.
|
||||
Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
|
||||
ImmediateOpPtr* out) const override;
|
||||
|
||||
const SignatureDefFunctionMetadata& GetFunctionMetadata() const override;
|
||||
|
||||
~TFSignatureDefFunction() override = default;
|
||||
|
||||
private:
|
||||
TFSignatureDefFunction(std::unique_ptr<FlatTensorFunction> func,
|
||||
SignatureDefFunctionMetadata metadata);
|
||||
|
||||
TFSignatureDefFunction(const TFSignatureDefFunction&) = delete;
|
||||
TFSignatureDefFunction& operator=(const TFSignatureDefFunction&) = delete;
|
||||
|
||||
std::unique_ptr<FlatTensorFunction> func_;
|
||||
SignatureDefFunctionMetadata metadata_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_
|
@ -34,15 +34,15 @@ TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
|
||||
&tensorflow::unwrap(func)->GetFunctionMetadata()));
|
||||
}
|
||||
|
||||
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func,
|
||||
TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status) {
|
||||
TFE_Op* TF_ConcreteFunctionMakeCallOp(TF_ConcreteFunction* func,
|
||||
TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status) {
|
||||
tensorflow::ImmediateOpPtr call_op;
|
||||
absl::Span<tensorflow::AbstractTensorHandle* const> input_span(
|
||||
reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
||||
tensorflow::unwrap(inputs)),
|
||||
static_cast<size_t>(num_inputs));
|
||||
status->status = tensorflow::unwrap(func)->GetCallOp(input_span, &call_op);
|
||||
status->status = tensorflow::unwrap(func)->MakeCallOp(input_span, &call_op);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -107,7 +107,7 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) {
|
||||
compute_fn_inputs.push_back(input_a);
|
||||
compute_fn_inputs.push_back(input_b);
|
||||
|
||||
TFE_Op* compute_fn_op = TF_ConcreteFunctionGetCallOp(
|
||||
TFE_Op* compute_fn_op = TF_ConcreteFunctionMakeCallOp(
|
||||
compute_fn, compute_fn_inputs.data(), compute_fn_inputs.size(), status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
|
@ -47,7 +47,7 @@ TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
|
||||
// high-level API here. A strawman for what this interface could look like:
|
||||
// TF_Value* TF_ExecuteFunction(TFE_Context*, TF_ConcreteFunction*, TF_Value*
|
||||
// inputs, int num_inputs, TF_Status* status);
|
||||
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp(
|
||||
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionMakeCallOp(
|
||||
TF_ConcreteFunction* func, TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status);
|
||||
|
||||
|
@ -132,6 +132,23 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "summary_op_benchmark_test",
|
||||
size = "small",
|
||||
srcs = ["summary_op_benchmark_test.cc"],
|
||||
deps = [
|
||||
":summary_op",
|
||||
"//tensorflow/c:kernels",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_shape_utils",
|
||||
srcs = ["tensor_shape_utils.cc"],
|
||||
|
@ -93,11 +93,13 @@ void HistogramSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
|
||||
std::ostringstream err;
|
||||
err << "Nan in summary histogram for: " << k->op_node_name;
|
||||
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
|
||||
TF_OpKernelContext_Failure(ctx, status.get());
|
||||
return;
|
||||
} else if (Eigen::numext::isinf(double_val)) {
|
||||
std::ostringstream err;
|
||||
err << "Infinity in Histogram for: " << k->op_node_name;
|
||||
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
|
||||
TF_OpKernelContext_Failure(ctx, status.get());
|
||||
return;
|
||||
}
|
||||
histo.Add(double_val);
|
||||
|
71
tensorflow/c/kernels/summary_op_benchmark_test.cc
Normal file
71
tensorflow/c/kernels/summary_op_benchmark_test.cc
Normal file
@ -0,0 +1,71 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Graph* BM_ScalarSummaryOp(TensorShape shape, std::string tag, float value) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor tags(DT_STRING, shape);
|
||||
Tensor values(DT_FLOAT, shape);
|
||||
for (int i = 0; i < tags.NumElements(); ++i) {
|
||||
tags.flat<tstring>()(i) = tag;
|
||||
values.flat<float>()(i) = value;
|
||||
}
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("dummy"), "ScalarSummary")
|
||||
.Input(test::graph::Constant(g, tags))
|
||||
.Input(test::graph::Constant(g, values))
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Finalize(g, &ret));
|
||||
return g;
|
||||
}
|
||||
|
||||
// Macro used to parse initializer list for tensorshape
|
||||
#define DIMARGS(...) \
|
||||
{ __VA_ARGS__ }
|
||||
// // Random parameters for testing
|
||||
constexpr char longTagParam[] = "LONGTAG____________________________";
|
||||
constexpr float largeValueParam = 2352352.2623433;
|
||||
|
||||
#define BM_ScalarSummaryDev(device, dims, name, tag, value) \
|
||||
void BM_ScalarSummary##name##device(int iters) { \
|
||||
testing::StopTiming(); \
|
||||
TensorShape tensorshape(DIMARGS dims); \
|
||||
auto g = BM_ScalarSummaryOp(tensorshape, #tag, value); \
|
||||
testing::StartTiming(); \
|
||||
test::Benchmark("cpu", g).Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_ScalarSummary##name##device);
|
||||
|
||||
BM_ScalarSummaryDev(Cpu, (5, 10, 100), Base, Tag, 5.2);
|
||||
// Benchmark for large shapes
|
||||
BM_ScalarSummaryDev(Cpu, (500, 100, 100), LargeShape, Tag, 5.2);
|
||||
// Benchmark for large tag tstring
|
||||
BM_ScalarSummaryDev(Cpu, (5, 10, 100), LongTag, longTagParam, 5.2);
|
||||
// Benchmark for large values
|
||||
BM_ScalarSummaryDev(Cpu, (500, 100, 100), LargeValue, Tag, largeValueParam);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -329,6 +329,7 @@ cc_library(
|
||||
srcs = ["xla_compilation_cache.cc"],
|
||||
hdrs = ["xla_compilation_cache.h"],
|
||||
deps = [
|
||||
":flags",
|
||||
":xla_activity_listener",
|
||||
":xla_activity_proto_cc",
|
||||
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||
@ -361,8 +362,11 @@ tf_cc_test(
|
||||
"xla_compilation_cache_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":flags",
|
||||
":xla_compilation_cache",
|
||||
":xla_cpu_jit",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
@ -918,6 +922,7 @@ tf_cc_test(
|
||||
":xla_cpu_jit",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:functional_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/compiler/tf2xla:test_util",
|
||||
|
@ -518,10 +518,15 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
||||
}
|
||||
}
|
||||
|
||||
// Returns `true` iff node has a given `attr` set to `true`. Returns `false`
|
||||
// both for the missing attr, and the attr set to `false`.
|
||||
static bool HasBoolAttr(const NodeDef& node, const char* attr) {
|
||||
const auto& it = node.attr().find(attr);
|
||||
return it != node.attr().end() && it->second.b();
|
||||
}
|
||||
|
||||
bool CanCreateXlaKernel(const NodeDef& node_def) {
|
||||
// If kXlaMustCompileAttr is set on the node_def, use its value.
|
||||
const auto& it = node_def.attr().find(kXlaMustCompileAttr);
|
||||
return it != node_def.attr().end() && it->second.b();
|
||||
return HasBoolAttr(node_def, kXlaMustCompileAttr);
|
||||
}
|
||||
|
||||
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
@ -564,4 +569,58 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static auto const ops_triggering_xla_compilation =
|
||||
new absl::flat_hash_set<std::string>{"XlaBroadcastHelper",
|
||||
"XlaConv",
|
||||
"XlaDequantize",
|
||||
"XlaDot",
|
||||
"XlaDynamicSlice",
|
||||
"XlaDynamicUpdateSlice",
|
||||
"XlaEinsum",
|
||||
"XlaGather",
|
||||
"XlaIf",
|
||||
"XlaKeyValueSort",
|
||||
"XlaPad",
|
||||
"XlaRecv",
|
||||
"XlaReduce",
|
||||
"XlaReduceWindow",
|
||||
"XlaReplicaId",
|
||||
"XlaScatter",
|
||||
"XlaSelectAndScatter",
|
||||
"XlaSelfAdjointEig",
|
||||
"XlaSend",
|
||||
"XlaSharding",
|
||||
"XlaSort",
|
||||
"XlaSpmdFullToShardShape",
|
||||
"XlaSpmdShardToFullShape",
|
||||
"XlaSvd",
|
||||
"XlaWhile"};
|
||||
|
||||
static bool NodeCanTriggerXlaCompilation(const NodeDef& node) {
|
||||
return node.attr().find(kXlaClusterIdAttr) != node.attr().end() ||
|
||||
HasBoolAttr(node, kXlaMustCompileAttr) ||
|
||||
HasBoolAttr(node, kXlaCompileAttr) ||
|
||||
HasBoolAttr(node, kXlaScopeAttr) ||
|
||||
HasBoolAttr(node, kXlaInternalScopeAttr) ||
|
||||
ops_triggering_xla_compilation->count(node.op());
|
||||
}
|
||||
|
||||
bool CanTriggerXlaCompilation(const GraphDef& graph) {
|
||||
for (const FunctionDef& function : graph.library().function()) {
|
||||
for (const NodeDef& node : function.node_def()) {
|
||||
if (NodeCanTriggerXlaCompilation(node)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const NodeDef& node : graph.node()) {
|
||||
if (NodeCanTriggerXlaCompilation(node)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -126,9 +126,10 @@ class RecursiveCompilabilityChecker {
|
||||
bool allow_inaccurate_ops = false;
|
||||
};
|
||||
|
||||
RecursiveCompilabilityChecker(const OperationFilter* op_filter,
|
||||
const DeviceType* jit_device_type)
|
||||
: op_filter_(*op_filter), jit_device_type_(*jit_device_type) {}
|
||||
RecursiveCompilabilityChecker(OperationFilter op_filter,
|
||||
DeviceType jit_device_type)
|
||||
: op_filter_(std::move(op_filter)),
|
||||
jit_device_type_(std::move(jit_device_type)) {}
|
||||
|
||||
using UncompilableNodesMap =
|
||||
std::map<std::string,
|
||||
@ -259,8 +260,8 @@ class RecursiveCompilabilityChecker {
|
||||
// Make sure we don't recurse infinitely on recursive functions.
|
||||
const size_t kMaxRecursionDepth = 10;
|
||||
|
||||
const OperationFilter& op_filter_;
|
||||
const DeviceType& jit_device_type_;
|
||||
const OperationFilter op_filter_;
|
||||
const DeviceType jit_device_type_;
|
||||
};
|
||||
|
||||
RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
||||
@ -282,6 +283,9 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
// set.
|
||||
bool CanCreateXlaKernel(const NodeDef& node_def);
|
||||
|
||||
// Check whether graph can trigger XLA compilation.
|
||||
bool CanTriggerXlaCompilation(const GraphDef& graph);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/functional_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
@ -75,8 +76,8 @@ class CompilabilityCheckUtilTest : public ::testing::Test {
|
||||
op_filter_.allow_inaccurate_ops = false;
|
||||
op_filter_.allow_slow_ops = false;
|
||||
|
||||
checker_ = absl::make_unique<RecursiveCompilabilityChecker>(&op_filter_,
|
||||
&device_type_);
|
||||
checker_ = absl::make_unique<RecursiveCompilabilityChecker>(op_filter_,
|
||||
device_type_);
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime* GetFunctionLibraryRuntime() {
|
||||
@ -354,5 +355,110 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
|
||||
"unsupported op"));
|
||||
}
|
||||
|
||||
TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) {
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
FunctionDefLibrary library;
|
||||
|
||||
FunctionDef identity_func = FunctionDefHelper::Create(
|
||||
"IdentityFunc",
|
||||
/*in_def=*/{"x:float"},
|
||||
/*out_def=*/{"res:float"},
|
||||
/*attr_def=*/{},
|
||||
/*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}},
|
||||
/*ret_def*/ {{"res", "t0:output"}});
|
||||
|
||||
*library.add_function() = identity_func;
|
||||
|
||||
Output in = ops::Placeholder(root, DT_FLOAT);
|
||||
NameAttrList b_name_attr;
|
||||
b_name_attr.set_name("IdentityFunc");
|
||||
ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
|
||||
b_name_attr);
|
||||
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
|
||||
EXPECT_FALSE(CanTriggerXlaCompilation(graph_def));
|
||||
}
|
||||
|
||||
TEST_F(CompilabilityCheckUtilTest, TestXlaOpsCanTriggerXlaCompilation) {
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
FunctionDefLibrary library;
|
||||
|
||||
FunctionDef sort_func = FunctionDefHelper::Create(
|
||||
"SortFunc",
|
||||
/*in_def=*/{"x:float"},
|
||||
/*out_def=*/{"res:float"},
|
||||
/*attr_def=*/{},
|
||||
/*node_def=*/{{{"t0"}, "XlaSort", {"x"}, {{"T", DT_FLOAT}}}},
|
||||
/*ret_def*/ {{"res", "t0:output"}});
|
||||
|
||||
*library.add_function() = sort_func;
|
||||
|
||||
Output in = ops::Placeholder(root, DT_FLOAT);
|
||||
NameAttrList b_name_attr;
|
||||
b_name_attr.set_name("SortFunc");
|
||||
ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
|
||||
b_name_attr);
|
||||
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
|
||||
EXPECT_TRUE(CanTriggerXlaCompilation(graph_def));
|
||||
}
|
||||
|
||||
TEST_F(CompilabilityCheckUtilTest, TestCanTriggerXlaCompilation) {
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
FunctionDefLibrary library;
|
||||
|
||||
AttrValue true_attribute;
|
||||
true_attribute.set_b(true);
|
||||
|
||||
FunctionDef identity_func = FunctionDefHelper::Create(
|
||||
"IdentityFunc",
|
||||
/*in_def=*/{"x:float"},
|
||||
/*out_def=*/{"res:float"},
|
||||
/*attr_def=*/{},
|
||||
/*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}},
|
||||
/*ret_def*/ {{"res", "t0:output"}});
|
||||
|
||||
(*identity_func.mutable_attr())[kXlaMustCompileAttr] = true_attribute;
|
||||
|
||||
FunctionDef call_identity = FunctionDefHelper::Create(
|
||||
"CallIdentity",
|
||||
/*in_def=*/{"x:float"},
|
||||
/*out_def=*/{"z:float"}, /*attr_def=*/{},
|
||||
/*node_def=*/
|
||||
{{{"func_call"},
|
||||
"PartitionedCall",
|
||||
{"x"},
|
||||
{{"Tin", DataTypeSlice({DT_FLOAT})},
|
||||
{"Tout", DataTypeSlice({DT_FLOAT})},
|
||||
{"f",
|
||||
FunctionDefHelper::FunctionRef("IdentityRef", {{"T", DT_FLOAT}})},
|
||||
{kXlaMustCompileAttr, true}}}},
|
||||
/*ret_def=*/{{"z", "func_call:output:0"}});
|
||||
|
||||
*library.add_function() = identity_func;
|
||||
*library.add_function() = call_identity;
|
||||
|
||||
Output in = ops::Placeholder(root, DT_FLOAT);
|
||||
NameAttrList b_name_attr;
|
||||
b_name_attr.set_name("CallIdentity");
|
||||
ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
|
||||
b_name_attr);
|
||||
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
|
||||
EXPECT_TRUE(CanTriggerXlaCompilation(graph_def));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -28,4 +28,6 @@ const char* const kXlaScopeAttr = "_XlaScope";
|
||||
// only when auto_jit is ON.
|
||||
const char* const kXlaInternalScopeAttr = "_XlaInternalScope";
|
||||
|
||||
const char* const kXlaClusterIdAttr = "_xla_compile_id";
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -35,6 +35,9 @@ extern const char* const kXlaCompileAttr; // "_XlaCompile"
|
||||
extern const char* const kXlaScopeAttr; // "_XlaScope"
|
||||
extern const char* const kXlaInternalScopeAttr; // "_XlaInternalScope"
|
||||
|
||||
// The id of the compiled cluster.
|
||||
extern const char* const kXlaClusterIdAttr; // "_xla_compile_id"
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_DEFS_H_
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -34,9 +35,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr =
|
||||
"_xla_compile_id";
|
||||
|
||||
namespace {
|
||||
|
||||
const char* const kXlaClusterOutput = "XlaClusterOutput";
|
||||
@ -45,10 +43,7 @@ bool IsCpuGpuCompile(const Graph* graph) {
|
||||
for (Node* n : graph->nodes()) {
|
||||
string name;
|
||||
// Only consider nodes being compiled.
|
||||
if (!GetNodeAttr(n->attrs(),
|
||||
EncapsulateXlaComputationsPass::kXlaClusterAttr, &name)
|
||||
.ok())
|
||||
continue;
|
||||
if (!GetNodeAttr(n->attrs(), kXlaClusterIdAttr, &name).ok()) continue;
|
||||
// Early return for any node with a device that is not a CPU or GPU.
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
if (DeviceNameUtils::ParseFullName(n->requested_device(), &parsed)) {
|
||||
@ -180,8 +175,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
|
||||
retvals[i]->AddAttr("index", i);
|
||||
}
|
||||
|
||||
AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(),
|
||||
call_def);
|
||||
AddNodeAttr(kXlaClusterIdAttr, call_def->name(), call_def);
|
||||
AddNodeAttr("_variable_start_index", variable_start_index, call_def);
|
||||
|
||||
// Uniquify the function name.
|
||||
@ -216,8 +210,8 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
|
||||
// O(n) pass over the edges.
|
||||
for (const Edge* e : (*graph)->edges()) {
|
||||
if (!e->IsControlEdge() &&
|
||||
e->src()->attrs().Find(kXlaClusterAttr) != nullptr &&
|
||||
e->dst()->attrs().Find(kXlaClusterAttr) == nullptr &&
|
||||
e->src()->attrs().Find(kXlaClusterIdAttr) != nullptr &&
|
||||
e->dst()->attrs().Find(kXlaClusterIdAttr) == nullptr &&
|
||||
e->dst()->type_string() != kXlaClusterOutput) {
|
||||
return errors::InvalidArgument(
|
||||
"Undeclared output of XLA computation. Some common causes of this "
|
||||
@ -232,9 +226,9 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
|
||||
|
||||
auto output = absl::make_unique<Graph>((*graph)->op_registry());
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
EncapsulateSubgraphsInFunctions(kXlaClusterAttr, **graph, RewriteSubgraph,
|
||||
/*reuse_existing_functions=*/true,
|
||||
&output, flib_def),
|
||||
EncapsulateSubgraphsInFunctions(
|
||||
kXlaClusterIdAttr, **graph, RewriteSubgraph,
|
||||
/*reuse_existing_functions=*/true, &output, flib_def),
|
||||
"EncapsulateXlaComputationsPass failed");
|
||||
graph->swap(output);
|
||||
return Status::OK();
|
||||
@ -246,7 +240,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
|
||||
// while iterating.
|
||||
std::vector<Node*> launch_nodes;
|
||||
for (Node* n : graph->nodes()) {
|
||||
const string& name = GetNodeAttrString(n->attrs(), kXlaClusterAttr);
|
||||
const string& name = GetNodeAttrString(n->attrs(), kXlaClusterIdAttr);
|
||||
if (!name.empty()) {
|
||||
launch_nodes.push_back(n);
|
||||
}
|
||||
|
@ -34,8 +34,6 @@ namespace tensorflow {
|
||||
// XlaLaunch operators.
|
||||
class EncapsulateXlaComputationsPass : public GraphOptimizationPass {
|
||||
public:
|
||||
static const char* const kXlaClusterAttr; // _xla_compile_id
|
||||
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
|
||||
// The following methods are public only for unit tests.
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/resource_variable_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/test_util.h"
|
||||
@ -46,19 +47,18 @@ static std::unique_ptr<Graph> MakeOuterGraph(
|
||||
auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
|
||||
|
||||
NodeDef def;
|
||||
TF_CHECK_OK(
|
||||
NodeDefBuilder("launch0", function, &flib_def)
|
||||
.Input(a.node()->name(), 0, DT_INT32)
|
||||
.Input(b.node()->name(), 0, DT_FLOAT)
|
||||
.Input(c.node()->name(), 0, DT_INT32)
|
||||
.Input(d.node()->name(), 0, DT_FLOAT)
|
||||
.Input(u.node()->name(), 0, DT_RESOURCE)
|
||||
.Input(v.node()->name(), 0, DT_RESOURCE)
|
||||
.Input(w.node()->name(), 0, DT_RESOURCE)
|
||||
.Device("/gpu:0")
|
||||
.Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0")
|
||||
.Attr("_variable_start_index", 4)
|
||||
.Finalize(&def));
|
||||
TF_CHECK_OK(NodeDefBuilder("launch0", function, &flib_def)
|
||||
.Input(a.node()->name(), 0, DT_INT32)
|
||||
.Input(b.node()->name(), 0, DT_FLOAT)
|
||||
.Input(c.node()->name(), 0, DT_INT32)
|
||||
.Input(d.node()->name(), 0, DT_FLOAT)
|
||||
.Input(u.node()->name(), 0, DT_RESOURCE)
|
||||
.Input(v.node()->name(), 0, DT_RESOURCE)
|
||||
.Input(w.node()->name(), 0, DT_RESOURCE)
|
||||
.Device("/gpu:0")
|
||||
.Attr(kXlaClusterIdAttr, "launch0")
|
||||
.Attr("_variable_start_index", 4)
|
||||
.Finalize(&def));
|
||||
|
||||
Status status;
|
||||
Node* launch = scope.graph()->AddNode(def, &status);
|
||||
@ -107,7 +107,7 @@ static std::unique_ptr<Graph> MakeBodyGraph() {
|
||||
auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6);
|
||||
|
||||
auto add_attrs = [](Node* node) {
|
||||
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
|
||||
node->AddAttr(kXlaClusterIdAttr, "launch0");
|
||||
node->set_requested_device("/gpu:0");
|
||||
};
|
||||
|
||||
@ -155,8 +155,7 @@ TEST(EncapsulateXlaComputations, DeterministicEncapsulate) {
|
||||
: ops::Add(scope.WithOpName("E"), a1, a0);
|
||||
|
||||
auto add_attrs = [](Node* node) {
|
||||
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr,
|
||||
"launch0");
|
||||
node->AddAttr(kXlaClusterIdAttr, "launch0");
|
||||
};
|
||||
add_attrs(e.node());
|
||||
|
||||
@ -216,7 +215,7 @@ TEST(EncapsulateXlaComputations, Encapsulate) {
|
||||
auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
|
||||
|
||||
auto add_attrs = [](Node* node) {
|
||||
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
|
||||
node->AddAttr(kXlaClusterIdAttr, "launch0");
|
||||
node->set_requested_device("/gpu:0");
|
||||
};
|
||||
|
||||
|
@ -268,4 +268,10 @@ void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
}
|
||||
|
||||
static std::atomic<bool> xla_compilation_disabled(false);
|
||||
|
||||
void DisableXlaCompilation() { xla_compilation_disabled = true; }
|
||||
|
||||
bool FailOnXlaCompilation() { return xla_compilation_disabled; }
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -162,6 +162,13 @@ MlirCommonFlags* GetMlirCommonFlags();
|
||||
void AppendMarkForCompilationPassFlags(
|
||||
std::vector<tensorflow::Flag>* flag_list);
|
||||
|
||||
// Disables XLA compilation, forces it to return an error message instead. Can
|
||||
// be used by a server to ensure that JIT compilation is opt-in.
|
||||
void DisableXlaCompilation();
|
||||
|
||||
// Returns `false` unless `DisableXlaCompilation` was called.
|
||||
bool FailOnXlaCompilation();
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_
|
||||
|
@ -158,7 +158,7 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
||||
constants_(constants),
|
||||
resources_(resources),
|
||||
function_(function),
|
||||
platform_info_(XlaPlatformInfoFromContext(ctx)),
|
||||
platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
|
||||
has_ref_vars_(has_ref_vars) {}
|
||||
|
||||
static Status CompileToLocalExecutable(
|
||||
@ -180,7 +180,7 @@ static Status CompileToLocalExecutable(
|
||||
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
|
||||
rm->default_container(), "xla_cache", &cache,
|
||||
[&](XlaCompilationCache** cache) {
|
||||
return BuildXlaCompilationCache(ctx, platform_info, cache);
|
||||
return BuildXlaCompilationCache(ctx->device(), platform_info, cache);
|
||||
}));
|
||||
// Hold the reference to the JIT during evaluation. (We could probably
|
||||
// free it sooner because the ResourceMgr will retain a reference, but
|
||||
@ -191,7 +191,9 @@ static Status CompileToLocalExecutable(
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
XlaCompiler::Options options = GenerateCompilerOptions(
|
||||
*cache, ctx, platform_info, has_ref_vars, &tf_allocator_adapter);
|
||||
*cache, *ctx->function_library(), ctx->device(),
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
|
||||
platform_info, has_ref_vars, &tf_allocator_adapter);
|
||||
|
||||
std::map<int, Tensor> constant_args;
|
||||
for (int i : constants) {
|
||||
@ -248,8 +250,10 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "Executing XLA Computation...";
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator =
|
||||
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
|
||||
se::DeviceMemoryAllocator* allocator = GetAllocator(
|
||||
&tf_allocator_adapter, ctx->device(),
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
|
||||
platform_info_);
|
||||
int device_ordinal = stream ? stream->parent()->device_ordinal()
|
||||
: client->default_device_ordinal();
|
||||
XlaComputationLaunchContext launch_context(
|
||||
@ -373,7 +377,7 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
|
||||
constants_(ConstantsVector(ctx)),
|
||||
resources_(ResourcesVector(ctx)),
|
||||
function_(FunctionAttr(ctx)),
|
||||
platform_info_(XlaPlatformInfoFromContext(ctx)),
|
||||
platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
|
||||
must_compile_(MustCompileAttr(ctx)),
|
||||
has_ref_vars_(HasRefVars(ctx)) {}
|
||||
|
||||
@ -461,7 +465,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
}
|
||||
|
||||
XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
|
||||
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
|
||||
|
||||
void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(3) << "XlaRunOp " << def().name();
|
||||
@ -472,8 +476,10 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
XlaExecutableClosureStore::Global()->Consume(key);
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator =
|
||||
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
|
||||
se::DeviceMemoryAllocator* allocator = GetAllocator(
|
||||
&tf_allocator_adapter, ctx->device(),
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
|
||||
platform_info_);
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
int device_ordinal = stream ? stream->parent()->device_ordinal()
|
||||
|
@ -1196,12 +1196,9 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
continue;
|
||||
}
|
||||
|
||||
DeviceType jit_device_type(registration->compilation_device_name);
|
||||
|
||||
RecursiveCompilabilityChecker::OperationFilter op_filter =
|
||||
CreateOperationFilter(*registration);
|
||||
|
||||
if (!RecursiveCompilabilityChecker{&op_filter, &jit_device_type}
|
||||
if (!RecursiveCompilabilityChecker{
|
||||
CreateOperationFilter(*registration),
|
||||
DeviceType{registration->compilation_device_name}}
|
||||
.IsCompilableNode(*node, lib_runtime)) {
|
||||
continue;
|
||||
}
|
||||
@ -1718,7 +1715,6 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
|
||||
®istration));
|
||||
DeviceType jit_device_type(registration->compilation_device_name);
|
||||
|
||||
// We can always *compile* resource operations, stateful RNGs and dummy ops,
|
||||
// even if we are sometimes unable to auto-cluster them.
|
||||
@ -1733,7 +1729,8 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
op_filter.allow_slow_ops = true;
|
||||
op_filter.allow_inaccurate_ops = true;
|
||||
|
||||
RecursiveCompilabilityChecker checker{&op_filter, &jit_device_type};
|
||||
RecursiveCompilabilityChecker checker{
|
||||
op_filter, DeviceType{registration->compilation_device_name}};
|
||||
if (!uncompilable_node_info) {
|
||||
// We do not need uncompilable node info. Just return the result.
|
||||
return checker.IsCompilableCall(ndef, flr);
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "absl/base/call_once.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||
@ -323,6 +324,10 @@ Status XlaCompilationCache::CompileImpl(
|
||||
absl::optional<int64> compile_threshold,
|
||||
const XlaCompiler::CompilationResult** out_compilation_result,
|
||||
xla::LocalExecutable** out_executable) {
|
||||
if (FailOnXlaCompilation()) {
|
||||
return errors::Internal("XLA compilation disabled");
|
||||
}
|
||||
|
||||
DCHECK_NE(out_executable, nullptr);
|
||||
VLOG(2) << "XlaCompilationCache::Compile " << DebugString();
|
||||
|
||||
|
@ -15,7 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
@ -52,6 +54,30 @@ TEST(XlaCompilationCacheTest, SignatureEquality) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(XlaCompilationCacheTest, TestDisabledXlaCompilation) {
|
||||
NameAttrList fn;
|
||||
fn.set_name("afunction");
|
||||
|
||||
DisableXlaCompilation();
|
||||
|
||||
xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie();
|
||||
DeviceType device_type = DeviceType(DEVICE_CPU_XLA_JIT);
|
||||
|
||||
const XlaCompiler::CompilationResult* compilation_result;
|
||||
xla::LocalExecutable* executable;
|
||||
|
||||
auto cache = new XlaCompilationCache(client, device_type);
|
||||
core::ScopedUnref cache_ref(cache);
|
||||
|
||||
Status status = cache->Compile(XlaCompiler::Options{}, fn, {},
|
||||
XlaCompiler::CompileOptions{},
|
||||
XlaCompilationCache::CompileMode::kStrict,
|
||||
&compilation_result, &executable);
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_TRUE(
|
||||
absl::StrContains(status.error_message(), "XLA compilation disabled"));
|
||||
}
|
||||
|
||||
static void BM_BuildSignature(int iters, int n_args) {
|
||||
NameAttrList fn;
|
||||
fn.set_name("afunction");
|
||||
|
@ -49,8 +49,10 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator =
|
||||
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
|
||||
se::DeviceMemoryAllocator* allocator = GetAllocator(
|
||||
&tf_allocator_adapter, ctx->device(),
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
|
||||
platform_info_);
|
||||
XlaComputationLaunchContext launch_context(
|
||||
client, allocator, client->default_device_ordinal(),
|
||||
/*allocate_xla_tensors=*/platform_info_.xla_device_metadata() != nullptr,
|
||||
@ -157,13 +159,16 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
|
||||
rm->default_container(), "xla_cache", cache,
|
||||
[&](XlaCompilationCache** write_into_cache) {
|
||||
return BuildXlaCompilationCache(ctx, platform_info_, write_into_cache);
|
||||
return BuildXlaCompilationCache(ctx->device(), platform_info_,
|
||||
write_into_cache);
|
||||
}));
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
XlaCompiler::Options options =
|
||||
GenerateCompilerOptions(**cache, ctx, platform_info_,
|
||||
/*has_ref_vars=*/true, &tf_allocator_adapter);
|
||||
XlaCompiler::Options options = GenerateCompilerOptions(
|
||||
**cache, *ctx->function_library(), ctx->device(),
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
|
||||
platform_info_,
|
||||
/*has_ref_vars=*/true, &tf_allocator_adapter);
|
||||
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.is_entry_computation = true;
|
||||
|
@ -37,7 +37,8 @@ namespace tensorflow {
|
||||
class XlaCompileOnDemandOp : public OpKernel {
|
||||
public:
|
||||
explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
|
||||
: OpKernel(ctx),
|
||||
platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
|
@ -51,7 +51,7 @@ Status XlaCpuDeviceFactory::CreateDevices(
|
||||
std::vector<std::unique_ptr<Device>>* devices) {
|
||||
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||
if (!flags->tf_xla_enable_xla_devices) {
|
||||
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||
VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||
return Status::OK();
|
||||
}
|
||||
bool compile_on_demand = flags->tf_xla_compile_on_demand;
|
||||
|
@ -94,6 +94,11 @@ class XlaDevice : public LocalDevice {
|
||||
static Status GetMetadata(OpKernelConstruction* ctx,
|
||||
const Metadata** metadata);
|
||||
|
||||
// Sets `*metadata` to the XlaDevice Metadata in the XLA device used by
|
||||
// `device`.
|
||||
static Status GetMetadataFromDevice(DeviceBase* device,
|
||||
const XlaDevice::Metadata** metadata);
|
||||
|
||||
struct Options {
|
||||
// The StreamExecutor platform. Not owned. Must be non-null.
|
||||
se::Platform* platform = nullptr;
|
||||
@ -196,8 +201,6 @@ class XlaDevice : public LocalDevice {
|
||||
xla::StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>>
|
||||
GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
static Status GetMetadataFromDevice(DeviceBase* device,
|
||||
const XlaDevice::Metadata** metadata);
|
||||
|
||||
Status MakeTensorFromProto(XlaDeviceContext* device_context,
|
||||
const TensorProto& tensor_proto,
|
||||
|
@ -66,7 +66,7 @@ class XlaGpuDeviceFactory : public DeviceFactory {
|
||||
Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
||||
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||
if (!flags->tf_xla_enable_xla_devices) {
|
||||
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||
VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -44,12 +44,6 @@ namespace {
|
||||
using xla::ScopedShapedBuffer;
|
||||
using xla::ShapedBuffer;
|
||||
|
||||
const char kPossibleNonVariableResourceHintMessage[] =
|
||||
"If the error is similar to `Trying to access resource using the wrong "
|
||||
"type`, this is likely because XLA only accepts Resource Variables as "
|
||||
"inputs by snapshotting their values. Other TensorFlow resource types like "
|
||||
"TensorList/TensorArray/Stack are not supported. Try removing non-variable "
|
||||
"resource inputs to XLA.";
|
||||
} // anonymous namespace
|
||||
|
||||
VariableInfo::VariableInfo(int index, absl::string_view name, Var* var)
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status BuildXlaCompilationCache(OpKernelContext* ctx,
|
||||
Status BuildXlaCompilationCache(DeviceBase* device,
|
||||
const XlaPlatformInfo& platform_info,
|
||||
XlaCompilationCache** cache) {
|
||||
if (platform_info.xla_device_metadata()) {
|
||||
@ -59,7 +59,7 @@ Status BuildXlaCompilationCache(OpKernelContext* ctx,
|
||||
xla::LocalClientOptions client_options;
|
||||
client_options.set_platform(platform.ValueOrDie());
|
||||
client_options.set_intra_op_parallelism_threads(
|
||||
ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
|
||||
device->tensorflow_cpu_worker_threads()->num_threads);
|
||||
auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
|
||||
if (!client.ok()) {
|
||||
return client.status();
|
||||
@ -75,21 +75,21 @@ Status BuildXlaCompilationCache(OpKernelContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) {
|
||||
DeviceType device_type = ctx->device_type();
|
||||
XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
|
||||
auto device = static_cast<Device*>(device_base);
|
||||
se::Platform::Id platform_id = nullptr;
|
||||
const XlaDevice::Metadata* xla_device_metadata = nullptr;
|
||||
se::DeviceMemoryAllocator* custom_allocator = nullptr;
|
||||
|
||||
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
|
||||
if (device->device_type() == DEVICE_CPU) {
|
||||
platform_id = se::host::kHostPlatformId;
|
||||
} else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
|
||||
platform_id = ctx->device()
|
||||
->tensorflow_gpu_device_info()
|
||||
} else if (device->device_type() == DEVICE_GPU) {
|
||||
platform_id = device->tensorflow_gpu_device_info()
|
||||
->stream->parent()
|
||||
->platform()
|
||||
->id();
|
||||
} else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
|
||||
} else if (XlaDevice::GetMetadataFromDevice(device, &xla_device_metadata)
|
||||
.ok()) {
|
||||
// If we are on an XlaDevice, use the underlying XLA platform's allocator
|
||||
// directly. We could use the StreamExecutor's allocator which may
|
||||
// theoretically be more correct, but XLA returns a nice OOM message in a
|
||||
@ -104,47 +104,46 @@ XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) {
|
||||
xla_device_metadata->client()->backend().memory_allocator();
|
||||
}
|
||||
|
||||
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
|
||||
custom_allocator);
|
||||
return XlaPlatformInfo(DeviceType(device->device_type()), platform_id,
|
||||
xla_device_metadata, custom_allocator);
|
||||
}
|
||||
|
||||
se::DeviceMemoryAllocator* GetAllocator(
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
|
||||
OpKernelContext* ctx, const XlaPlatformInfo& platform_info) {
|
||||
DeviceBase* device, se::Stream* stream,
|
||||
const XlaPlatformInfo& platform_info) {
|
||||
if (platform_info.custom_allocator()) {
|
||||
return platform_info.custom_allocator();
|
||||
}
|
||||
if (!ctx->op_device_context()) {
|
||||
if (!stream) {
|
||||
// Stream is not set for the host platform.
|
||||
se::Platform* platform =
|
||||
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
|
||||
.ValueOrDie();
|
||||
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform);
|
||||
tf_allocator_adapter->emplace(device->GetAllocator({}), platform);
|
||||
return &tf_allocator_adapter->value();
|
||||
}
|
||||
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}),
|
||||
ctx->op_device_context()->stream());
|
||||
tf_allocator_adapter->emplace(device->GetAllocator({}), stream);
|
||||
return &tf_allocator_adapter->value();
|
||||
}
|
||||
|
||||
XlaCompiler::Options GenerateCompilerOptions(
|
||||
const XlaCompilationCache& cache, OpKernelContext* ctx,
|
||||
const XlaPlatformInfo& platform_info, bool has_ref_vars,
|
||||
const XlaCompilationCache& cache,
|
||||
const FunctionLibraryRuntime& function_library, DeviceBase* device,
|
||||
se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars,
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter) {
|
||||
CHECK(ctx->function_library());
|
||||
XlaCompiler::Options options;
|
||||
options.client = static_cast<xla::LocalClient*>(cache.client());
|
||||
if (ctx->op_device_context() != nullptr) {
|
||||
options.device_ordinal =
|
||||
ctx->op_device_context()->stream()->parent()->device_ordinal();
|
||||
if (stream != nullptr) {
|
||||
options.device_ordinal = stream->parent()->device_ordinal();
|
||||
}
|
||||
options.device_type = cache.device_type();
|
||||
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
|
||||
options.graph_def_version = ctx->function_library()->graph_def_version();
|
||||
options.flib_def = function_library.GetFunctionLibraryDefinition();
|
||||
options.graph_def_version = function_library.graph_def_version();
|
||||
options.allow_cpu_custom_calls =
|
||||
(platform_info.platform_id() == se::host::kHostPlatformId);
|
||||
options.device_allocator =
|
||||
GetAllocator(tf_allocator_adapter, ctx, platform_info);
|
||||
GetAllocator(tf_allocator_adapter, device, stream, platform_info);
|
||||
if (platform_info.xla_device_metadata()) {
|
||||
options.shape_representation_fn =
|
||||
platform_info.xla_device_metadata()->shape_representation_fn();
|
||||
|
@ -80,27 +80,31 @@ class XlaPlatformInfo {
|
||||
};
|
||||
|
||||
// Returns created XLA compilation cache.
|
||||
Status BuildXlaCompilationCache(OpKernelContext* ctx,
|
||||
Status BuildXlaCompilationCache(DeviceBase* dev,
|
||||
const XlaPlatformInfo& platform_info,
|
||||
XlaCompilationCache** cache);
|
||||
|
||||
// Returns information about the platform from kernel context.
|
||||
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx);
|
||||
XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device);
|
||||
|
||||
// Returns allocator from platform info if non-null, or populate and return a
|
||||
// pointer to the allocator adapter with allocator from context.
|
||||
//
|
||||
// This is necessary because for XLA devices the underlying TF allocator returns
|
||||
// dummy tensors.
|
||||
//
|
||||
// `stream` parameter is nullable when running on host.
|
||||
se::DeviceMemoryAllocator* GetAllocator(
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
|
||||
OpKernelContext* ctx, const XlaPlatformInfo& platform_info);
|
||||
DeviceBase* device, se::Stream* stream,
|
||||
const XlaPlatformInfo& platform_info);
|
||||
|
||||
// Returns created options for the XLA compiler, and writes the used allocator
|
||||
// into `tf_allocator_adapter`.
|
||||
XlaCompiler::Options GenerateCompilerOptions(
|
||||
const XlaCompilationCache& cache, OpKernelContext* ctx,
|
||||
const XlaPlatformInfo& platform_info, bool has_ref_vars,
|
||||
const XlaCompilationCache& cache,
|
||||
const FunctionLibraryRuntime& function_library, DeviceBase* device,
|
||||
se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars,
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -24,11 +24,40 @@ filegroup(
|
||||
srcs = glob(["**/*.td"]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "string_container_utils",
|
||||
hdrs = ["utils/string_container_utils.h"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "array_container_utils",
|
||||
hdrs = ["utils/array_container_utils.h"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "name_utils",
|
||||
srcs = ["utils/name_utils.cc"],
|
||||
hdrs = ["utils/name_utils.h"],
|
||||
deps = [
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "op_or_arg_name_mapper",
|
||||
srcs = ["op_or_arg_name_mapper.cc"],
|
||||
hdrs = ["op_or_arg_name_mapper.h"],
|
||||
deps = [
|
||||
":name_utils",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
|
@ -341,6 +341,7 @@ cc_library(
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
@ -348,6 +349,22 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mhlo_control_flow_to_scf",
|
||||
srcs = ["lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc"],
|
||||
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "map_lmhlo_to_scalar_op",
|
||||
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"],
|
||||
@ -800,6 +817,7 @@ cc_library(
|
||||
":lhlo_legalize_to_affine",
|
||||
":lhlo_legalize_to_gpu",
|
||||
":lhlo_legalize_to_parallel_loops",
|
||||
":mhlo_control_flow_to_scf",
|
||||
":mhlo_fusion",
|
||||
":mhlo_to_mhlo_lowering_patterns",
|
||||
":sink_constants_to_control_flow",
|
||||
|
@ -30,6 +30,11 @@ def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> {
|
||||
let constructor = "createLegalizeControlFlowPass()";
|
||||
}
|
||||
|
||||
def LegalizeControlFlowToScfPass : Pass<"mhlo-control-flow-to-scf", "FuncOp"> {
|
||||
let summary = "Legalize from MHLO control flow to SCF control flow.";
|
||||
let constructor = "createControlFlowToScfPass()";
|
||||
}
|
||||
|
||||
def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> {
|
||||
let summary = "Legalizes gathers to a torch index select.";
|
||||
let constructor = "createLegalizeGatherToTorchIndexSelectPass()";
|
||||
|
@ -35,6 +35,9 @@ namespace mhlo {
|
||||
/// Lowers HLO control flow ops to the Standard dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
|
||||
|
||||
/// Lowers MHLO control flow ops to the SCF dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass();
|
||||
|
||||
/// Lowers from HLO dialect to Standard dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
||||
|
||||
|
@ -93,6 +93,7 @@ add_mlir_library(MhloToLhloConversion
|
||||
add_mlir_library(MhloToStandard
|
||||
legalize_control_flow.cc
|
||||
legalize_to_standard.cc
|
||||
mhlo_control_flow_to_scf.cc
|
||||
|
||||
DEPENDS
|
||||
MLIRhlo_opsIncGen
|
||||
|
@ -0,0 +1,199 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#define DEBUG_TYPE "mhlo-control-flow-to-scf"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
|
||||
namespace {
|
||||
|
||||
/// Convert MHLO While to SCF.
|
||||
void MatchAndRewrite(WhileOp whileOp);
|
||||
|
||||
/// Pass that converts MHLO control flow to SCF.
|
||||
class ControlFlowToScfPass
|
||||
: public mlir::PassWrapper<ControlFlowToScfPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry& registry) const override {
|
||||
registry.insert<scf::SCFDialect>();
|
||||
}
|
||||
void runOnFunction() override {
|
||||
getFunction().walk([&](WhileOp whileOp) { MatchAndRewrite(whileOp); });
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(jpienaar): Look into reformulating as a pattern.
|
||||
void MatchAndRewrite(WhileOp whileOp) {
|
||||
// Handle pattern:
|
||||
// x = start
|
||||
// step = ...
|
||||
// limit = ...
|
||||
// while (x < limit) { ... x += step; }
|
||||
|
||||
// Only handling multi value while loops at the moment.
|
||||
auto tupleOp = whileOp.getOperand().getDefiningOp<TupleOp>();
|
||||
if (!tupleOp) return;
|
||||
auto bodyReturn = whileOp.body()
|
||||
.front()
|
||||
.getTerminator()
|
||||
->getOperand(0)
|
||||
.getDefiningOp<mhlo::TupleOp>();
|
||||
// Note: due to the shape restrictions on While, if the operand to While is a
|
||||
// tuple, then so is the return type of the body. But the verifier isn't
|
||||
// checking that at the moment, so just bail out here if this doesn't hold.
|
||||
if (!bodyReturn) return;
|
||||
|
||||
Value result = whileOp.cond().front().getTerminator()->getOperand(0);
|
||||
// TODO(jpienaar): Expand to handle more than simple case with LT compare and
|
||||
// constant step.
|
||||
auto cmp = result.getDefiningOp<mhlo::CompareOp>();
|
||||
if (!cmp || cmp.comparison_direction() != "LT") return;
|
||||
|
||||
const int kConstant = -1;
|
||||
auto getValueAndIndex = [&](Value val) -> std::pair<Value, int> {
|
||||
if (matchPattern(val, m_Constant())) return {val, kConstant};
|
||||
// If it is defined by a tuple, then the tuple has to have been fed in and
|
||||
// the external value is captured.
|
||||
if (auto gte = val.getDefiningOp<GetTupleElementOp>()) {
|
||||
if (!gte.getOperand().isa<mlir::BlockArgument>()) return {nullptr, 0};
|
||||
int index = gte.index().getSExtValue();
|
||||
return {tupleOp.getOperand(index), index};
|
||||
}
|
||||
return {nullptr, 0};
|
||||
};
|
||||
|
||||
using ValueIndex = std::pair<Value, int>;
|
||||
ValueIndex loopIndVar = getValueAndIndex(cmp.lhs());
|
||||
ValueIndex max = getValueAndIndex(cmp.rhs());
|
||||
if (!loopIndVar.first || !max.first) return;
|
||||
auto add =
|
||||
bodyReturn.getOperand(loopIndVar.second).getDefiningOp<mhlo::AddOp>();
|
||||
if (!add) return;
|
||||
ValueIndex step = getValueAndIndex(add.rhs());
|
||||
if (step.second != kConstant || !step.first) return;
|
||||
|
||||
// Only handle case where tuple isn't propagated as is for now.
|
||||
// TODO(jpienaar): Remove this when a tuple is also created inside the loop
|
||||
// to propagate.
|
||||
for (auto* use : whileOp.body().front().getArgument(0).getUsers())
|
||||
if (!isa<GetTupleElementOp>(use)) return;
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "Found for (" << whileOp.getLoc() << "):\n";
|
||||
llvm::dbgs() << " loopIndVar = " << loopIndVar.second << " max = "
|
||||
<< max.second << " step = " << step.second << "\n";
|
||||
llvm::dbgs() << " loopIndVar = " << loopIndVar.first << " max = "
|
||||
<< max.first << " step = " << step.first << "\n";);
|
||||
OpBuilder b(whileOp);
|
||||
// Inputs to new for loop.
|
||||
llvm::SmallVector<Value, 4> input;
|
||||
input.reserve(tupleOp.getNumOperands());
|
||||
for (auto r : tupleOp.getOperands().take_front(loopIndVar.second))
|
||||
input.push_back(r);
|
||||
for (auto r : tupleOp.getOperands().drop_front(loopIndVar.second + 1))
|
||||
input.push_back(r);
|
||||
|
||||
auto tensorIndexType = RankedTensorType::get({}, b.getIndexType());
|
||||
auto getAsIndex = [&](Value val) {
|
||||
auto loc = whileOp.getLoc();
|
||||
return b.create<ExtractElementOp>(
|
||||
loc, b.create<IndexCastOp>(loc, tensorIndexType, val), ValueRange());
|
||||
};
|
||||
|
||||
// SCF for uses index type, so converted these.
|
||||
auto forloopIndVar = getAsIndex(loopIndVar.first);
|
||||
auto forMax = getAsIndex(max.first);
|
||||
auto forStep = getAsIndex(step.first);
|
||||
auto forOp = b.create<mlir::scf::ForOp>(whileOp.getLoc(), forloopIndVar,
|
||||
forMax, forStep, input);
|
||||
// Transfer the body without the block arguments.
|
||||
forOp.getLoopBody().front().getOperations().splice(
|
||||
forOp.getLoopBody().front().getOperations().end(),
|
||||
whileOp.body().front().getOperations());
|
||||
|
||||
b.setInsertionPointToStart(&forOp.getLoopBody().front());
|
||||
auto loopIndVarElType =
|
||||
loopIndVar.first.getType().cast<ShapedType>().getElementType();
|
||||
Value indVar = b.create<SplatOp>(
|
||||
whileOp.getLoc(), RankedTensorType::get({}, loopIndVarElType),
|
||||
b.create<IndexCastOp>(whileOp.getLoc(), loopIndVarElType,
|
||||
forOp.getInductionVar()));
|
||||
// Update all block argument users to the SCF For args.
|
||||
for (auto* use :
|
||||
llvm::make_early_inc_range(whileOp.body().getArgument(0).getUsers())) {
|
||||
// TODO(jpienaar): Expand here too when we allow using the tuple in the
|
||||
// loop.
|
||||
auto gte = cast<GetTupleElementOp>(use);
|
||||
// If the loop induction var, then refer to the loop induction variable as
|
||||
// this operand is not updated.
|
||||
if (gte.index() == loopIndVar.second) {
|
||||
use->getResult(0).replaceAllUsesWith(indVar);
|
||||
use->erase();
|
||||
continue;
|
||||
}
|
||||
int index = gte.index().getSExtValue();
|
||||
// If after the loop induction variable, then decrement as we don't include
|
||||
// the loop induction variable in the for iter operands.
|
||||
if (index > loopIndVar.second) --index;
|
||||
use->getResult(0).replaceAllUsesWith(forOp.getIterOperands()[index]);
|
||||
use->erase();
|
||||
}
|
||||
|
||||
// Create new yield op without induction var update.
|
||||
SmallVector<Value, 4> newYieldOps;
|
||||
newYieldOps.reserve(bodyReturn.getNumOperands() - 1);
|
||||
for (auto r : bodyReturn.getOperands().take_front(loopIndVar.second))
|
||||
newYieldOps.push_back(r);
|
||||
for (auto r : bodyReturn.getOperands().drop_front(loopIndVar.second + 1))
|
||||
newYieldOps.push_back(r);
|
||||
// Delete return & tuple op.
|
||||
forOp.getLoopBody().front().back().erase();
|
||||
forOp.getLoopBody().front().back().erase();
|
||||
b.setInsertionPointToEnd(&forOp.getLoopBody().front());
|
||||
b.create<scf::YieldOp>(whileOp.getLoc(), newYieldOps);
|
||||
|
||||
// Recombine output tuple with max value of induction variable.
|
||||
llvm::SmallVector<Value, 4> loopOut;
|
||||
loopOut.reserve(forOp.getNumResults() + 1);
|
||||
for (auto r : forOp.getResults().take_front(loopIndVar.second))
|
||||
loopOut.push_back(r);
|
||||
loopOut.push_back(max.first);
|
||||
for (auto r : forOp.getResults().drop_front(loopIndVar.second))
|
||||
loopOut.push_back(r);
|
||||
b.setInsertionPoint(whileOp);
|
||||
auto newRes = b.create<mhlo::TupleOp>(whileOp.getLoc(), loopOut);
|
||||
whileOp.replaceAllUsesWith(newRes.getOperation());
|
||||
whileOp.erase();
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass() {
|
||||
return std::make_unique<ControlFlowToScfPass>();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
38
tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir
Normal file
38
tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir
Normal file
@ -0,0 +1,38 @@
|
||||
// RUN: mlir-hlo-opt --mhlo-control-flow-to-scf %s | FileCheck %s
|
||||
|
||||
func @lt_loop(%arg0: tensor<4xf32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<4xf32>, %arg4: tensor<f32>, %arg5: tensor<f32>, %arg6: tensor<f32>, %arg7: tensor<f32>, %arg8: tensor<i32>) -> (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) {
|
||||
%cst = constant dense<-1> : tensor<i32>
|
||||
%cst_0 = constant dense<1> : tensor<i32>
|
||||
%cst_1 = constant dense<0> : tensor<i32>
|
||||
%cst_2 = constant dense<1000> : tensor<i32>
|
||||
%0 = "mhlo.tuple"(%cst_1, %cst, %cst_2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
|
||||
%1 = "mhlo.while"(%0) ( {
|
||||
^bb0(%arg9: tuple<tensor<i32>, tensor<i32>, tensor<i32>>): // no predecessors
|
||||
%2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
|
||||
%3 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
|
||||
%4 = "mhlo.compare"(%2, %3) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"mhlo.return"(%4) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg9: tuple<tensor<i32>, tensor<i32>, tensor<i32>>): // no predecessors
|
||||
%2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
|
||||
%3 = mhlo.add %2, %cst_0 : tensor<i32>
|
||||
%4 = "mhlo.get_tuple_element"(%arg9) {index = 1 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
|
||||
%5 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
|
||||
%6 = "mhlo.tuple"(%3, %4, %5) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
|
||||
"mhlo.return"(%6) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> ()
|
||||
}) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
|
||||
return %1 : tuple<tensor<i32>, tensor<i32>, tensor<i32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @lt_loop(
|
||||
// CHECK: %[[VAL_9:.*]] = constant dense<-1> : tensor<i32>
|
||||
// CHECK: %[[VAL_10:.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: %[[VAL_11:.*]] = constant dense<0> : tensor<i32>
|
||||
// CHECK: %[[VAL_12:.*]] = constant dense<1000> : tensor<i32>
|
||||
// CHECK: %[[VAL_14:.*]] = index_cast %[[VAL_11]] : tensor<i32> to tensor<index>
|
||||
// CHECK: %[[VAL_15:.*]] = extract_element %[[VAL_14]][] : tensor<index>
|
||||
// CHECK: %[[VAL_16:.*]] = index_cast %[[VAL_12]] : tensor<i32> to tensor<index>
|
||||
// CHECK: %[[VAL_17:.*]] = extract_element %[[VAL_16]][] : tensor<index>
|
||||
// CHECK: %[[VAL_18:.*]] = index_cast %[[VAL_10]] : tensor<i32> to tensor<index>
|
||||
// CHECK: %[[VAL_19:.*]] = extract_element %[[VAL_18]][] : tensor<index>
|
||||
// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_19]] iter_args(%[[VAL_22:.*]] = %[[VAL_9]], %[[VAL_23:.*]] = %[[VAL_12]])
|
@ -1029,14 +1029,49 @@ func @splitv(%arg0: tensor<1x4x3x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<i32
|
||||
// CHECK: "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 2 : i32} : (tensor<1x4x3x3xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>)
|
||||
}
|
||||
|
||||
func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
func @matmul(%arg0: tensor<40x37xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> {
|
||||
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = false} :
|
||||
(tensor<40x37xf32>, tensor<37x40xf32>) -> tensor<40x40xf32>
|
||||
return %0 : tensor<40x40xf32>
|
||||
// CHECK-LABEL: matmul
|
||||
// CHECK: %[[CST:.*]] = constant dense<[1, 0]> : tensor<2xi32>
|
||||
// CHECK: %[[ARG:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
|
||||
// CHECK: %[[CST_0:.*]] = constant unit
|
||||
// CHECK: "tfl.fully_connected"(%arg0, %[[ARG]], %[[CST_0]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
|
||||
}
|
||||
|
||||
func @matmul_transposed_a(%arg0: tensor<37x40xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> {
|
||||
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = true, transpose_b = false} :
|
||||
(tensor<37x40xf32>, tensor<37x40xf32>) -> tensor<40x40xf32>
|
||||
return %0 : tensor<40x40xf32>
|
||||
// CHECK-LABEL: matmul_transposed_a
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32>
|
||||
// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<[1, 0]> : tensor<2xi32>
|
||||
// CHECK: %[[ARG_1:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
|
||||
// CHECK: %[[CST_2:.*]] = constant unit
|
||||
// CHECK: "tfl.fully_connected"(%[[ARG_0]], %[[ARG_1]], %[[CST_2]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
|
||||
}
|
||||
|
||||
func @matmul_transposed_b(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = true} :
|
||||
(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32>
|
||||
return %0 : tensor<40x40xf32>
|
||||
// CHECK-LABEL: matmul_transposed
|
||||
// CHECK-LABEL: matmul_transposed_b
|
||||
// CHECK: "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
|
||||
}
|
||||
|
||||
func @matmul_transposed_ab(%arg0: tensor<37x40xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = true, transpose_b = true} :
|
||||
(tensor<37x40xf32>, tensor<40x37xf32>) -> tensor<40x40xf32>
|
||||
return %0 : tensor<40x40xf32>
|
||||
// CHECK-LABEL: matmul_transposed_ab
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32>
|
||||
// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
|
||||
// CHECK: %[[CST_1:.*]] = constant unit
|
||||
// CHECK: "tfl.fully_connected"(%[[ARG_0]], %arg1, %[[CST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
|
||||
}
|
||||
|
||||
func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
|
||||
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
|
||||
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i32>) -> tensor<2x3xi32>
|
||||
|
@ -66,7 +66,6 @@ namespace TFL {
|
||||
// The actual LegalizeTF Pass.
|
||||
namespace {
|
||||
|
||||
using xla::Status;
|
||||
using xla::StatusOr;
|
||||
|
||||
constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm";
|
||||
@ -232,26 +231,47 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
|
||||
return success();
|
||||
}
|
||||
|
||||
// The following is effectively:
|
||||
// def : Pat<
|
||||
// (TF_MatMulOp $a, $b, ConstBoolAttrFalse:$transpose_a,
|
||||
// ConstBoolAttrTrue:$transpose_b),
|
||||
// (TFL_FullyConnectedOp:$__0 $a, $b,
|
||||
// NoInput.pattern, TFL_AF_None, TFL_FCWO_Default, ConstBoolAttrFalse)>;
|
||||
LogicalResult ConvertTFMatMulOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_matmul_op = cast<TF::MatMulOp>(op);
|
||||
if (tf_matmul_op.transpose_a()) return failure();
|
||||
if (!tf_matmul_op.transpose_b()) return failure();
|
||||
auto lhs = op->getOperand(0);
|
||||
auto rhs = op->getOperand(1);
|
||||
auto transpose = [&](Value input) -> std::pair<LogicalResult, Value> {
|
||||
RankedTensorType type =
|
||||
input.getType().dyn_cast_or_null<RankedTensorType>();
|
||||
if (!type || type.getRank() != 2) return {failure(), nullptr};
|
||||
|
||||
auto permute_attr = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({2}, rewriter.getI32Type()), {1, 0});
|
||||
auto permute = rewriter.create<ConstantOp>(
|
||||
op->getLoc(), permute_attr.getType(), permute_attr);
|
||||
llvm::SmallVector<int64_t, 2> new_shape{type.getShape()[1],
|
||||
type.getShape()[0]};
|
||||
auto output = rewriter.create<TFL::TransposeOp>(
|
||||
op->getLoc(), RankedTensorType::get(new_shape, type.getElementType()),
|
||||
input, permute);
|
||||
return {success(), output};
|
||||
};
|
||||
|
||||
// TODO(jpienaar): Remove once handled via dailect conversion.
|
||||
if (tf_matmul_op.transpose_a()) {
|
||||
LogicalResult result = success();
|
||||
std::tie(result, lhs) = transpose(lhs);
|
||||
if (failed(result)) return failure();
|
||||
}
|
||||
if (!tf_matmul_op.transpose_b()) {
|
||||
LogicalResult result = success();
|
||||
std::tie(result, rhs) = transpose(rhs);
|
||||
if (failed(result)) return failure();
|
||||
}
|
||||
|
||||
Type output_type = tf_matmul_op.getResult().getType();
|
||||
// TODO(jpienaar): Follow up post shuffle discussion.
|
||||
auto no_input = rewriter.create<ConstantOp>(
|
||||
op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
|
||||
auto fc_op = rewriter.create<FullyConnectedOp>(
|
||||
op->getLoc(), ArrayRef<Type>{output_type}, op->getOperand(0),
|
||||
op->getOperand(1), no_input, rewriter.getStringAttr("NONE"),
|
||||
rewriter.getStringAttr("DEFAULT"), rewriter.getBoolAttr(false));
|
||||
op->getLoc(), ArrayRef<Type>{output_type}, lhs, rhs, no_input,
|
||||
rewriter.getStringAttr("NONE"), rewriter.getStringAttr("DEFAULT"),
|
||||
rewriter.getBoolAttr(false));
|
||||
rewriter.replaceOp(op, {fc_op.getResult(0)});
|
||||
return success();
|
||||
}
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/utils/name_utils.h"
|
||||
|
||||
static inline absl::string_view StringRefToView(llvm::StringRef ref) {
|
||||
return absl::string_view(ref.data(), ref.size());
|
||||
@ -103,62 +104,16 @@ int OpOrArgNameMapper::InitOpName(OpOrVal op_or_val, llvm::StringRef name) {
|
||||
|
||||
bool OpOrArgNameMapper::IsUnique(llvm::StringRef name) { return true; }
|
||||
|
||||
namespace {
|
||||
// Derives name from location.
|
||||
std::string GetNameFromLoc(mlir::Location loc) {
|
||||
llvm::SmallVector<llvm::StringRef, 8> loc_names;
|
||||
llvm::SmallVector<mlir::Location, 8> locs;
|
||||
locs.push_back(loc);
|
||||
bool names_is_nonempty = false;
|
||||
|
||||
while (!locs.empty()) {
|
||||
mlir::Location curr_loc = locs.pop_back_val();
|
||||
|
||||
if (auto name_loc = curr_loc.dyn_cast<mlir::NameLoc>()) {
|
||||
// Add name in NameLoc. For NameLoc we also account for names due to ops
|
||||
// in functions where the op's name is first.
|
||||
auto name = name_loc.getName().strref().split('@').first;
|
||||
loc_names.push_back(name);
|
||||
if (!name.empty()) names_is_nonempty = true;
|
||||
continue;
|
||||
} else if (auto call_loc = curr_loc.dyn_cast<mlir::CallSiteLoc>()) {
|
||||
// Add name if CallSiteLoc's callee has a NameLoc (as should be the
|
||||
// case if imported with DebugInfo).
|
||||
if (auto name_loc = call_loc.getCallee().dyn_cast<mlir::NameLoc>()) {
|
||||
auto name = name_loc.getName().strref().split('@').first;
|
||||
loc_names.push_back(name);
|
||||
if (!name.empty()) names_is_nonempty = true;
|
||||
continue;
|
||||
}
|
||||
} else if (auto fused_loc = curr_loc.dyn_cast<mlir::FusedLoc>()) {
|
||||
// Push all locations in FusedLoc in reverse order, so locations are
|
||||
// visited based on order in FusedLoc.
|
||||
auto reversed_fused_locs = llvm::reverse(fused_loc.getLocations());
|
||||
locs.append(reversed_fused_locs.begin(), reversed_fused_locs.end());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Location is not a supported, so an empty StringRef is added.
|
||||
loc_names.push_back(llvm::StringRef());
|
||||
}
|
||||
|
||||
if (names_is_nonempty)
|
||||
return llvm::join(loc_names.begin(), loc_names.end(), ";");
|
||||
|
||||
return "";
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) {
|
||||
if (auto* op = op_or_val.dyn_cast<mlir::Operation*>()) {
|
||||
auto name_from_loc = GetNameFromLoc(op->getLoc());
|
||||
auto name_from_loc = mlir::GetNameFromLoc(op->getLoc());
|
||||
if (!name_from_loc.empty()) return name_from_loc;
|
||||
// If the location is none of the expected types, then simply use name
|
||||
// generated using the op type.
|
||||
return std::string(op->getName().getStringRef());
|
||||
}
|
||||
auto val = op_or_val.dyn_cast<mlir::Value>();
|
||||
auto name_from_loc = GetNameFromLoc(val.getLoc());
|
||||
auto name_from_loc = mlir::GetNameFromLoc(val.getLoc());
|
||||
if (!name_from_loc.empty()) return name_from_loc;
|
||||
// If the location is none of the expected types, then simply use name
|
||||
// generated using the op type. Follow TF convention and append the result
|
||||
|
@ -794,6 +794,7 @@ cc_library(
|
||||
"transforms/tpu_identity_pruning.cc",
|
||||
"transforms/tpu_merge_variables_with_execute.cc",
|
||||
"transforms/tpu_outside_compilation_cluster.cc",
|
||||
"transforms/tpu_resource_read_for_write.cc",
|
||||
"transforms/tpu_rewrite_pass.cc",
|
||||
"transforms/tpu_sharding_identification_pass.cc",
|
||||
"transforms/tpu_space_to_depth_pass.cc",
|
||||
@ -960,6 +961,7 @@ cc_library(
|
||||
"//tensorflow/cc/saved_model:loader_lite",
|
||||
"//tensorflow/cc/saved_model:loader_util",
|
||||
"//tensorflow/compiler/jit:shape_inference_helpers",
|
||||
"//tensorflow/compiler/mlir:name_utils",
|
||||
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||
"//tensorflow/compiler/tf2xla:functionalize_control_flow",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
|
@ -2,7 +2,6 @@ load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_copts",
|
||||
"tf_cuda_library",
|
||||
"tfe_xla_copts",
|
||||
)
|
||||
|
||||
package(
|
||||
@ -20,7 +19,7 @@ tf_cuda_library(
|
||||
srcs = [
|
||||
"c_api_unified_experimental_mlir.cc",
|
||||
],
|
||||
copts = tf_copts() + tfe_xla_copts(),
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
|
@ -452,7 +452,8 @@ Status MlirAbstractOp::SetAttrFloat(const char* attr_name, float value) {
|
||||
return Unimplemented("SetAttrFloat has not been implemented yet.");
|
||||
}
|
||||
Status MlirAbstractOp::SetAttrBool(const char* attr_name, bool value) {
|
||||
return Unimplemented("SetAttrBool has not been implemented yet.");
|
||||
attrs_[attr_name] = BoolAttr::get(value, context_);
|
||||
return Status::OK();
|
||||
}
|
||||
Status MlirAbstractOp::SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) {
|
||||
|
@ -250,33 +250,6 @@ ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) {
|
||||
// tf_executor.fetch
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
void Print(FetchOp fetch, OpAsmPrinter &p) {
|
||||
p << fetch.getOperationName();
|
||||
if (fetch.getNumOperands() > 0) {
|
||||
p << ' ';
|
||||
p.printOperands(fetch.operand_begin(), fetch.operand_end());
|
||||
p << " : ";
|
||||
interleaveComma(fetch.getOperandTypes(), p);
|
||||
}
|
||||
p.printOptionalAttrDict(fetch.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseFetchOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> opInfo;
|
||||
SmallVector<Type, 2> types;
|
||||
llvm::SMLoc loc = parser.getCurrentLocation();
|
||||
return failure(parser.parseOperandList(opInfo) ||
|
||||
(!opInfo.empty() && parser.parseColonTypeList(types)) ||
|
||||
parser.resolveOperands(opInfo, types, loc, result.operands) ||
|
||||
parser.parseOptionalAttrDict(result.attributes)
|
||||
|
||||
);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf_executor.island
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -411,31 +384,6 @@ ParseResult ParseIslandOp(OpAsmParser &parser, OperationState &result) {
|
||||
// tf_executor.yield
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
void Print(YieldOp yield, OpAsmPrinter &p) {
|
||||
p << yield.getOperationName();
|
||||
if (yield.getNumOperands() > 0) {
|
||||
p << ' ';
|
||||
p.printOperands(yield.operand_begin(), yield.operand_end());
|
||||
p << " : ";
|
||||
interleaveComma(yield.getOperandTypes(), p);
|
||||
}
|
||||
p.printOptionalAttrDict(yield.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseYieldOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> op_info;
|
||||
SmallVector<Type, 2> types;
|
||||
llvm::SMLoc loc = parser.getCurrentLocation();
|
||||
return failure(parser.parseOperandList(op_info) ||
|
||||
(!op_info.empty() && parser.parseColonTypeList(types)) ||
|
||||
parser.resolveOperands(op_info, types, loc, result.operands) ||
|
||||
parser.parseOptionalAttrDict(result.attributes));
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf_executor.Switch
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -848,23 +796,6 @@ LogicalResult Verify(NextIterationSourceOp source) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void Print(NextIterationSourceOp next_iteration, OpAsmPrinter &p) {
|
||||
p << next_iteration.getOperationName() << " : " << next_iteration.getType(0);
|
||||
p.printOptionalAttrDict(next_iteration.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseNextIterationSourceOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
SmallVector<Type, 1> types;
|
||||
if (parser.parseColonTypeList(types)) return failure();
|
||||
|
||||
MLIRContext *context = parser.getBuilder().getContext();
|
||||
Type token_type = TokenType::get(context);
|
||||
Type control_type = ControlType::get(context);
|
||||
result.addTypes({types.front(), token_type, control_type});
|
||||
return parser.parseOptionalAttrDict(result.attributes);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -891,36 +822,6 @@ LogicalResult Verify(NextIterationSinkOp sink) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void Print(NextIterationSinkOp next_iteration, OpAsmPrinter &p) {
|
||||
p << next_iteration.getOperationName() << " [";
|
||||
p.printOperand(next_iteration.getOperand(0));
|
||||
p << "] ";
|
||||
p.printOperands(llvm::drop_begin(next_iteration.getOperands(), 1));
|
||||
p << " : " << next_iteration.getOperand(1).getType();
|
||||
p.printOptionalAttrDict(next_iteration.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseNextIterationSinkOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> op_infos;
|
||||
llvm::SMLoc loc = parser.getCurrentLocation();
|
||||
|
||||
// First type is always the token consumed from the NextIteration.source
|
||||
Type token_type = TokenType::get(parser.getBuilder().getContext());
|
||||
SmallVector<Type, 1> types = {token_type};
|
||||
|
||||
if (parser.parseOperandList(op_infos, 1, OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
|
||||
return failure();
|
||||
|
||||
Type control_type = ControlType::get(parser.getBuilder().getContext());
|
||||
types.append(op_infos.size() - 2, control_type);
|
||||
if (parser.resolveOperands(op_infos, types, loc, result.operands))
|
||||
return failure();
|
||||
|
||||
return parser.parseOptionalAttrDict(result.attributes);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -959,32 +860,6 @@ ParseResult ParseExitOp(OpAsmParser &parser, OperationState &result) {
|
||||
// tf_executor.ControlTrigger
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
void Print(ControlTriggerOp trigger, OpAsmPrinter &p) {
|
||||
p << trigger.getOperationName() << ' ';
|
||||
p.printOperands(trigger.getOperands());
|
||||
p.printOptionalAttrDict(trigger.getAttrs());
|
||||
}
|
||||
|
||||
ParseResult ParseControlTriggerOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> op_infos;
|
||||
SmallVector<Type, 1> types;
|
||||
llvm::SMLoc loc = parser.getCurrentLocation();
|
||||
|
||||
if (parser.parseOperandList(op_infos)) return failure();
|
||||
Type control_type = ControlType::get(parser.getBuilder().getContext());
|
||||
types.append(op_infos.size(), control_type);
|
||||
if (parser.resolveOperands(op_infos, types, loc, result.operands))
|
||||
return failure();
|
||||
|
||||
// Single control as the only output
|
||||
result.types.push_back(control_type);
|
||||
return parser.parseOptionalAttrDict(result.attributes);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf_executor.LoopCond
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -47,10 +47,12 @@ def TfExecutor_Dialect : Dialect {
|
||||
}
|
||||
|
||||
// Control type.
|
||||
def TfeControlType : Type<CPred<"$_self.isa<ControlType>()">, "control">;
|
||||
def TfeControlType : Type<CPred<"$_self.isa<ControlType>()">, "control">,
|
||||
BuildableType<"$_builder.getType<ControlType>()">;
|
||||
|
||||
// Token type.
|
||||
def TfeTokenType : Type<CPred<"$_self.isa<TokenType>()">, "token">;
|
||||
def TfeTokenType : Type<CPred<"$_self.isa<TokenType>()">, "token">,
|
||||
BuildableType<"$_builder.getType<TokenType>()">;
|
||||
|
||||
// TODO(hinsu): Define and use TensorType instead of AnyType for data operands
|
||||
// and results. For example, MergeOp output type.
|
||||
@ -148,7 +150,11 @@ def TfExecutor_FetchOp : TfExecutor_Op<"fetch",
|
||||
}]>
|
||||
];
|
||||
|
||||
let assemblyFormat = "($fetches^ `:` type($fetches))? attr-dict";
|
||||
|
||||
let verifier = ?;
|
||||
let printer = ?;
|
||||
let parser = ?;
|
||||
}
|
||||
|
||||
def TfExecutor_IslandOp : TfExecutor_Op<"island",
|
||||
@ -229,7 +235,11 @@ def TfExecutor_YieldOp : TfExecutor_Op<"yield",
|
||||
}]>
|
||||
];
|
||||
|
||||
let assemblyFormat = "($fetches^ `:` type($fetches))? attr-dict";
|
||||
|
||||
let verifier = ?;
|
||||
let printer = ?;
|
||||
let parser = ?;
|
||||
}
|
||||
|
||||
def TfExecutor_SwitchOp : TfExecutor_Op<"Switch",
|
||||
@ -466,6 +476,10 @@ def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source",
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = "`:` type($output) attr-dict";
|
||||
|
||||
let printer = ?;
|
||||
let parser = ?;
|
||||
}
|
||||
|
||||
|
||||
@ -527,6 +541,11 @@ def TfExecutor_NextIterationSinkOp : TfExecutor_Op<"NextIteration.Sink",
|
||||
result.attributes.append(attributes.begin(), attributes.end());
|
||||
}]>
|
||||
];
|
||||
|
||||
let assemblyFormat = " `[` $token `]` $input (`,` $controlInputs^)? `:` type($input) attr-dict";
|
||||
|
||||
let printer = ?;
|
||||
let parser = ?;
|
||||
}
|
||||
|
||||
def TfExecutor_ExitOp : TfExecutor_Op<"Exit",
|
||||
@ -552,7 +571,7 @@ def TfExecutor_ExitOp : TfExecutor_Op<"Exit",
|
||||
.Attr("T: type")
|
||||
|
||||
For example:
|
||||
%1:2 = tf_executor.Exit %0#0 {T: "tfdtype$DT_INT32"} : tensor<*xi32>
|
||||
%1:2 = tf_executor.Exit %0#0 : tensor<*xi32> {T: "tfdtype$DT_INT32"}
|
||||
|
||||
Note: Additional result corresponds to the control output.
|
||||
}];
|
||||
@ -607,6 +626,11 @@ def TfExecutor_ControlTriggerOp : TfExecutor_Op<"ControlTrigger",
|
||||
result.attributes.append(attributes.begin(), attributes.end());
|
||||
}]>
|
||||
];
|
||||
|
||||
let assemblyFormat = "$controlInputs attr-dict";
|
||||
|
||||
let printer = ?;
|
||||
let parser = ?;
|
||||
}
|
||||
|
||||
def TfExecutor_LoopCondOp : TfExecutor_Op<"LoopCond",
|
||||
|
@ -52,6 +52,12 @@ an output element, this operation computes \\(y = |x|\\).
|
||||
def TF_AcosOp : TF_Op<"Acos", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Computes acos of x element-wise.";
|
||||
|
||||
let description = [{
|
||||
Provided an input tensor, the `tf.math.acos` operation returns the inverse cosine of each element of the tensor. If `y = tf.math.cos(x)` then, `x = tf.math.acos(y)`.
|
||||
|
||||
Input range is `[-1, 1]` and the output has a range of `[0, pi]`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x
|
||||
);
|
||||
@ -94,6 +100,10 @@ def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutA
|
||||
let description = [{
|
||||
*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
|
||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
|
||||
|
||||
Given two input tensors, the `tf.add` operation computes the sum for every element in the tensor.
|
||||
|
||||
Both input and output have a range `(-inf, inf)`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
@ -136,31 +146,6 @@ Inputs must be of same size and shape.
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x + y element-wise.";
|
||||
|
||||
let description = [{
|
||||
*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
|
||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_AdjustContrastv2Op : TF_Op<"AdjustContrastv2", [NoSideEffect]> {
|
||||
let summary = "Adjust the contrast of one or more images.";
|
||||
|
||||
@ -1740,6 +1725,24 @@ def TF_ConcatV2Op : TF_Op<"ConcatV2", [NoSideEffect]> {
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_ConfigureDistributedTPUOp : TF_Op<"ConfigureDistributedTPU", []> {
|
||||
let summary = [{
|
||||
Sets up the centralized structures for a distributed TPU system.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
StrAttr:$embedding_config,
|
||||
StrAttr:$tpu_embedding_config,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_global_init,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$enable_whole_mesh_compilations,
|
||||
DefaultValuedAttr<BoolAttr, "true">:$compilation_failure_closes_chips
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_StrTensor:$topology
|
||||
);
|
||||
}
|
||||
|
||||
def TF_ConjOp : TF_Op<"Conj", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Returns the complex conjugate of a complex number.";
|
||||
|
||||
@ -2786,27 +2789,6 @@ def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape, TF_SameOpe
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns 0 if the denominator is zero.";
|
||||
|
||||
let description = [{
|
||||
*NOTE*: `DivNoNan` supports broadcasting. More about broadcasting
|
||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$x,
|
||||
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_DynamicStitchOp : TF_Op<"DynamicStitch", [NoSideEffect, SameVariadicOperandSize]> {
|
||||
let summary = [{
|
||||
Interleave the values from the `data` tensors into a single tensor.
|
||||
@ -3853,6 +3835,95 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_FusedBatchNormV2Op : TF_Op<"FusedBatchNormV2", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
|
||||
let summary = "Batch normalization.";
|
||||
|
||||
let description = [{
|
||||
Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
|
||||
The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32]>:$x,
|
||||
F32Tensor:$scale,
|
||||
F32Tensor:$offset,
|
||||
F32Tensor:$mean,
|
||||
F32Tensor:$variance,
|
||||
|
||||
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
|
||||
DefaultValuedAttr<F32Attr, "1.0f">:$exponential_avg_factor,
|
||||
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
|
||||
DefaultValuedAttr<BoolAttr, "true">:$is_training
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32]>:$y,
|
||||
F32Tensor:$batch_mean,
|
||||
F32Tensor:$batch_variance,
|
||||
F32Tensor:$reserve_space_1,
|
||||
F32Tensor:$reserve_space_2
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TF_FoldOperandsTransposeInterface:
|
||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
||||
|
||||
// TF_LayoutSensitiveInterface:
|
||||
StringRef GetOptimalLayout(const RuntimeDevices& devices);
|
||||
LogicalResult UpdateDataFormat(StringRef data_format);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
|
||||
let summary = "Batch normalization.";
|
||||
|
||||
let description = [{
|
||||
Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
|
||||
The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32]>:$x,
|
||||
F32Tensor:$scale,
|
||||
F32Tensor:$offset,
|
||||
F32Tensor:$mean,
|
||||
F32Tensor:$variance,
|
||||
|
||||
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
|
||||
DefaultValuedAttr<F32Attr, "1.0f">:$exponential_avg_factor,
|
||||
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
|
||||
DefaultValuedAttr<BoolAttr, "true">:$is_training
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32]>:$y,
|
||||
F32Tensor:$batch_mean,
|
||||
F32Tensor:$batch_variance,
|
||||
F32Tensor:$reserve_space_1,
|
||||
F32Tensor:$reserve_space_2,
|
||||
F32Tensor:$reserve_space_3
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TF_FoldOperandsTransposeInterface:
|
||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
||||
|
||||
// TF_LayoutSensitiveInterface:
|
||||
StringRef GetOptimalLayout(const RuntimeDevices& devices);
|
||||
LogicalResult UpdateDataFormat(StringRef data_format);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> {
|
||||
let summary = "Gather slices from `params` according to `indices`.";
|
||||
|
||||
@ -6213,14 +6284,14 @@ retained with length 1.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
|
||||
TF_I32OrI64Tensor:$reduction_indices,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$keep_dims
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -6324,27 +6395,6 @@ def TF_MaxPoolGradOp : TF_Op<"MaxPoolGrad", [NoSideEffect]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns the max of x and y (i.e. x > y ? x : y) element-wise.";
|
||||
|
||||
let description = [{
|
||||
*NOTE*: `Maximum` supports broadcasting. More about broadcasting
|
||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
|
||||
let summary = "Computes the mean of elements across dimensions of a tensor.";
|
||||
|
||||
@ -6440,14 +6490,14 @@ retained with length 1.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
|
||||
TF_I32OrI64Tensor:$reduction_indices,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$keep_dims
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -7878,33 +7928,6 @@ tf.real(input) ==> [-2.25, 3.25]
|
||||
TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_RealDivOp : TF_Op<"RealDiv", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x / y element-wise for real types.";
|
||||
|
||||
let description = [{
|
||||
If `x` and `y` are reals, this will return the floating-point division.
|
||||
|
||||
*NOTE*: `Div` supports broadcasting. More about broadcasting
|
||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_ReciprocalOp : TF_Op<"Reciprocal", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Computes the reciprocal of x element-wise.";
|
||||
|
||||
@ -9314,6 +9337,18 @@ Generate a sharded filename. The filename is printf formatted as
|
||||
);
|
||||
}
|
||||
|
||||
def TF_ShutdownDistributedTPUOp : TF_Op<"ShutdownDistributedTPU", []> {
|
||||
let summary = "Shuts down a running distributed TPU system.";
|
||||
|
||||
let description = [{
|
||||
The op returns an error if no system is running.
|
||||
}];
|
||||
|
||||
let arguments = (ins);
|
||||
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def TF_SigmoidOp : TF_Op<"Sigmoid", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Computes sigmoid of `x` element-wise.";
|
||||
|
||||
@ -9832,6 +9867,41 @@ backpropagation,
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_SparseMatMulOp : TF_Op<"SparseMatMul", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Multiply matrix "a" by matrix "b".
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
The inputs must be two-dimensional matrices and the inner dimension of "a" must
|
||||
match the outer dimension of "b". Both "a" and "b" must be `Tensor`s not
|
||||
`SparseTensor`s. This op is optimized for the case where at least one of "a" or
|
||||
"b" is sparse, in the sense that they have a large proportion of zero values.
|
||||
The breakeven for using this versus a dense matrix multiply on one platform was
|
||||
30% zero values in the sparse matrix.
|
||||
|
||||
The gradient computation of this operation will only take advantage of sparsity
|
||||
in the input gradient when that gradient comes from a Relu.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F32]>:$a,
|
||||
TensorOf<[BF16, F32]>:$b,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$transpose_b,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$a_is_sparse,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$b_is_sparse
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
F32Tensor:$product
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Ta = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tb = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_SparseReshapeOp : TF_Op<"SparseReshape", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Reshapes a SparseTensor to represent values in a new dense shape.
|
||||
@ -11625,9 +11695,9 @@ array([[1, 2, 3, 1, 2, 3],
|
||||
TF_DerivedOperandTypeAttr Tmultiples = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
// TODO(parkers): Add folds for multiples = [1,...].
|
||||
// TODO(parkers): Add errors for negative multiples and multiples.size() !=
|
||||
// input.rank()
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_TopKV2Op : TF_Op<"TopKV2", [NoSideEffect]> {
|
||||
@ -12893,7 +12963,8 @@ create these operators.
|
||||
DefaultValuedAttr<I64ArrayAttr, "{1, 1, 1, 1}">:$dilations,
|
||||
DefaultValuedAttr<BoolAttr, "true">:$use_cudnn_on_gpu,
|
||||
DefaultValuedAttr<StrArrayAttr, "{}">:$fused_ops,
|
||||
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon
|
||||
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
|
||||
DefaultValuedAttr<F32Attr, "0.2f">:$leakyrelu_alpha
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
|
@ -157,20 +157,10 @@ class TF_TensorFlowType <string name, string description> :
|
||||
"TensorFlow " # description # " type">,
|
||||
BuildableType<"getType<mlir::TF::" # name # "Type>()">;
|
||||
|
||||
// Any tensor element type allowed in TensorFlow ops
|
||||
def TF_ElementType : Type<Or<[AnyFloat.predicate,
|
||||
AnySignlessInteger.predicate,
|
||||
AnyUnsignedInteger.predicate,
|
||||
AnyComplex.predicate,
|
||||
TF_TFDialectType.predicate]>,
|
||||
"tf.dtype">;
|
||||
|
||||
// Any TensorFlow tensor type
|
||||
def TF_Tensor : TensorOf<[TF_ElementType]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Integer types
|
||||
|
||||
// TODO(mgester) shouldn't this be SignedIntOfWidths?
|
||||
def TF_I32Or64 : SignlessIntOfWidths<[32, 64]>;
|
||||
|
||||
def TF_I32OrI64Tensor : TensorOf<[TF_I32Or64]>;
|
||||
@ -191,10 +181,11 @@ def TF_Uint64Tensor : TensorOf<[TF_Uint64]>;
|
||||
def TF_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>;
|
||||
|
||||
// Any signed integer type
|
||||
// TODO(mgester) shouldn't this be SignedIntOfWidths?
|
||||
def TF_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>;
|
||||
|
||||
// Any integer type
|
||||
def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt]>;
|
||||
def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">;
|
||||
|
||||
// Any integer tensor types
|
||||
def TF_IntTensor : TensorOf<[TF_Int]>;
|
||||
@ -208,8 +199,8 @@ def TF_Quint8 : TF_TensorFlowType<"Quint8", "quint8">;
|
||||
def TF_Quint16 : TF_TensorFlowType<"Quint16", "quint16">;
|
||||
|
||||
// Any quantized type
|
||||
def TF_AnyQuantized : AnyTypeOf<[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8,
|
||||
TF_Quint16]>;
|
||||
def TF_Quantized : AnyTypeOf<[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8,
|
||||
TF_Quint16], "quantized">;
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Floating-point types
|
||||
|
||||
@ -217,8 +208,10 @@ def TF_F32Or64 : FloatOfWidths<[32, 64]>;
|
||||
|
||||
def TF_F32OrF64Tensor : TensorOf<[TF_F32Or64]>;
|
||||
|
||||
def TF_Float : AnyTypeOf<[F16, F32, F64, BF16], "floating-point">;
|
||||
|
||||
// Any floating-point tensor types
|
||||
def TF_FpTensor : TensorOf<[AnyFloat]>;
|
||||
def TF_FpTensor : TensorOf<[TF_Float]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Complex types
|
||||
@ -231,10 +224,9 @@ def TF_Complex64Tensor : TensorOf<[TF_Complex64]>;
|
||||
def TF_Complex128 : Complex<F<64>>;
|
||||
def TF_Complex128Tensor : TensorOf<[TF_Complex128]>;
|
||||
|
||||
def TF_AnyComplex : AnyTypeOf<[TF_Complex64, TF_Complex128],
|
||||
"64/128-bit complex type">;
|
||||
def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">;
|
||||
|
||||
def TF_ComplexTensor : TensorOf<[TF_AnyComplex]>;
|
||||
def TF_ComplexTensor : TensorOf<[TF_Complex]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// String/variant/resource types
|
||||
@ -248,28 +240,113 @@ def TF_VariantTensor : TensorOf<[TF_Variant]>;
|
||||
def TF_Resource : TF_TensorFlowType<"Resource", "resource">;
|
||||
def TF_ResourceTensor : TensorOf<[TF_Resource]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Reference types
|
||||
|
||||
// Float reference types
|
||||
def TF_F16Ref : TF_TensorFlowType<"HalfRef", "f16ref">;
|
||||
def TF_F32Ref : TF_TensorFlowType<"FloatRef", "f32ref">;
|
||||
def TF_F64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">;
|
||||
def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">;
|
||||
|
||||
// Any float reference type
|
||||
def TF_FloatRef : AnyTypeOf<[TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_Bfloat16Ref],
|
||||
"floating-point reference">;
|
||||
|
||||
// Complex reference types
|
||||
def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">;
|
||||
def TF_Complex128Ref : TF_TensorFlowType<"Complex128Ref", "complex128ref">;
|
||||
|
||||
// Any complex reference type
|
||||
def TF_ComplexRef : AnyTypeOf<[TF_Complex64Ref, TF_Complex128Ref], "complex reference">;
|
||||
|
||||
// Integer reference types
|
||||
def TF_Int8Ref : TF_TensorFlowType<"Int8Ref", "i8ref">;
|
||||
def TF_Int16Ref : TF_TensorFlowType<"Int16Ref", "i16ref">;
|
||||
def TF_Int32Ref : TF_TensorFlowType<"Int32Ref", "i32ref">;
|
||||
def TF_Int64Ref : TF_TensorFlowType<"Int64Ref", "i64ref">;
|
||||
|
||||
def TF_Uint8Ref : TF_TensorFlowType<"Uint8Ref", "ui8ref">;
|
||||
def TF_Uint16Ref : TF_TensorFlowType<"Uint16Ref", "ui16ref">;
|
||||
def TF_Uint32Ref : TF_TensorFlowType<"Uint32Ref", "ui32ref">;
|
||||
def TF_Uint64Ref : TF_TensorFlowType<"Uint64Ref", "ui64ref">;
|
||||
|
||||
// Any signed integer reference type
|
||||
def TF_SIntRef : AnyTypeOf<[TF_Int8Ref, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref],
|
||||
"signed integer reference">;
|
||||
|
||||
// Any unsigned integer reference type
|
||||
def TF_UIntRef : AnyTypeOf<[TF_Uint8Ref, TF_Uint16Ref, TF_Uint32Ref,
|
||||
TF_Uint64Ref], "unsigned integer reference">;
|
||||
|
||||
// Any integer reference type
|
||||
def TF_IntRef : AnyTypeOf<[TF_SIntRef, TF_UIntRef], "integer reference">;
|
||||
|
||||
// Quantized reference types
|
||||
def TF_Qint8Ref : TF_TensorFlowType<"Qint8Ref", "qint8ref">;
|
||||
def TF_Qint16Ref : TF_TensorFlowType<"Qint16Ref", "qint16ref">;
|
||||
def TF_Qint32Ref : TF_TensorFlowType<"Qint32Ref", "qint32ref">;
|
||||
def TF_Quint8Ref : TF_TensorFlowType<"Quint8Ref", "quint8ref">;
|
||||
def TF_Quint16Ref : TF_TensorFlowType<"Quint16Ref", "quint16ref">;
|
||||
|
||||
// Any quantized reference type
|
||||
def TF_QuantizedRef : AnyTypeOf<[TF_Qint8Ref, TF_Qint16Ref, TF_Qint32Ref,
|
||||
TF_Quint8Ref, TF_Quint16Ref], "quantized reference">;
|
||||
|
||||
// Other reference types
|
||||
def TF_BoolRef : TF_TensorFlowType<"BoolRef", "boolref">;
|
||||
def TF_ResourceRef : TF_TensorFlowType<"ResourceRef", "resourceref">;
|
||||
def TF_StringRef : TF_TensorFlowType<"StringRef", "stringref">;
|
||||
def TF_VariantRef : TF_TensorFlowType<"VariantRef", "variantref">;
|
||||
|
||||
// Reference tensor types
|
||||
def TF_FpRefTensor : TensorOf<[TF_FloatRef]>;
|
||||
def TF_I32OrI64RefTensor : TensorOf<[TF_Int32Ref, TF_Int64Ref]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Multi-category type constraints
|
||||
|
||||
def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32Or64]>;
|
||||
|
||||
def TF_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TF_I32Or64]>;
|
||||
def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32Or64]>;
|
||||
|
||||
// Any integer or floating-point tensor types
|
||||
def TF_IntOrFpTensor : TensorOf<[TF_Int, AnyFloat]>;
|
||||
def TF_IntOrFpTensor : TensorOf<[TF_Int, TF_Float]>;
|
||||
|
||||
def TF_SintOrFpTensor : TensorOf<[TF_SInt, AnyFloat]>;
|
||||
def TF_SintOrFpTensor : TensorOf<[TF_SInt, TF_Float]>;
|
||||
|
||||
def TF_FpOrComplexTensor : TensorOf<[AnyFloat, TF_AnyComplex]>;
|
||||
def TF_FpOrComplexTensor : TensorOf<[TF_Float, TF_Complex]>;
|
||||
|
||||
def TF_AnyNumber : AnyTypeOf<[TF_Int, AnyFloat, TF_AnyQuantized, TF_AnyComplex],
|
||||
"number">;
|
||||
def TF_Number : AnyTypeOf<[TF_Int, TF_Float, TF_Quantized, TF_Complex],
|
||||
"number">;
|
||||
def TF_NumberRef : AnyTypeOf<[TF_IntRef, TF_FloatRef, TF_QuantizedRef,
|
||||
TF_ComplexRef], "number reference">;
|
||||
|
||||
def TF_NumberTensor : TensorOf<[TF_AnyNumber]>;
|
||||
def TF_NumberTensor : TensorOf<[TF_Number]>;
|
||||
def TF_NumberRefTensor : TensorOf<[TF_NumberRef]>;
|
||||
|
||||
def TF_NumberOrStr : AnyTypeOf<[AnyFloat, TF_SInt, TF_AnyComplex, TF_Uint8, TF_Str]>;
|
||||
def TF_NumberOrStr : AnyTypeOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8,
|
||||
TF_Str]>;
|
||||
def TF_NumberOrStrTensor : TensorOf<[TF_NumberOrStr]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tensor and tensor element types
|
||||
|
||||
// Bool type
|
||||
def TF_Bool : I<1>;
|
||||
|
||||
// Any tensor element type allowed in TensorFlow ops
|
||||
// (see https://www.tensorflow.org/api_docs/python/tf/dtypes/DType)
|
||||
def TF_ElementType : Type<Or<[TF_Float.predicate,
|
||||
TF_Complex.predicate,
|
||||
TF_Int.predicate,
|
||||
TF_Bool.predicate,
|
||||
TF_TFDialectType.predicate]>,
|
||||
"tf.dtype">;
|
||||
|
||||
// Any TensorFlow tensor type
|
||||
def TF_Tensor : TensorOf<[TF_ElementType]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow attribute definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -570,36 +570,6 @@ def TF_PlaceholderWithDefaultOp : TF_Op<"PlaceholderWithDefault", [NoSideEffect]
|
||||
DerivedAttr shape = TF_DerivedResultShapeAttr;
|
||||
}
|
||||
|
||||
def TF_SparseMatMulOp : TF_Op<"SparseMatMul", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
SparseMatMul is MatMul with hints on the sparseness of the matrices.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Similar to MatMul, with a_is_sparse and b_is_sparse indicating whether a and b
|
||||
are sparse matrices.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F32]>:$a,
|
||||
TensorOf<[BF16, F32]>:$b,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "true">:$a_is_sparse,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$b_is_sparse,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$transpose_b
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32]>:$product
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Ta = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tb = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
|
||||
def TF_StatefulPartitionedCallOp : TF_Op<"StatefulPartitionedCall",
|
||||
[CallOpInterface]> {
|
||||
let summary =
|
||||
@ -1213,63 +1183,6 @@ def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> {
|
||||
let verifier = [{ return VerifyPartitionedCall(*this); }];
|
||||
}
|
||||
|
||||
class TF_FusedBatchNormOpBase<string Name> : TF_Op<Name, [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
|
||||
let summary = "Batch normalization.";
|
||||
|
||||
let description = [{
|
||||
Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
|
||||
The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32]>:$x,
|
||||
F32Tensor:$scale,
|
||||
F32Tensor:$offset,
|
||||
F32Tensor:$mean,
|
||||
F32Tensor:$variance,
|
||||
|
||||
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
|
||||
DefaultValuedAttr<F32Attr, "1.0f">:$exponential_avg_factor,
|
||||
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
|
||||
DefaultValuedAttr<BoolAttr, "true">:$is_training
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TF_FoldOperandsTransposeInterface:
|
||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
||||
|
||||
// TF_LayoutSensitiveInterface:
|
||||
StringRef GetOptimalLayout(const RuntimeDevices& devices);
|
||||
LogicalResult UpdateDataFormat(StringRef data_format);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_FusedBatchNormV2Op : TF_FusedBatchNormOpBase<"FusedBatchNormV2"> {
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32]>:$y,
|
||||
F32Tensor:$batch_mean,
|
||||
F32Tensor:$batch_variance,
|
||||
F32Tensor:$reserve_space_1,
|
||||
F32Tensor:$reserve_space_2
|
||||
);
|
||||
}
|
||||
|
||||
def TF_FusedBatchNormV3Op : TF_FusedBatchNormOpBase<"FusedBatchNormV3"> {
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32]>:$y,
|
||||
F32Tensor:$batch_mean,
|
||||
F32Tensor:$batch_variance,
|
||||
F32Tensor:$reserve_space_1,
|
||||
F32Tensor:$reserve_space_2,
|
||||
F32Tensor:$reserve_space_3
|
||||
);
|
||||
}
|
||||
|
||||
def TF_BatchFunctionOp : TF_Op<"BatchFunction", [AttrSizedOperandSegments]> {
|
||||
let summary = [{
|
||||
Batches all the inputs tensors to the computation done by the function.
|
||||
@ -1341,4 +1254,98 @@ must be a Tensor or a list/tuple of Tensors.
|
||||
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x + y element-wise.";
|
||||
|
||||
let description = [{
|
||||
*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
|
||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref]>:$x,
|
||||
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns 0 if the denominator is zero.";
|
||||
|
||||
let description = [{
|
||||
*NOTE*: `DivNoNan` supports broadcasting. More about broadcasting
|
||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$x,
|
||||
TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns the max of x and y (i.e. x > y ? x : y) element-wise.";
|
||||
|
||||
let description = [{
|
||||
*NOTE*: `Maximum` supports broadcasting. More about broadcasting
|
||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$x,
|
||||
TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_RealDivOp : TF_Op<"RealDiv", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x / y element-wise for real types.";
|
||||
|
||||
let description = [{
|
||||
If `x` and `y` are reals, this will return the floating-point division.
|
||||
|
||||
*NOTE*: `Div` supports broadcasting. More about broadcasting
|
||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$x,
|
||||
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
#endif // TF_OPS
|
||||
|
@ -1783,6 +1783,87 @@ static LogicalResult Verify(TensorScatterUpdateOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TileOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Verifies that,
|
||||
//
|
||||
// - input has at least rank 1
|
||||
// - multiples is rank 1
|
||||
// - multiples.size() == input.rank()
|
||||
// - input.rank() == output.rank()
|
||||
// - Elements in multiples are non-negative
|
||||
// - input.shape[i] * multiples[i] == output.shape[i]
|
||||
// for i in [0, input.rank() - 1]
|
||||
|
||||
static LogicalResult Verify(TileOp op) {
|
||||
auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
|
||||
auto multiples_type = op.multiples().getType().dyn_cast<RankedTensorType>();
|
||||
auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
if (multiples_type && multiples_type.getRank() != 1) {
|
||||
return op.emitOpError() << "expected multiples to be rank 1, got rank = "
|
||||
<< multiples_type.getRank();
|
||||
}
|
||||
|
||||
if (input_type && multiples_type && multiples_type.hasStaticShape() &&
|
||||
(input_type.getRank() != multiples_type.getNumElements() ||
|
||||
(input_type.getRank() == 0 && multiples_type.getNumElements() == 1))) {
|
||||
return op.emitOpError()
|
||||
<< "expected size of multiples equal to rank of input"
|
||||
<< ", got multiples of size " << multiples_type.getNumElements()
|
||||
<< ", and input of rank " << input_type.getRank();
|
||||
}
|
||||
|
||||
if (input_type && output_type) {
|
||||
if (input_type.getRank() != output_type.getRank()) {
|
||||
return op.emitOpError()
|
||||
<< "expected rank of input to equal to rank of output"
|
||||
<< ", got input of rank " << input_type.getRank()
|
||||
<< ", and output of rank " << output_type.getRank();
|
||||
}
|
||||
|
||||
DenseIntElementsAttr multiples_attr;
|
||||
if (matchPattern(op.multiples(), m_Constant(&multiples_attr))) {
|
||||
for (int32_t i = 0, e = input_type.getRank(); i < e; ++i) {
|
||||
const int64_t input_dim = input_type.getDimSize(i);
|
||||
const int64_t output_dim = output_type.getDimSize(i);
|
||||
const int64_t m = multiples_attr.getValue<APInt>(i).getSExtValue();
|
||||
|
||||
if (m < 0) {
|
||||
return op.emitOpError()
|
||||
<< "expected multiples to be non-negative, got "
|
||||
<< "multiples[" << i << "] = " << m;
|
||||
}
|
||||
|
||||
if (!ShapedType::isDynamic(input_dim) &&
|
||||
!ShapedType::isDynamic(output_dim) && output_dim != input_dim * m) {
|
||||
return op.emitOpError()
|
||||
<< "requires input.shape[" << i << "] (" << input_dim << ")"
|
||||
<< " * " << m << " to be equal to "
|
||||
<< "output.shape[" << i << "] (" << output_dim << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
|
||||
DenseIntElementsAttr multiples_attr;
|
||||
if (matchPattern(multiples(), m_Constant(&multiples_attr))) {
|
||||
// Return input directly when multiples are all ones,
|
||||
// regardless what input is.
|
||||
if (multiples_attr.isSplat() &&
|
||||
multiples_attr.getSplatValue<APInt>().getSExtValue() == 1) {
|
||||
return input();
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TopKV2Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -285,7 +285,7 @@ func @empty_island_multiple_data_results(%arg0: tensor<*xf32>, %arg1: tensor<*xi
|
||||
// and certain tf_executor ops are added correctly.
|
||||
|
||||
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
|
||||
// CHECK: tf_executor.NextIteration.Sink [{{.*}}] {{.*}}, %[[CONTROL]]
|
||||
// CHECK: tf_executor.NextIteration.Sink[{{.*}}] {{.*}}, %[[CONTROL]]
|
||||
func @next_iteration_sink_control_input() {
|
||||
tf_executor.graph {
|
||||
%source:3 = tf_executor.NextIteration.Source : tensor<*xi32>
|
||||
|
@ -568,6 +568,14 @@ func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2:
|
||||
return %0: tensor<*xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testTileMultiplesAllOnes
|
||||
func @testTileMultiplesAllOnes(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%cst = constant dense <[1, 1]> : tensor<2xi32>
|
||||
// CHECK: return %arg0
|
||||
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32>
|
||||
return %0: tensor<2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testLogicalNotOfEqual
|
||||
func @testLogicalNotOfEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> {
|
||||
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
|
||||
|
@ -220,7 +220,7 @@ func @merge_islands_only() {
|
||||
%11:2 = tf_executor.island(%10#1) wraps "tf.opF"() : () -> tensor<i32>
|
||||
%12:2 = tf_executor.island wraps "tf.opG"(%10#0, %11#0) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%13 = tf_executor.ControlTrigger %2, %12#1, %9#1
|
||||
tf_executor.NextIteration.Sink [%3#1] %12#0, %13 : tensor<*xi32>
|
||||
tf_executor.NextIteration.Sink[%3#1] %12#0, %13 : tensor<*xi32>
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
@ -244,7 +244,7 @@ func @merge_islands_only() {
|
||||
// CHECK-NEXT: %[[OP_G:[0-9]*]] = "tf.opG"(%[[OP_E]], %[[OP_F]])
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_G]] : tensor<*xi32>
|
||||
// CHECK: %[[CT:.*]] = tf_executor.ControlTrigger %[[ISLAND_1]], %[[ISLAND_3_control]], %[[EXIT_control]]
|
||||
// CHECK-NEXT: tf_executor.NextIteration.Sink [%[[NEXTIT_SRC_token]]] %[[ISLAND_3]], %[[CT]]
|
||||
// CHECK-NEXT: tf_executor.NextIteration.Sink[%[[NEXTIT_SRC_token]]] %[[ISLAND_3]], %[[CT]]
|
||||
|
||||
|
||||
// Test no merging took place as cycle would be formed otherwise.
|
||||
|
@ -7,7 +7,7 @@
|
||||
# CHECK: %[[NEXTITERATION:[a-z0-9]+]], %[[NEXTITERATION_token:[a-z0-9]+]], {{.*}} = tf_executor.NextIteration.Source
|
||||
# CHECK: tf_executor.Merge {{.*}} %[[NEXTITERATION]]
|
||||
|
||||
# CHECK: tf_executor.NextIteration.Sink [%[[NEXTITERATION_token]]]
|
||||
# CHECK: tf_executor.NextIteration.Sink[%[[NEXTITERATION_token]]]
|
||||
|
||||
node {
|
||||
name: "Const"
|
||||
|
@ -3468,3 +3468,85 @@ func @testCumprod(%arg: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<i32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTile(%arg0: tensor<2x3x?xf32>) {
|
||||
%cst = constant dense <[2, 3, 4]> : tensor<3xi32>
|
||||
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3x?xf32>, tensor<3xi32>) -> tensor<4x9x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTileMultipleNotRank1(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1xi32>) {
|
||||
// expected-error @+1 {{expected multiples to be rank 1, got rank = 2}}
|
||||
%0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<1x1xi32>) -> tensor<2x3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTileInputRankNotEqualToMultiplesSize(%arg0: tensor<2x3xf32>, %arg1: tensor<3xi32>) {
|
||||
// expected-error @+1 {{expected size of multiples equal to rank of input, got multiples of size 3, and input of rank 2}}
|
||||
%0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<3xi32>) -> tensor<2x3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTileInputRankNotEqualToOutputRank(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) {
|
||||
// expected-error @+1 {{expected rank of input to equal to rank of output, got input of rank 2, and output of rank 3}}
|
||||
%0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3x1xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTileNegativeMultiples(%arg0: tensor<2x3xf32>) {
|
||||
%cst = constant dense <[-1, 1]> : tensor<2xi32>
|
||||
// expected-error @+1 {{expected multiples to be non-negative, got multiples[0] = -1}}
|
||||
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTileInvalidOutputShape(%arg0: tensor<2x3xf32>) {
|
||||
%cst = constant dense <[2, 3]> : tensor<2xi32>
|
||||
// expected-error @+1 {{requires input.shape[1] (3) * 3 to be equal to output.shape[1] (6)}}
|
||||
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<4x6xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test reference variable support for some ops (no errors expected)
|
||||
|
||||
// CHECK-LABEL: @testMaximumWithRef
|
||||
func @testMaximumWithRef(%arg0: tensor<!tf.f32ref>, %arg1: tensor<f32>) -> tensor<f32> {
|
||||
// CHECK: tf.Maximum
|
||||
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<!tf.f32ref>, tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @testAddV2WithRef
|
||||
func @testAddV2WithRef(%arg0: tensor<!tf.int16ref>, %arg1: tensor<i16>) -> tensor<i16> {
|
||||
// CHECK: tf.AddV2
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<!tf.int16ref>, tensor<i16>) -> tensor<i16>
|
||||
return %0 : tensor<i16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @testRealDivWithRef
|
||||
func @testRealDivWithRef(%arg0: tensor<f64>, %arg1: tensor<!tf.f64ref>) -> tensor<f64> {
|
||||
// CHECK: tf.RealDivOp
|
||||
%0 = "tf.RealDivOp"(%arg0, %arg1) : (tensor<f64>, tensor<!tf.f64ref>) -> tensor<f64>
|
||||
return %0 : tensor<f64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @testDivNoNanWithRef
|
||||
func @testDivNoNanWithRef(%arg0: tensor<f32>, %arg1: tensor<!tf.f32ref>) -> tensor<f32> {
|
||||
// CHECK: tf.DivNoNanOp
|
||||
%0 = "tf.DivNoNanOp"(%arg0, %arg1) : (tensor<f32>, tensor<!tf.f32ref>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
@ -433,7 +433,7 @@ func @nextiteration(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
|
||||
%1:3 = tf_executor.NextIteration.Source : tensor<*xf32>
|
||||
tf_executor.NextIteration.Sink[%1#1] %1#0 : tensor<*xf32>
|
||||
// CHECK: tf_executor.NextIteration.Source : tensor<*xf32>
|
||||
// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}} : tensor<*xf32>
|
||||
// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}} : tensor<*xf32>
|
||||
tf_executor.fetch %1#0 : tensor<*xf32>
|
||||
}
|
||||
return %0 : tensor<*xf32>
|
||||
@ -445,7 +445,7 @@ func @nextiteration_with_attributes(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*
|
||||
%1:3 = tf_executor.NextIteration.Source : tensor<*xf32> {attr3 = 32 : i64, tf_executor.attr_fetch = "some_value"}
|
||||
tf_executor.NextIteration.Sink[%1#1] %1#0 : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"}
|
||||
// CHECK: tf_executor.NextIteration.Source : tensor<*xf32> {attr3 = 32 : i64, tf_executor.attr_fetch = "some_value"}
|
||||
// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}} : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"}
|
||||
// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}} : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"}
|
||||
tf_executor.fetch %1#0 : tensor<*xf32>
|
||||
}
|
||||
return %0 : tensor<*xf32>
|
||||
@ -457,9 +457,9 @@ func @nextiteration_control(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*
|
||||
%1:3 = tf_executor.Switch %arg0, %arg1 : tensor<*xf32>
|
||||
%2:2 = tf_executor.Enter %arg0, %1#2, %1#2 frame "some/frame" : tensor<*xf32>
|
||||
%3:3 = tf_executor.NextIteration.Source : tensor<*xf32>
|
||||
tf_executor.NextIteration.Sink [%3#1] %3#0, %1#2 : tensor<*xf32>
|
||||
tf_executor.NextIteration.Sink[%3#1] %3#0, %1#2 : tensor<*xf32>
|
||||
// CHECK: tf_executor.NextIteration.Source : tensor<*xf32>
|
||||
// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}}, %{{.*}} : tensor<*xf32>
|
||||
// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}}, %{{.*}} : tensor<*xf32>
|
||||
tf_executor.fetch %3#0 : tensor<*xf32>
|
||||
}
|
||||
return %0 : tensor<*xf32>
|
||||
|
@ -0,0 +1,64 @@
|
||||
// RUN: tf-opt -tf-tpu-resource-read-for-write %s | FileCheck %s --dump-input=always
|
||||
|
||||
// CHECK-LABEL: func @write_only_resource
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>, [[ARG1:%.*]]: tensor<f32>, [[ARG2:%.*]]: tensor<*x!tf.resource<tensor<i32>>>)
|
||||
func @write_only_resource(%arg0: tensor<i32>, %arg1: tensor<f32>, %arg2: tensor<*x!tf.resource<tensor<i32>>>) {
|
||||
// CHECK-NEXT: [[READ:%.*]] = "tf.ReadVariableOp"([[ARG2]])
|
||||
// CHECK-NEXT: [[CLUSTER:%.*]]:2 = "tf_device.cluster_func"([[ARG0]], [[ARG1]], [[READ]])
|
||||
// CHECK-SAME: _tpu_replicate = "write"
|
||||
%0:2 = "tf_device.cluster_func"(%arg0, %arg1) {_tpu_replicate = "write", func = @write_func} : (tensor<i32>, tensor<f32>) -> (tensor<f32>, tensor<i32>)
|
||||
// CHECK-NEXT: "tf.AssignVariableOp"([[ARG2]], [[CLUSTER]]#1)
|
||||
"tf.AssignVariableOp"(%arg2, %0#1) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
|
||||
// CHECK-NEXT: return
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @write_func
|
||||
// CHECK-SAME: ({{%.*}}: tensor<i32>, {{%.*}}: tensor<f32>, {{%.*}}: tensor<i32>) -> (tensor<f32>, tensor<i32>)
|
||||
func @write_func(%arg0: tensor<i32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<i32>) {
|
||||
return %arg1, %arg0 : tensor<f32>, tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @read_write_resource
|
||||
func @read_write_resource(%arg0: tensor<i32>, %arg1: tensor<f32>, %arg2: tensor<*x!tf.resource<tensor<i32>>>) {
|
||||
// CHECK-COUNT-1: tf.ReadVariableOp
|
||||
%0 = "tf.ReadVariableOp"(%arg2) : (tensor<*x!tf.resource<tensor<i32>>>) -> tensor<i32>
|
||||
%1:2 = "tf_device.cluster_func"(%arg0, %arg1, %0) {_tpu_replicate = "read_write", func = @read_write_func} : (tensor<i32>, tensor<f32>, tensor<i32>) -> (tensor<f32>, tensor<i32>)
|
||||
"tf.AssignVariableOp"(%arg2, %1#1) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @read_write_func
|
||||
// CHECK-SAME: ({{%.*}}: tensor<i32>, {{%.*}}: tensor<f32>) -> (tensor<f32>, tensor<i32>)
|
||||
func @read_write_func(%arg0: tensor<i32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<i32>) {
|
||||
return %arg1, %arg0 : tensor<f32>, tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @multiple_write_resource
|
||||
func @multiple_write_resource(%arg0: tensor<i32>, %arg1: tensor<*x!tf.resource<tensor<i32>>>) {
|
||||
// CHECK-NOT: tf.ReadVariableOp
|
||||
%0:2 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "multiple_write", func = @multiple_write_func} : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
|
||||
"tf.AssignVariableOp"(%arg1, %0#0) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
|
||||
"tf.AssignVariableOp"(%arg1, %0#1) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @multiple_write_func
|
||||
// CHECK-SAME: ({{%.*}}: tensor<i32>) -> (tensor<i32>, tensor<i32>)
|
||||
func @multiple_write_func(%arg0: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
|
||||
return %arg0, %arg0 : tensor<i32>, tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @multiple_result_user
|
||||
func @multiple_result_user(%arg0: tensor<i32>, %arg1: tensor<*x!tf.resource<tensor<i32>>>) -> tensor<i32> {
|
||||
// CHECK-NOT: tf.ReadVariableOp
|
||||
%0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "multiple_uses", func = @multiple_result_user_func} : (tensor<i32>) -> tensor<i32>
|
||||
"tf.AssignVariableOp"(%arg1, %0) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @multiple_result_user_func
|
||||
// CHECK-SAME: ({{%.*}}: tensor<i32>) -> tensor<i32>
|
||||
func @multiple_result_user_func(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
return %arg0 : tensor<i32>
|
||||
}
|
@ -173,7 +173,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
func @tail_single_outside_compiled_op() {
|
||||
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: "tf.NoOp"
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
// CHECK-NEXT: {
|
||||
// CHECK-DAG: num_cores_per_replica = 1
|
||||
@ -190,7 +190,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
"tf_device.cluster"() ( {
|
||||
%a = "tf.A"() : () -> tensor<i32>
|
||||
"tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
|
||||
"tf.C"() : () -> ()
|
||||
"tf.NoOp"() : () -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
@ -200,7 +200,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
func @tail_single_outside_compiled_op_user() -> tensor<i32> {
|
||||
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: "tf.NoOp"
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
// CHECK-NEXT: {
|
||||
// CHECK-DAG: num_cores_per_replica = 1
|
||||
@ -217,7 +217,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
%cluster = "tf_device.cluster"() ( {
|
||||
%a = "tf.A"() : () -> tensor<i32>
|
||||
%b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
|
||||
"tf.C"() : () -> ()
|
||||
"tf.NoOp"() : () -> ()
|
||||
tf_device.return %b : tensor<i32>
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor<i32>
|
||||
// CHECK: return %[[LAUNCH_OUT]]
|
||||
@ -262,7 +262,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
%b = "tf.B"() : () -> tensor<i32>
|
||||
// CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"
|
||||
// CHECK-NEXT: %[[E_OUT:.*]] = "tf.E"
|
||||
// CHECK-NEXT: %[[E_OUT:.*]] = "tf.Const"
|
||||
// CHECK-NEXT: tf_device.return %[[C_OUT]], %[[E_OUT]]
|
||||
// CHECK-NEXT: {
|
||||
// CHECK-DAG: num_cores_per_replica = 1
|
||||
@ -279,7 +279,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
%cluster:5 = "tf_device.cluster"() ( {
|
||||
%c = "tf.C"() : () -> tensor<i32>
|
||||
%d = "tf.D"(%c, %a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%e = "tf.E"() : () -> tensor<i32>
|
||||
%e = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_device.return %a, %b, %c, %d, %e : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>)
|
||||
// CHECK: return %[[A_OUT]], %[[B_OUT]], %[[CLUSTER_OUT]]#0, %[[LAUNCH_OUT]], %[[CLUSTER_OUT]]#1
|
||||
@ -320,14 +320,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
func @head_tail_no_extraction_middle_outside_compiled_ops(%arg0: tensor<i32>) {
|
||||
// CHECK-NOT: "tf_device.launch"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
%a = "tf.A"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
%a = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
%b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
|
||||
"tf.C"(%b) : (tensor<i32>) -> ()
|
||||
%c = "tf.Identity"(%b) : (tensor<i32>) -> tensor<i32>
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
@ -379,7 +379,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"
|
||||
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[RI]], %[[B_OUT]])
|
||||
// CHECK-NEXT: "tf.E"(%[[C_OUT]], %[[HEAD_LAUNCH_OUT]])
|
||||
// CHECK-NEXT: "tf.IdentityN"(%[[C_OUT]], %[[HEAD_LAUNCH_OUT]])
|
||||
// CHECK-NEXT: tf_device.return %[[C_OUT]]
|
||||
// CHECK-NEXT: {
|
||||
// CHECK-DAG: num_cores_per_replica = 1
|
||||
@ -399,11 +399,72 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
%b = "tf.B"() : () -> tensor<i32>
|
||||
%c = "tf.C"(%ri, %b) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%d = "tf.D"(%a, %c, %ri) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%e = "tf.E"(%c, %a) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%e:2 = "tf.IdentityN"(%c, %a) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @side_effect_middle
|
||||
func @side_effect_middle() {
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @side_effect_head_no_operand
|
||||
func @side_effect_head_no_operand() {
|
||||
// CHECK: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"
|
||||
// CHECK-NEXT: tf_device.return %[[C_OUT]]
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.Const"
|
||||
// CHECK-NEXT: "tf.D"(%[[HEAD_LAUNCH_OUT]])
|
||||
// CHECK-NEXT: tf_device.return
|
||||
|
||||
"tf_device.cluster"() ( {
|
||||
%cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
%c = "tf.C"() {_xla_outside_compilation = "cluster1"} : () -> tensor<i32>
|
||||
"tf.D"(%c) : (tensor<i32>) -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @side_effect_tail_no_operand
|
||||
func @side_effect_tail_no_operand() {
|
||||
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK-NEXT: "tf.Const"
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
|
||||
// CHECK: "tf_device.launch"()
|
||||
// CHECK-NEXT: "tf.B"(%[[CLUSTER_OUT]])
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
"tf_device.cluster"() ( {
|
||||
%a = "tf.A"() : () -> tensor<i32>
|
||||
"tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
|
||||
"tf.C"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
%cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -110,6 +110,7 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
|
||||
pm.addPass(TF::CreateResourceDeviceInferencePass());
|
||||
pm.addPass(TFDevice::CreateClusterOutliningPass());
|
||||
pm.addPass(CreateTPUDynamicPaddingMapperPass());
|
||||
pm.addPass(CreateTPUResourceReadForWritePass());
|
||||
pm.addPass(CreateTPUShardingIdentificationPass());
|
||||
pm.addPass(TFDevice::CreateAnnotateParameterReplicationPass());
|
||||
pm.addPass(CreateTPURewritePass());
|
||||
|
@ -629,8 +629,7 @@ class Lower_UnaryOpsComposition
|
||||
LogicalResult matchAndRewrite(TF::_UnaryOpsCompositionOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value result = op.x();
|
||||
for (StringRef op_name :
|
||||
op.op_names().getAsRange<StringAttr, StringRef>()) {
|
||||
for (StringRef op_name : op.op_names().getAsValueRange<StringAttr>()) {
|
||||
std::string full_name = "tf." + op_name.str();
|
||||
// All ops in the sequences have the same result type as the original
|
||||
// result type.
|
||||
|
@ -287,6 +287,10 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicLayoutPass();
|
||||
// `tf_device.launch_func` `padding_map` attribute to its encapsulated function.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicPaddingMapperPass();
|
||||
|
||||
// Creates a pass that adds `tf.ReadVariableOp` to a TPU cluster for resources
|
||||
// the cluster only writes to.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUResourceReadForWritePass();
|
||||
|
||||
// Creates a pass that rewrites `tf_device.launch_func` on TPUs into TPU runtime
|
||||
// ops.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTPURewritePass();
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
@ -34,6 +35,7 @@ limitations under the License.
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
@ -118,7 +120,10 @@ tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op,
|
||||
// computation or other ops that can be extracted, and have no operands from
|
||||
// other ops in the TPU computation that cannot be extracted.
|
||||
llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
|
||||
const TF::SideEffectAnalysis& side_effect_analysis,
|
||||
tf_device::ClusterOp cluster) {
|
||||
const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
|
||||
cluster.getParentOfType<FuncOp>());
|
||||
Region* cluster_region = &cluster.body();
|
||||
llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
|
||||
|
||||
@ -127,6 +132,15 @@ llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
|
||||
if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
|
||||
// An outside compiled op can be extracted if its operands are not from
|
||||
// other ops in the cluster that cannot be extracted.
|
||||
|
||||
// Check if the side effecting op right before this side effecting op, if
|
||||
// it is side effecting, can be head extracted. Because of op ordering due
|
||||
// to side effects, if this is not true, this op cannot be head extracted.
|
||||
auto predecessors = analysis.DirectControlPredecessors(&cluster_op);
|
||||
if (!predecessors.empty() &&
|
||||
!head_outside_compiled_ops.contains(predecessors.back()))
|
||||
continue;
|
||||
|
||||
auto walk_result = cluster_op.walk([&](Operation* op) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
Operation* operand_op = GetOpOfValue(operand);
|
||||
@ -168,11 +182,11 @@ void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
|
||||
// Extracts and move outside compiled ops that have no dependencies in the
|
||||
// cluster to before the cluster.
|
||||
mlir::LogicalResult LiftHeadOutsideCompiledOps(
|
||||
OpBuilder* builder, const mlir::TF::RuntimeDevices& devices,
|
||||
tf_device::ClusterOp cluster, std::string* host_device,
|
||||
bool* cluster_updated) {
|
||||
OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
|
||||
const mlir::TF::RuntimeDevices& devices, tf_device::ClusterOp cluster,
|
||||
std::string* host_device, bool* cluster_updated) {
|
||||
llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
|
||||
FindOutsideCompiledOpsAtHead(cluster);
|
||||
FindOutsideCompiledOpsAtHead(side_effect_analysis, cluster);
|
||||
if (head_outside_compiled_ops.empty()) return success();
|
||||
if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, cluster,
|
||||
host_device)))
|
||||
@ -191,9 +205,12 @@ mlir::LogicalResult LiftHeadOutsideCompiledOps(
|
||||
// TPU computation or other ops that can be extracted, and have no results used
|
||||
// by other ops in the TPU computation that cannot be extracted.
|
||||
void FindOutsideCompiledOpsAtTailAndClusterResults(
|
||||
const TF::SideEffectAnalysis& side_effect_analysis,
|
||||
tf_device::ClusterOp cluster,
|
||||
llvm::SmallVectorImpl<Operation*>* tail_outside_compiled_ops,
|
||||
llvm::SmallVectorImpl<Value>* cluster_results) {
|
||||
const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
|
||||
cluster.getParentOfType<FuncOp>());
|
||||
Region* cluster_region = &cluster.body();
|
||||
llvm::SmallSetVector<Operation*, 4> tail_outside_compiled_ops_set;
|
||||
Operation* terminator = cluster.GetBody().getTerminator();
|
||||
@ -205,6 +222,15 @@ void FindOutsideCompiledOpsAtTailAndClusterResults(
|
||||
for (Operation& cluster_op : cluster_ops) {
|
||||
if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
|
||||
|
||||
// Check if the side effecting op right after this side effecting op, if
|
||||
// it is side effecting, can be tail extracted. Because of op ordering due
|
||||
// to side effects, if this is not true, this op cannot be tail extracted.
|
||||
auto successors = analysis.DirectControlSuccessors(
|
||||
&cluster_op, [&terminator](Operation* op) { return op != terminator; });
|
||||
if (!successors.empty() &&
|
||||
!tail_outside_compiled_ops_set.contains(successors.front()))
|
||||
continue;
|
||||
|
||||
llvm::SmallVector<int, 4> results_to_forward;
|
||||
bool can_be_extracted =
|
||||
llvm::all_of(cluster_op.getUsers(), [&](Operation* op) {
|
||||
@ -293,13 +319,14 @@ tf_device::ClusterOp UpdateClusterResults(
|
||||
// Extracts and move outside compiled ops that do not create dependencies in the
|
||||
// cluster to after the cluster.
|
||||
mlir::LogicalResult LiftTailOutsideCompiledOps(
|
||||
OpBuilder* builder, const mlir::TF::RuntimeDevices& devices,
|
||||
std::string host_device, tf_device::ClusterOp* cluster,
|
||||
bool* cluster_updated) {
|
||||
OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
|
||||
const mlir::TF::RuntimeDevices& devices, std::string host_device,
|
||||
tf_device::ClusterOp* cluster, bool* cluster_updated) {
|
||||
llvm::SmallVector<Operation*, 4> tail_outside_compiled_ops;
|
||||
llvm::SmallVector<Value, 4> cluster_results;
|
||||
FindOutsideCompiledOpsAtTailAndClusterResults(
|
||||
*cluster, &tail_outside_compiled_ops, &cluster_results);
|
||||
FindOutsideCompiledOpsAtTailAndClusterResults(side_effect_analysis, *cluster,
|
||||
&tail_outside_compiled_ops,
|
||||
&cluster_results);
|
||||
if (tail_outside_compiled_ops.empty()) return success();
|
||||
|
||||
if (host_device.empty())
|
||||
@ -365,6 +392,7 @@ struct TPUExtractHeadTailOutsideCompilation
|
||||
};
|
||||
|
||||
void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
|
||||
auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
|
||||
// Get runtime devices information from the closest parent module.
|
||||
auto module = getOperation();
|
||||
mlir::TF::RuntimeDevices devices;
|
||||
@ -379,10 +407,12 @@ void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
|
||||
for (tf_device::ClusterOp cluster : clusters) {
|
||||
std::string host_device;
|
||||
bool cluster_updated = false;
|
||||
if (failed(LiftHeadOutsideCompiledOps(&builder, devices, cluster,
|
||||
&host_device, &cluster_updated)) ||
|
||||
failed(LiftTailOutsideCompiledOps(&builder, devices, host_device,
|
||||
&cluster, &cluster_updated)))
|
||||
if (failed(LiftHeadOutsideCompiledOps(&builder, side_effect_analysis,
|
||||
devices, cluster, &host_device,
|
||||
&cluster_updated)) ||
|
||||
failed(LiftTailOutsideCompiledOps(&builder, side_effect_analysis,
|
||||
devices, host_device, &cluster,
|
||||
&cluster_updated)))
|
||||
return signalPassFailure();
|
||||
if (cluster_updated) RemoveClusterAliasedOutputs(&builder, cluster);
|
||||
}
|
||||
|
@ -0,0 +1,140 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFTPU {
|
||||
|
||||
// A pass that finds TPU clusters with write only resource access and adds an
|
||||
// associated resource read, so the resource can later be fused into TPUExecute.
|
||||
namespace {
|
||||
struct TPUResourceReadForWrite
|
||||
: public PassWrapper<TPUResourceReadForWrite, OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// Helper struct holding a resource value and its associated type.
|
||||
struct ResourceValueAndSubtype {
|
||||
Value resource;
|
||||
Type subtype;
|
||||
};
|
||||
|
||||
// Finds resource handle and type for result if result writes to a resource.
|
||||
ResourceValueAndSubtype GetResourceWriteResult(
|
||||
tf_device::ClusterFuncOp cluster_func, Value result) {
|
||||
ResourceValueAndSubtype resource;
|
||||
if (!result.hasOneUse()) return resource;
|
||||
Operation* result_user = *result.getUsers().begin();
|
||||
auto assign_var = dyn_cast<TF::AssignVariableOp>(result_user);
|
||||
if (!assign_var) return resource;
|
||||
|
||||
auto handle = assign_var.resource();
|
||||
// Skip result if cluster writes to the same variable via multiple results.
|
||||
for (Operation* handle_user : handle.getUsers()) {
|
||||
if (handle_user == assign_var) continue;
|
||||
auto assign_var_user = dyn_cast<TF::AssignVariableOp>(handle_user);
|
||||
if (!assign_var_user) continue;
|
||||
if (assign_var_user.value().getDefiningOp() == cluster_func)
|
||||
return resource;
|
||||
}
|
||||
|
||||
resource.resource = assign_var.resource();
|
||||
resource.subtype = assign_var.value().getType();
|
||||
return resource;
|
||||
}
|
||||
|
||||
// Checks if resource is read by TPU cluster.
|
||||
bool ClusterFuncHasResourceRead(tf_device::ClusterFuncOp cluster_func,
|
||||
Value resource) {
|
||||
for (Operation* resource_user : resource.getUsers())
|
||||
if (auto read = dyn_cast<TF::ReadVariableOp>(resource_user))
|
||||
for (Operation* read_user : read.value().getUsers())
|
||||
if (read_user == cluster_func) return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void TPUResourceReadForWrite::runOnOperation() {
|
||||
SmallVector<tf_device::ClusterFuncOp, 4> cluster_funcs;
|
||||
getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) {
|
||||
cluster_funcs.push_back(cluster_func);
|
||||
});
|
||||
|
||||
OpBuilder builder(&getContext());
|
||||
// Add resource reads for resource writes from TPU cluster where for such
|
||||
// resources the TPU cluster does not read from.
|
||||
for (tf_device::ClusterFuncOp cluster_func : cluster_funcs) {
|
||||
builder.setInsertionPoint(cluster_func);
|
||||
|
||||
SmallVector<Value, 4> read_operands;
|
||||
for (Value result : cluster_func.getResults()) {
|
||||
// TODO(lyandy): Update pass to use resource alias analysis.
|
||||
auto resource_and_type = GetResourceWriteResult(cluster_func, result);
|
||||
if (!resource_and_type.resource) continue;
|
||||
if (ClusterFuncHasResourceRead(cluster_func, resource_and_type.resource))
|
||||
continue;
|
||||
auto new_read = builder.create<TF::ReadVariableOp>(
|
||||
resource_and_type.resource.getLoc(), resource_and_type.subtype,
|
||||
resource_and_type.resource);
|
||||
read_operands.push_back(new_read.value());
|
||||
}
|
||||
|
||||
if (read_operands.empty()) continue;
|
||||
|
||||
// Update caller and function types with new read operands.
|
||||
auto operands = llvm::to_vector<4>(cluster_func.getOperands());
|
||||
operands.append(read_operands.begin(), read_operands.end());
|
||||
|
||||
auto new_cluster_func = builder.create<tf_device::ClusterFuncOp>(
|
||||
cluster_func.getLoc(), cluster_func.getResultTypes(), operands,
|
||||
cluster_func.getAttrs());
|
||||
cluster_func.replaceAllUsesWith(new_cluster_func);
|
||||
FuncOp func = cluster_func.getFunc();
|
||||
Block& block = func.front();
|
||||
for (Value read_operand : read_operands)
|
||||
block.addArgument(read_operand.getType());
|
||||
|
||||
func.setType(FunctionType::get(block.getArgumentTypes(),
|
||||
func.getCallableResults(), &getContext()));
|
||||
cluster_func.erase();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUResourceReadForWritePass() {
|
||||
return std::make_unique<TPUResourceReadForWrite>();
|
||||
}
|
||||
|
||||
static PassRegistration<TPUResourceReadForWrite> pass(
|
||||
"tf-tpu-resource-read-for-write",
|
||||
"Inserts tf.ReadVariableOp inputs to a TPU cluster for resource writes "
|
||||
"with no reads");
|
||||
|
||||
} // namespace TFTPU
|
||||
} // namespace mlir
|
@ -49,6 +49,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
|
||||
#include "tensorflow/compiler/mlir/utils/name_utils.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
@ -80,46 +81,14 @@ constexpr char kInvalidExecutorGraphMsg[] =
|
||||
constexpr char kDeviceAttr[] = "tf.device";
|
||||
constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
|
||||
|
||||
bool IsLegalChar(char c, bool first_char) {
|
||||
if (isalpha(c)) return true;
|
||||
if (isdigit(c)) return true;
|
||||
if (c == '.') return true;
|
||||
if (c == '_') return true;
|
||||
|
||||
// First character of a node name can only be a letter, digit, dot or
|
||||
// underscore.
|
||||
if (first_char) return false;
|
||||
|
||||
if (c == '/') return true;
|
||||
if (c == '-') return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Convert characters in name that are considered illegal in TensorFlow Node
|
||||
// name to '.'.
|
||||
std::string LegalizeNodeName(llvm::StringRef name) {
|
||||
assert(!name.empty() && "expected non-empty name");
|
||||
|
||||
std::string legalized_name;
|
||||
bool first = true;
|
||||
for (auto c : name) {
|
||||
if (IsLegalChar(c, first)) {
|
||||
legalized_name += c;
|
||||
} else {
|
||||
legalized_name += '.';
|
||||
}
|
||||
first = false;
|
||||
}
|
||||
|
||||
return legalized_name;
|
||||
}
|
||||
|
||||
// OpOrArgLocNameMapper that legalizes the returned name.
|
||||
class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper {
|
||||
private:
|
||||
std::string GetName(OpOrVal op_or_val) override {
|
||||
return LegalizeNodeName(OpOrArgLocNameMapper::GetName(op_or_val));
|
||||
std::string name = OpOrArgLocNameMapper::GetName(op_or_val);
|
||||
assert(!name.empty() && "expected non-empty name");
|
||||
mlir::LegalizeNodeName(name);
|
||||
return name;
|
||||
}
|
||||
};
|
||||
|
||||
@ -523,13 +492,14 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
|
||||
if (index >= num_data_results) break;
|
||||
// TODO(jpienaar): If there is a result index specified, ensure only one
|
||||
// and that it matches the result index of the op.
|
||||
std::string orig_name(output_names[index]);
|
||||
auto tensor_id = ParseTensorName(orig_name);
|
||||
auto name = LegalizeNodeName(
|
||||
llvm::StringRef(tensor_id.node().data(), tensor_id.node().size()));
|
||||
std::string name(output_names[index]);
|
||||
auto tensor_id = ParseTensorName(name);
|
||||
std::string tensor_id_node(tensor_id.node());
|
||||
assert(!tensor_id_node.empty() && "expected non-empty name");
|
||||
mlir::LegalizeNodeName(tensor_id_node);
|
||||
|
||||
// Ensure name does not get reused.
|
||||
(void)exporter.op_to_name_.GetUniqueName(name);
|
||||
(void)exporter.op_to_name_.GetUniqueName(tensor_id_node);
|
||||
}
|
||||
}
|
||||
|
||||
@ -537,8 +507,9 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
|
||||
TF_RET_CHECK(input_names.size() == block.getNumArguments());
|
||||
for (const auto& it : llvm::enumerate(function.getArguments())) {
|
||||
// TODO(lyandy): Update when changing feed/fetch import.
|
||||
std::string orig_name(input_names[it.index()]);
|
||||
std::string name = LegalizeNodeName(orig_name);
|
||||
std::string name(input_names[it.index()]);
|
||||
assert(!name.empty() && "expected non-empty name");
|
||||
mlir::LegalizeNodeName(name);
|
||||
auto tensor_id = ParseTensorName(name);
|
||||
TF_RET_CHECK(tensor_id.index() == 0)
|
||||
<< "input port designation not supported";
|
||||
|
51
tensorflow/compiler/mlir/utils/array_container_utils.h
Normal file
51
tensorflow/compiler/mlir/utils/array_container_utils.h
Normal file
@ -0,0 +1,51 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
template <typename T>
|
||||
inline llvm::ArrayRef<T> SpanToArrayRef(absl::Span<const T> span) {
|
||||
return llvm::ArrayRef<T>(span.data(), span.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline llvm::ArrayRef<T> SpanToArrayRef(absl::Span<T> span) {
|
||||
return llvm::ArrayRef<T>(span.data(), span.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline llvm::MutableArrayRef<T> SpanToMutableArrayRef(absl::Span<T> span) {
|
||||
return llvm::MutableArrayRef<T>(span.data(), span.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline absl::Span<const T> ArrayRefToSpan(llvm::ArrayRef<T> ref) {
|
||||
return absl::Span<const T>(ref.data(), ref.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline absl::Span<T> MutableArrayRefToSpan(llvm::MutableArrayRef<T> ref) {
|
||||
return absl::Span<T>(ref.data(), ref.size());
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_
|
99
tensorflow/compiler/mlir/utils/name_utils.cc
Normal file
99
tensorflow/compiler/mlir/utils/name_utils.cc
Normal file
@ -0,0 +1,99 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/utils/name_utils.h"
|
||||
|
||||
#include <cctype>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "mlir/IR/Identifier.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace {
|
||||
// Checks if a character is legal for a TensorFlow node name, with special
|
||||
// handling if a character is at the beginning.
|
||||
bool IsLegalChar(char c, bool first_char) {
|
||||
if (isalpha(c)) return true;
|
||||
if (isdigit(c)) return true;
|
||||
if (c == '.') return true;
|
||||
if (c == '_') return true;
|
||||
|
||||
// First character of a node name can only be a letter, digit, dot or
|
||||
// underscore.
|
||||
if (first_char) return false;
|
||||
|
||||
if (c == '/') return true;
|
||||
if (c == '-') return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
void LegalizeNodeName(std::string& name) {
|
||||
if (name.empty()) return;
|
||||
|
||||
if (!IsLegalChar(name[0], /*first_char=*/true)) name[0] = '.';
|
||||
|
||||
for (char& c : llvm::drop_begin(name, 1))
|
||||
if (!IsLegalChar(c, /*first_char=*/false)) c = '.';
|
||||
}
|
||||
|
||||
std::string GetNameFromLoc(Location loc) {
|
||||
llvm::SmallVector<llvm::StringRef, 8> loc_names;
|
||||
llvm::SmallVector<Location, 8> locs;
|
||||
locs.push_back(loc);
|
||||
bool names_is_nonempty = false;
|
||||
|
||||
while (!locs.empty()) {
|
||||
Location curr_loc = locs.pop_back_val();
|
||||
|
||||
if (auto name_loc = curr_loc.dyn_cast<NameLoc>()) {
|
||||
// Add name in NameLoc. For NameLoc we also account for names due to ops
|
||||
// in functions where the op's name is first.
|
||||
auto name = name_loc.getName().strref().split('@').first;
|
||||
loc_names.push_back(name);
|
||||
if (!name.empty()) names_is_nonempty = true;
|
||||
continue;
|
||||
} else if (auto call_loc = curr_loc.dyn_cast<CallSiteLoc>()) {
|
||||
// Add name if CallSiteLoc's callee has a NameLoc (as should be the
|
||||
// case if imported with DebugInfo).
|
||||
if (auto name_loc = call_loc.getCallee().dyn_cast<NameLoc>()) {
|
||||
auto name = name_loc.getName().strref().split('@').first;
|
||||
loc_names.push_back(name);
|
||||
if (!name.empty()) names_is_nonempty = true;
|
||||
continue;
|
||||
}
|
||||
} else if (auto fused_loc = curr_loc.dyn_cast<FusedLoc>()) {
|
||||
// Push all locations in FusedLoc in reverse order, so locations are
|
||||
// visited based on order in FusedLoc.
|
||||
auto reversed_fused_locs = llvm::reverse(fused_loc.getLocations());
|
||||
locs.append(reversed_fused_locs.begin(), reversed_fused_locs.end());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Location is not a supported, so an empty StringRef is added.
|
||||
loc_names.push_back(llvm::StringRef());
|
||||
}
|
||||
|
||||
if (names_is_nonempty)
|
||||
return llvm::join(loc_names.begin(), loc_names.end(), ";");
|
||||
|
||||
return "";
|
||||
}
|
||||
|
||||
} // namespace mlir
|
35
tensorflow/compiler/mlir/utils/name_utils.h
Normal file
35
tensorflow/compiler/mlir/utils/name_utils.h
Normal file
@ -0,0 +1,35 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
|
||||
// Converts characters in name that are considered illegal in TensorFlow Node
|
||||
// name to '.'.
|
||||
void LegalizeNodeName(std::string& name);
|
||||
|
||||
// Creates a TensorFlow node name from a location.
|
||||
std::string GetNameFromLoc(Location loc);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_
|
34
tensorflow/compiler/mlir/utils/string_container_utils.h
Normal file
34
tensorflow/compiler/mlir/utils/string_container_utils.h
Normal file
@ -0,0 +1,34 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
inline absl::string_view StringRefToView(llvm::StringRef ref) {
|
||||
return absl::string_view(ref.data(), ref.size());
|
||||
}
|
||||
|
||||
inline llvm::StringRef StringViewToRef(absl::string_view view) {
|
||||
return llvm::StringRef(view.data(), view.size());
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_
|
@ -238,7 +238,6 @@ cc_library(
|
||||
deps = [
|
||||
":type_to_shape",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_type",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
@ -389,7 +388,6 @@ cc_library(
|
||||
":xla_legalize_tf_with_tf2xla",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration",
|
||||
"//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
|
||||
"//tensorflow/compiler/mlir/hlo:legalize_control_flow",
|
||||
"//tensorflow/compiler/mlir/hlo:legalize_tanh_to_approximation",
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user