Merge branch 'master' into toupstream/16x8_batch_matmul

This commit is contained in:
Thibaut Goetghebuer 2020-09-21 07:37:14 +00:00 committed by GitHub
commit 0f082f6b40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
357 changed files with 9560 additions and 4939 deletions

View File

@ -619,6 +619,8 @@ tf_cuda_library(
":c_api",
":c_api_experimental",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_datatype",
"//tensorflow/c:tf_tensor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -17,12 +17,16 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
using tensorflow::string;
using tensorflow::tstring;
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) {
float data[] = {value};
@ -36,6 +40,19 @@ TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) {
return th;
}
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx,
const tensorflow::tstring& value) {
TF_Status* status = TF_NewStatus();
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_STRING, nullptr, 0, status);
tstring* data = static_cast<tstring*>(TF_TensorData(t));
*data = value;
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value) {
int data[] = {value};
TF_Status* status = TF_NewStatus();

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
@ -28,6 +29,10 @@ TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value);
// Return a tensor handle containing a bool scalar
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value);
// Return a tensor handle containing a tstring scalar
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx,
const tensorflow::tstring& value);
// Return a tensor handle containing a 2x2 matrix of doubles
TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx);

View File

@ -249,21 +249,36 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
}
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
auto tf_dlm_context = GetDlContext(h, status);
if (!status->status.ok()) {
return nullptr;
}
auto* tf_dlm_data = TFE_TensorHandleDevicePointer(h, status);
if (!status->status.ok()) {
return nullptr;
}
const Tensor* tensor = GetTensorFromHandle(h, status);
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
auto tf_dlm_type = GetDlDataType(data_type, status);
if (!status->status.ok()) {
return nullptr;
}
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
tf_dlm_tensor_ctx->reference = tensor_ref;
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
dlm_tensor->deleter = &DLManagedTensorDeleter;
dlm_tensor->dl_tensor.ctx = GetDlContext(h, status);
dlm_tensor->dl_tensor.ctx = tf_dlm_context;
int ndim = tensor->dims();
dlm_tensor->dl_tensor.ndim = ndim;
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status);
dlm_tensor->dl_tensor.data = tf_dlm_data;
dlm_tensor->dl_tensor.dtype = tf_dlm_type;
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
@ -276,13 +291,14 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
}
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
dlm_tensor->dl_tensor.shape = shape_arr->data();
// There are two ways to represent compact row-major data
// 1) nullptr indicates tensor is compact and row-majored.
// 2) fill in the strides array as the real case for compact row-major data.
// Here we choose option 2, since some frameworks didn't handle the strides
// argument properly.
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
dlm_tensor->dl_tensor.strides = stride_arr->data();
dlm_tensor->dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here
return static_cast<void*>(dlm_tensor);

View File

@ -62,6 +62,7 @@ cc_library(
":function_metadata",
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/experimental/saved_model/core/revived_types:asset",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
@ -69,6 +70,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
@ -138,7 +140,6 @@ cc_library(
":saved_model_api",
":saved_model_utils",
":signature_def_function",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core/ops:restore_ops",
@ -148,10 +149,10 @@ cc_library(
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
"//tensorflow/cc/saved_model:bundle_v2",
"//tensorflow/cc/saved_model:constants",
"//tensorflow/cc/saved_model:loader_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",

View File

@ -8,6 +8,25 @@ package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "asset",
srcs = [
"asset.cc",
],
hdrs = [
"asset.h",
],
deps = [
":tensorhandle_convertible",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/cc/saved_model:constants",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "constant",
srcs = [

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
#include <string>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/path.h"
namespace tensorflow {
Asset::Asset(ImmediateTensorHandlePtr handle)
: TensorHandleConvertible(std::move(handle)) {}
Status Asset::Create(ImmediateExecutionContext* ctx,
const std::string& saved_model_dir,
const std::string& asset_filename,
std::unique_ptr<Asset>* output) {
std::string abs_path =
io::JoinPath(saved_model_dir, kSavedModelAssetsDirectory, asset_filename);
AbstractTensorPtr tensor(ctx->CreateStringScalar(abs_path));
if (tensor.get() == nullptr) {
return errors::Internal(
"Failed to create scalar string tensor for Asset at path ", abs_path);
}
ImmediateTensorHandlePtr handle(ctx->CreateLocalHandle(tensor.get()));
output->reset(new Asset(std::move(handle)));
return Status();
}
} // namespace tensorflow

View File

@ -0,0 +1,50 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_
#include <string>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/tensor.pb.h"
namespace tensorflow {
class Asset : public TensorHandleConvertible {
public:
static Status Create(ImmediateExecutionContext* ctx,
const std::string& saved_model_dir,
const std::string& asset_filename,
std::unique_ptr<Asset>* output);
// Asset is movable, but not copyable.
Asset(Asset&& other) = default;
Asset& operator=(Asset&& other) = default;
~Asset() override = default;
private:
explicit Asset(ImmediateTensorHandlePtr handle);
Asset(const Asset&) = delete;
Asset& operator=(const Asset&) = delete;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_

View File

@ -100,6 +100,20 @@ Status ValidateSavedFunctionCompatibleWithFunctionDef(
} // namespace
Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset,
const std::string& saved_model_dir,
absl::Span<const AssetFileDef> assets,
std::unique_ptr<Asset>* output) {
int asset_index = asset.asset_file_def_index();
if (asset_index >= assets.size()) {
return errors::FailedPrecondition(
"SavedAsset contained asset index ", asset_index,
" but AssetFileDef only contains ", assets.size(), " # of assets");
}
const std::string& asset_filename = assets[asset_index].filename();
return Asset::Create(ctx, saved_model_dir, asset_filename, output);
}
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
const TensorProto& proto,
std::unique_ptr<Constant>* output) {
@ -211,7 +225,8 @@ Status FlattenSignature(const StructuredValue& signature,
}
const SavedObject* FindNodeAtPath(StringPiece path,
const SavedObjectGraph& object_graph) {
const SavedObjectGraph& object_graph,
int* node_id) {
const auto& nodes = object_graph.nodes();
if (nodes.empty()) {
return nullptr;
@ -231,6 +246,9 @@ const SavedObject* FindNodeAtPath(StringPiece path,
if (child_node_iter == current_node->children().end()) {
return nullptr;
}
if (node_id) {
*node_id = child_node_iter->node_id();
}
current_node = &nodes.Get(child_node_iter->node_id());
}

View File

@ -22,7 +22,9 @@ limitations under the License.
#include <memory>
#include <unordered_map>
#include "absl/types/span.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
@ -31,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
@ -52,6 +55,11 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx,
const SavedVariable& variable,
std::unique_ptr<Variable>* output);
Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset,
const std::string& saved_model_dir,
absl::Span<const AssetFileDef> assets,
std::unique_ptr<Asset>* output);
// Creates a TFConcreteFunction from a SavedConcreteFunction.
Status LoadTFConcreteFunction(
const SavedConcreteFunction& saved_concrete_function,
@ -70,9 +78,11 @@ Status FlattenSignature(const StructuredValue& signature,
// Find the SavedObject in `object_graph` at location `path`. `path` must be
// a dot-delimited string of object names relative to the root object. If no
// object is found, returns nullptr. Callers must ensure `object_graph`
// outlives the returned pointer.
// outlives the returned pointer. If not `nullptr`, `node_id` will contain the
// index of the returned object in the `SavedObjectGraph.nodes` array.
const SavedObject* FindNodeAtPath(StringPiece path,
const SavedObjectGraph& object_graph);
const SavedObjectGraph& object_graph,
int* node_id = nullptr);
// Maps each node in `graphdef` to its corresponding Attribute Map.
// Callers must ensure that `graphdef` outlives the returned map.

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
#include "tensorflow/cc/saved_model/bundle_v2.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader_util.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
@ -108,12 +109,17 @@ Status ConstantFromSavedConstant(
// SavedResources. These are returned via the `out` parameter.
Status ReviveObjects(
const MetaGraphDef& metagraph, ImmediateExecutionContext* context,
const std::string& directory,
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>*
revived_objects) {
// This is needed to restore "Constant" nodes by looking up their
// "Value" attribute.
NodeAttrMap node_attr_map = internal::NodeToAttrMap(metagraph.graph_def());
// These are needed for creating "Assets", by looking up their filenames.
std::vector<AssetFileDef> assets;
TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(metagraph, &assets));
// Iterate through all the saved objects, restoring objects as we go.
// We don't recreate functions until all other objects have been created.
for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) {
@ -129,12 +135,10 @@ Status ReviveObjects(
node_attr_map, &constant));
(*revived_objects)[i] = std::move(constant);
} else if (node.kind_case() == SavedObject::kAsset) {
// TODO(bmzhao): Implement Asset C++ class. This should be just recreating
// the full path to the asset file:
// https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/saved_model/load.py#L395-L396
// and storing it as a string tensor:
// https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/training/tracking/tracking.py#L324-L325
return errors::Unimplemented("SavedAsset loading is not implemented yet");
std::unique_ptr<Asset> asset;
TF_RETURN_IF_ERROR(internal::LoadSavedAsset(context, node.asset(),
directory, assets, &asset));
(*revived_objects)[i] = std::move(asset);
} else if (node.kind_case() == SavedObject::kResource) {
// TODO(bmzhao): Figure out how resource loading works and implement it
return errors::Unimplemented(
@ -264,6 +268,12 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
}
const std::string& checkpoint_key = attribute->checkpoint_key();
if (!bundle->variable_reader()->Contains(checkpoint_key)) {
LOG(WARNING) << "No checkpoint entry found for " << checkpoint_key
<< ". Variable will be uninitialized.";
return Status();
}
std::string variables_path_prefix =
io::JoinPath(directory, kSavedModelVariablesDirectory,
kSavedModelVariablesFilename);
@ -321,6 +331,31 @@ std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
return result;
}
Status TFSavedModelAPI::GetVariable(const std::string& variable_path,
Variable** variable) {
int node_id;
const SavedObject* object = internal::FindNodeAtPath(
variable_path, bundle_.saved_object_graph(), &node_id);
if (object == nullptr) {
return errors::NotFound("No saved object found at path ", variable_path);
}
if (object->kind_case() == SavedObject::kVariable) {
auto iter = revived_objects_.find(node_id);
if (iter == revived_objects_.end()) {
return errors::Internal("Variable ", variable_path,
" was not properly revived.");
}
*variable = static_cast<Variable*>(iter->second.get());
return Status();
}
*variable = nullptr;
return errors::InvalidArgument(
variable_path, " is not a path to a Variable (kind=", object->kind_case(),
")");
}
TFSavedModelAPI::TFSavedModelAPI(
const std::string& directory, SavedModelV2Bundle bundle,
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
@ -352,8 +387,8 @@ Status TFSavedModelAPI::Load(
// https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454
RevivedObjectMap revived_objects;
TF_RETURN_IF_ERROR(
ReviveObjects(bundle.meta_graph_def(), context, &revived_objects));
TF_RETURN_IF_ERROR(ReviveObjects(bundle.meta_graph_def(), context, directory,
&revived_objects));
// TODO(bmzhao): When we later add support for loading resources, we need to
// handle the case where materializing a function's captures requires invoking

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
#include "tensorflow/cc/saved_model/bundle_v2.h"
@ -68,6 +69,8 @@ class TFSavedModelAPI : public SavedModelAPI {
~TFSavedModelAPI() override = default;
Status GetVariable(const std::string& variable_path, Variable** variable);
private:
TFSavedModelAPI(
const std::string& directory, SavedModelV2Bundle bundle,

View File

@ -245,14 +245,17 @@ tf_cc_test(
"saved_model_api_test.cc",
],
data = [
"//tensorflow/c/experimental/saved_model/internal/testdata:saved_models",
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
],
deps = [
":saved_model_api_type",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_tensor",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/c/eager:c_api_test_util",
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_api",
"//tensorflow/c/experimental/saved_model/public:concrete_function",
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
"//tensorflow/core:lib",

View File

@ -21,15 +21,20 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/tstring.h"
namespace {
using tensorflow::tstring;
constexpr char kTestData[] = "cc/saved_model/testdata";
const char* kServeTag[] = {"serve"};
@ -137,6 +142,103 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) {
TFE_DeleteContext(ctx);
}
TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
bool use_tfrt = GetParam();
if (use_tfrt) {
TFE_DeleteContextOptions(opts);
TF_DeleteStatus(status);
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
std::string model_dir = SavedModelPath("AssetModule");
TF_SavedModel* saved_model =
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_ConcreteFunction* read_file_fn =
TF_GetSavedModelConcreteFunction(saved_model, "read_file", status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* read_file_op =
TF_ConcreteFunctionMakeCallOp(read_file_fn, nullptr, 0, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
// inputs + outputs a function has.
TFE_TensorHandle* read_file_fn_outputs[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(read_file_op, &read_file_fn_outputs[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* result = TFE_TensorHandleResolve(read_file_fn_outputs[0], status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TF_NumDims(result), 0);
tensorflow::tstring* output_value =
static_cast<tensorflow::tstring*>(TF_TensorData(result));
EXPECT_EQ(std::string(*output_value), "TEST ASSET FILE CONTENTS\n");
TF_DeleteTensor(result);
TFE_DeleteTensorHandle(read_file_fn_outputs[0]);
TFE_DeleteOp(read_file_op);
TF_DeleteSavedModel(saved_model);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
}
TEST_P(CSavedModelAPITest, LoadSavedModelWithUninitializedVariable) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
bool use_tfrt = GetParam();
if (use_tfrt) {
TFE_DeleteContextOptions(opts);
TF_DeleteStatus(status);
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
std::string model_dir = tensorflow::io::JoinPath(
tensorflow::testing::TensorFlowSrcRoot(),
"c/experimental/saved_model/internal/testdata/UninitializedVariable");
TF_SavedModel* saved_model =
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
tensorflow::TFSavedModelAPI* model_api =
tensorflow::down_cast<tensorflow::TFSavedModelAPI*>(
tensorflow::unwrap(saved_model));
tensorflow::Variable* uninitialized_variable;
ASSERT_EQ(tensorflow::Status::OK(),
model_api->GetVariable("uninitialized_variable",
&uninitialized_variable));
ASSERT_EQ(tensorflow::DT_FLOAT, uninitialized_variable->dtype());
ASSERT_EQ(tensorflow::Status::OK(),
model_api->GetVariable("sub_module.uninitialized_variable",
&uninitialized_variable));
ASSERT_EQ(tensorflow::DT_INT64, uninitialized_variable->dtype());
TF_DeleteSavedModel(saved_model);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
}
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest,
::testing::Bool());

View File

@ -0,0 +1,36 @@
load("//tensorflow:tensorflow.bzl", "py_strict_binary")
package(
licenses = ["notice"], # Apache 2.0
)
# Run this binary manually, with an argument pointing to the testdata/
# directory, to generate the test files used by the filegroup rule below.
py_strict_binary(
name = "gen_saved_models",
srcs = ["gen_saved_models.py"],
python_version = "PY3",
deps = [
"//tensorflow/python:dtypes",
"//tensorflow/python:platform",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:variables",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/module",
"//tensorflow/python/saved_model",
"//tensorflow/python/saved_model:save_options",
],
)
# Files generated by the binary above.
filegroup(
name = "saved_models",
srcs = glob([
"UninitializedVariable/**",
]),
visibility = [
"//tensorflow/c/experimental/saved_model/internal:__pkg__",
],
)

View File

@ -0,0 +1,84 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Creates saved models used for testing.
This executable should be run with an argument pointing to the testdata/ folder
in this directory. It will re-generate the saved models that are used for
testing.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import google_type_annotations
from __future__ import print_function
import os
from tensorflow.python.compat import v2_compat
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec
from tensorflow.python.module import module
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import app
from tensorflow.python.saved_model import saved_model
def _gen_uninitialized_variable(base_dir):
"""Generates a saved model with an uninitialized variable."""
class SubModule(module.Module):
"""A module with an UninitializedVariable."""
def __init__(self):
self.uninitialized_variable = resource_variable_ops.UninitializedVariable(
name="uninitialized_variable", dtype=dtypes.int64)
class Module(module.Module):
"""A module with an UninitializedVariable."""
def __init__(self):
super(Module, self).__init__()
self.sub_module = SubModule()
self.initialized_variable = variables.Variable(
1.0, name="initialized_variable")
# An UninitializedVariable with the same name as the variable in the
# SubModule, but with a different type.
self.uninitialized_variable = resource_variable_ops.UninitializedVariable(
name="uninitialized_variable", dtype=dtypes.float32)
@def_function.function(
input_signature=[tensor_spec.TensorSpec((), dtypes.float32)])
def compute(self, value):
return self.initialized_variable + value
to_save = Module()
saved_model.save(
to_save, export_dir=os.path.join(base_dir, "UninitializedVariable"))
def main(args):
if len(args) != 2:
raise app.UsageError("Expected one argument (base_dir).")
_, base_dir = args
_gen_uninitialized_variable(base_dir)
if __name__ == "__main__":
v2_compat.enable_v2_behavior()
app.run(main)

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/stream_executor/executor_cache.h"
#include "tensorflow/stream_executor/multi_platform_manager.h"
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
#include "tensorflow/stream_executor/timer.h"

View File

@ -209,9 +209,13 @@ tf_cc_test(
py_binary(
name = "testdata/generate_saved_models",
srcs = ["testdata/generate_saved_models.py"],
data = [
":saved_model_asset_data",
],
python_version = "PY3",
srcs_version = "PY3",
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tensor_spec",
@ -221,6 +225,7 @@ py_binary(
"//tensorflow/python/module",
"//tensorflow/python/saved_model",
"//tensorflow/python/saved_model:save_options",
"//tensorflow/python/training/tracking",
"@absl_py//absl:app",
],
)
@ -229,6 +234,7 @@ py_binary(
filegroup(
name = "saved_model_test_files",
srcs = glob([
"testdata/AssetModule/**",
"testdata/half_plus_two_pbtxt/**",
"testdata/half_plus_two_main_op/**",
"testdata/half_plus_two/**",
@ -245,6 +251,13 @@ alias(
actual = ":saved_model_test_files",
)
filegroup(
name = "saved_model_asset_data",
srcs = [
"testdata/test_asset.txt",
],
)
exports_files(
glob([
"testdata/half_plus_two_pbtxt/**",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/loader_util.h"
#include "tensorflow/cc/saved_model/reader.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/io/path.h"
@ -73,26 +74,41 @@ uint64 GetLatencyMicroseconds(const uint64 start_microseconds) {
// Ensure that constant tensors loaded from the saved model have valid shape.
// Also ensure that constant nodes have a value assigned to them.
// TODO(b/154763635): this is temporary and will be replaced with a better audit
static Status ValidateNode(const NodeDef& node) {
const auto node_iterator = node.attr().find("value");
if (node_iterator != node.attr().end()) {
AttrValue node_value = node_iterator->second;
if (node_value.has_tensor()) {
const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
if (node_shape.num_elements() < 0) {
return errors::FailedPrecondition(
"Saved model contains node \"", node.name(), "\" (op \"", node.op(),
"\") which initializes from a tensor with ",
node_shape.num_elements(), " elements");
}
}
} else if (node.op() == "Const") {
return errors::FailedPrecondition(
"Saved model contains node \"", node.name(),
"\" which is a constant tensor but no value has been provided");
}
return Status::OK();
}
static Status ValidateSavedTensors(const GraphDef& graph_def) {
for (const auto& node : graph_def.node()) {
const auto node_iterator = node.attr().find("value");
if (node_iterator != node.attr().end()) {
AttrValue node_value = node_iterator->second;
if (node_value.has_tensor()) {
const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
if (node_shape.num_elements() < 0) {
return errors::FailedPrecondition(
"Saved model contains node \"", node.name(), "\" (op \"",
node.op(), "\") which initializes from a tensor with ",
node_shape.num_elements(), " elements");
}
TF_RETURN_IF_ERROR(ValidateNode(node));
}
if (graph_def.has_library()) {
const FunctionDefLibrary& library = graph_def.library();
for (const auto& function : library.function()) {
for (const auto& node : function.node_def()) {
TF_RETURN_IF_ERROR(ValidateNode(node));
}
} else if (node.op() == "Const") {
return errors::FailedPrecondition(
"Saved model contains node \"", node.name(),
"\" which is a constant tensor but no value has been provided");
}
}
return Status::OK();
}

View File

@ -45,6 +45,8 @@ constexpr char kTestFuzzGeneratedNegativeShape[] =
"cc/saved_model/testdata/fuzz_generated/negative_shape";
constexpr char kTestFuzzGeneratedConstWithNoValue[] =
"cc/saved_model/testdata/fuzz_generated/const_with_no_value";
constexpr char kTestFuzzGeneratedBadNodeAttr[] =
"cc/saved_model/testdata/fuzz_generated/bad_node_attr";
class LoaderTest : public ::testing::Test {
protected:
@ -328,5 +330,20 @@ TEST_F(LoaderTest, ConstNoValue) {
std::string::npos);
}
TEST_F(LoaderTest, BadNodeAttr) {
SavedModelBundle bundle;
RunOptions run_options;
SessionOptions session_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestFuzzGeneratedBadNodeAttr);
Status st = LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle);
EXPECT_FALSE(st.ok());
EXPECT_NE(
st.error_message().find("constant tensor but no value has been provided"),
std::string::npos);
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1 @@
TEST ASSET FILE CONTENTS

Binary file not shown.

View File

@ -29,9 +29,12 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.module import module
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.saved_model import save_options
from tensorflow.python.saved_model import saved_model
from tensorflow.python.training.tracking import tracking
class VarsAndArithmeticObjectGraph(module.Module):
@ -68,9 +71,21 @@ class CyclicModule(module.Module):
self.child = ReferencesParent(self)
class AssetModule(module.Module):
def __init__(self):
self.asset = tracking.Asset(
test.test_src_dir_path("cc/saved_model/testdata/test_asset.txt"))
@def_function.function(input_signature=[])
def read_file(self):
return io_ops.read_file(self.asset)
MODULE_CTORS = {
"VarsAndArithmeticObjectGraph": VarsAndArithmeticObjectGraph,
"CyclicModule": CyclicModule,
"AssetModule": AssetModule,
}

View File

@ -0,0 +1 @@
TEST ASSET FILE CONTENTS

View File

@ -1708,40 +1708,6 @@ std::atomic<int64>* GetPointerToFuel(int64 initial_value) {
}
} // anonymous namespace
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
RecursiveCompilabilityChecker::UncompilableNodesMap*
uncompilable_node_info) {
Device* device = flr->device();
const XlaOpRegistry::DeviceRegistration* registration;
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
&registration));
// We can always *compile* resource operations, stateful RNGs and dummy ops,
// even if we are sometimes unable to auto-cluster them.
RecursiveCompilabilityChecker::OperationFilter op_filter;
op_filter.allow_resource_ops_in_called_functions = true;
op_filter.allow_stack_ops = true;
op_filter.allow_tensor_array_ops = true;
op_filter.allow_stateful_rng_ops = true;
op_filter.allow_control_trigger = true;
op_filter.allow_eliding_assert_and_checknumerics_ops = true;
op_filter.allow_ops_producing_or_consuming_variant = true;
op_filter.allow_slow_ops = true;
op_filter.allow_inaccurate_ops = true;
RecursiveCompilabilityChecker checker{
op_filter, DeviceType{registration->compilation_device_name}};
if (!uncompilable_node_info) {
// We do not need uncompilable node info. Just return the result.
return checker.IsCompilableCall(ndef, flr);
}
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result =
checker.FindUncompilableNodes(ndef, flr);
uncompilable_node_info->swap(uncompilable_node_result);
return uncompilable_node_info->empty();
}
Status MarkForCompilationPass::Run(
const GraphOptimizationPassOptions& options) {
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
@ -1951,6 +1917,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"ParallelDynamicStitch",
"ParameterizedTruncatedNormal",
"PartitionedCall",
"Polygamma",
"PopulationCount",
"Qr",
"QuantizeAndDequantizeV2",
@ -2094,6 +2061,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"XlaSpmdShardToFullShape",
"XlaSvd",
"XlaWhile",
"Zeta",
"_Arg",
"_ArrayToList",
"_ListToArray",

View File

@ -50,14 +50,6 @@ class MarkForCompilationPass : public GraphOptimizationPass {
friend class MarkForCompilationPassTestHelper;
};
// Returns true iff 'ndef' is a call to a function that is compilable. A
// function is compilable iff every operator in the function body is
// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not
// null, we will populate 'uncompilable_node_info' with uncompilable node info.
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
RecursiveCompilabilityChecker::UncompilableNodesMap*
uncompilable_node_info = nullptr);
absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable();
namespace testing {

View File

@ -294,4 +294,12 @@ se::Stream* XlaDeviceContext::GetDeviceToDeviceStream() {
return device_to_device_stream(stream);
}
Status XlaDeviceContext::ThenExecute(Device* device,
stream_executor::Stream* stream,
std::function<void()> func) {
VLOG(2) << "XlaDeviceContext::ThenExecute";
stream->ThenDoHostCallback(std::move(func));
return Status::OK();
}
} // namespace tensorflow

View File

@ -86,6 +86,9 @@ class XlaDeviceContext : public DeviceContext {
// Returns a device-to-device stream, in round-robin fashion.
se::Stream* GetDeviceToDeviceStream();
Status ThenExecute(Device* device, stream_executor::Stream* stream,
std::function<void()> func) override;
private:
bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; }

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
@ -32,6 +31,44 @@ limitations under the License.
namespace tensorflow {
// Returns true iff 'ndef' is a call to a function that is compilable. A
// function is compilable iff every operator in the function body is
// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not
// null, we will populate 'uncompilable_node_info' with uncompilable node info.
static bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
RecursiveCompilabilityChecker::UncompilableNodesMap*
uncompilable_node_info) {
Device* device = flr->device();
const XlaOpRegistry::DeviceRegistration* registration;
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
&registration));
// We can always *compile* resource operations, stateful RNGs and dummy ops,
// even if we are sometimes unable to auto-cluster them.
RecursiveCompilabilityChecker::OperationFilter op_filter;
op_filter.allow_resource_ops_in_called_functions = true;
op_filter.allow_stack_ops = true;
op_filter.allow_tensor_array_ops = true;
op_filter.allow_stateful_rng_ops = true;
op_filter.allow_control_trigger = true;
op_filter.allow_eliding_assert_and_checknumerics_ops = true;
op_filter.allow_ops_producing_or_consuming_variant = true;
op_filter.allow_slow_ops = true;
op_filter.allow_inaccurate_ops = true;
RecursiveCompilabilityChecker checker{
op_filter, DeviceType{registration->compilation_device_name}};
if (!uncompilable_node_info) {
// We do not need uncompilable node info. Just return the result.
return checker.IsCompilableCall(ndef, flr);
}
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result =
checker.FindUncompilableNodes(ndef, flr);
uncompilable_node_info->swap(uncompilable_node_result);
return uncompilable_node_info->empty();
}
bool XlaKernelCreator::CanCreateKernel(
const FunctionLibraryRuntime& flr,
const std::shared_ptr<const NodeProperties>& props) const {

View File

@ -366,6 +366,20 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
}];
}
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
HLO_FpOrComplexTensor> {
let summary = "Sinh operation";
let description = [{
Returns `Sinh(operand)` element-wise.
$$
\sinh(x) = (e^x - e^-x) / 2 if |x| < 1
= e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
$$
}];
}
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", [],
HLO_FpOrComplexTensor> {
let summary = "Tan operation";

View File

@ -49,6 +49,45 @@ def : Pat<(HLOClient_AcosOp $input),
),
(HLO_ConstantLike<"M_PI"> $input))>;
// Express `sinh` as
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
def : Pat<(HLOClient_SinhOp $input),
(HLO_SelectOp
(HLO_CompareOp
(HLO_AbsOp $input),
(HLO_ConstantLike<"1"> $input),
HLO_COMPARISON_DIRECTION_LT
),
(HLO_DivOp
(HLO_SubOp
(HLO_ExpOp $input),
(HLO_ExpOp
(HLO_NegOp $input)
)
),
(HLO_ConstantLike<"2"> $input)
),
(HLO_SubOp
(HLO_ExpOp
(HLO_AddOp
$input,
(HLO_LogOp
(HLO_ConstantLike<"0.5"> $input)
)
)
),
(HLO_ExpOp
(HLO_SubOp
(HLO_LogOp
(HLO_ConstantLike<"0.5"> $input)
),
$input
)
)
)
)>;
// Express tan in MHLO dialect as
// tan(x) = sin(x) / cos(x).
def : Pat<(HLOClient_TanOp $input),

View File

@ -47,7 +47,8 @@ namespace {
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
// TODO(herhut): Generate these out of op definitions.
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) fn(TanOp) sep fn(AcosOp)
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
fn(TanOp) sep fn(AcosOp) sep fn(SinhOp)
template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {

View File

@ -357,11 +357,7 @@ ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
while_region.body()));
} else if (auto case_op = dyn_cast<CaseOp>(op)) {
llvm::SmallVector<FuncOp, 4> functions;
functions.reserve(case_op.branches().size());
for (auto branch : case_op.branches())
functions.emplace_back(SymbolTable::lookupNearestSymbolFrom<FuncOp>(
case_op, branch.cast<SymbolRefAttr>()));
case_op.get_branch_functions(functions);
AnalyzeFunctionalCaseOrIfOp(case_op, functions, backtrack_analysis);
} else if (auto if_op = dyn_cast<IfOp>(op)) {
AnalyzeFunctionalCaseOrIfOp(

View File

@ -653,9 +653,8 @@ Status MlirFunctionContext::Finalize(OutputList* outputs,
}
builder_.create<ReturnOp>(func_.getLoc(), ret_operands);
auto arg_types = llvm::to_vector<8>(body.getArgumentTypes());
auto result_types =
llvm::to_vector<8>(body.getTerminator()->getOperandTypes());
auto arg_types = body.getArgumentTypes();
auto result_types = body.getTerminator()->getOperandTypes();
func_.setType(FunctionType::get(arg_types, result_types, func_.getContext()));
*f = new MlirFunction(std::move(context_), std::move(module_), func_);
return Status::OK();

File diff suppressed because it is too large Load Diff

View File

@ -110,27 +110,37 @@ def TF_StackResource : TF_ResourceBase<"Stack">;
def TF_TensorArrayResource : TF_ResourceBase<"TensorArray">;
def TF_SummaryResource : TF_ResourceBase<"Summary">;
def TF_LookupTableResource : TF_ResourceBase<"LookupTable">;
def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">;
def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">;
def TF_VariableRead : MemRead<TF_VariableResource>;
def TF_StackRead : MemRead<TF_StackResource>;
def TF_TensorArrayRead : MemRead<TF_TensorArrayResource>;
def TF_LookupTableRead : MemRead<TF_LookupTableResource>;
def TF_DatasetSeedGeneratorRead : MemRead<TF_DatasetSeedGeneratorResource>;
def TF_DatasetMemoryCacheRead : MemRead<TF_DatasetMemoryCacheResource>;
def TF_VariableWrite : MemWrite<TF_VariableResource>;
def TF_StackWrite : MemWrite<TF_StackResource>;
def TF_TensorArrayWrite : MemWrite<TF_TensorArrayResource>;
def TF_SummaryWrite : MemWrite<TF_SummaryResource>;
def TF_LookupTableWrite : MemWrite<TF_LookupTableResource>;
def TF_DatasetSeedGeneratorWrite : MemWrite<TF_DatasetSeedGeneratorResource>;
def TF_DatasetMemoryCacheWrite : MemWrite<TF_DatasetMemoryCacheResource>;
def TF_VariableAlloc : MemAlloc<TF_VariableResource>;
def TF_StackAlloc : MemAlloc<TF_StackResource>;
def TF_TensorArrayAlloc : MemAlloc<TF_TensorArrayResource>;
def TF_SummaryAlloc : MemAlloc<TF_SummaryResource>;
def TF_LookupTableAlloc : MemAlloc<TF_LookupTableResource>;
def TF_DatasetSeedGeneratorAlloc : MemAlloc<TF_DatasetSeedGeneratorResource>;
def TF_DatasetMemoryCacheAlloc : MemAlloc<TF_DatasetMemoryCacheResource>;
def TF_StackFree : MemFree<TF_StackResource>;
def TF_TensorArrayFree : MemFree<TF_TensorArrayResource>;
def TF_SummaryFree : MemFree<TF_SummaryResource>;
def TF_DatasetSeedGeneratorFree : MemFree<TF_DatasetSeedGeneratorResource>;
def TF_DatasetMemoryCacheFree : MemFree<TF_DatasetMemoryCacheResource>;
//===----------------------------------------------------------------------===//
// TensorFlow op definitions
@ -172,109 +182,19 @@ class TF_TensorFlowType <string name, string description> :
"TensorFlow " # description # " type">,
BuildableType<"getType<mlir::TF::" # name # "Type>()">;
//===----------------------------------------------------------------------===//
// Integer types
// TODO(mgester) shouldn't this be SignedIntOfWidths?
def TF_I32Or64 : SignlessIntOfWidths<[32, 64]>;
def TF_I32OrI64Tensor : TensorOf<[TF_I32Or64]>;
def TF_Uint8 : UI<8>;
def TF_Uint8Tensor : TensorOf<[TF_Uint8]>;
def TF_Uint16 : UI<16>;
def TF_Uint16Tensor : TensorOf<[TF_Uint16]>;
def TF_Uint32 : UI<32>;
def TF_Uint32Tensor : TensorOf<[TF_Uint32]>;
def TF_Uint64 : UI<64>;
def TF_Uint64Tensor : TensorOf<[TF_Uint64]>;
// Any unsigned integer type
def TF_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>;
// Any signed integer type
// TODO(mgester) shouldn't this be SignedIntOfWidths?
def TF_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>;
// Any integer type
def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">;
// Any integer tensor types
def TF_IntTensor : TensorOf<[TF_Int]>;
//===----------------------------------------------------------------------===//
// Quantized types
def TF_Qint8 : TF_TensorFlowType<"Qint8", "qint8">;
def TF_Qint16 : TF_TensorFlowType<"Qint16", "qint16">;
def TF_Qint32 : TF_TensorFlowType<"Qint32", "qint32">;
def TF_Quint8 : TF_TensorFlowType<"Quint8", "quint8">;
def TF_Quint16 : TF_TensorFlowType<"Quint16", "quint16">;
// Any quantized type
def TF_Quantized : AnyTypeOf<[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8,
TF_Quint16], "quantized">;
//===----------------------------------------------------------------------===//
// Floating-point types
def TF_F32Or64 : FloatOfWidths<[32, 64]>;
def TF_F32OrF64Tensor : TensorOf<[TF_F32Or64]>;
def TF_Float : AnyTypeOf<[F16, F32, F64, BF16], "floating-point">;
// Any floating-point tensor types
def TF_FpTensor : TensorOf<[TF_Float]>;
//===----------------------------------------------------------------------===//
// Complex types
// TODO(suderman): Remove TF_Complex64 and use a standard ops declaration, along
// with the associated cleanup.
def TF_Complex64 : Complex<F<32>>;
def TF_Complex64Tensor : TensorOf<[TF_Complex64]>;
def TF_Complex128 : Complex<F<64>>;
def TF_Complex128Tensor : TensorOf<[TF_Complex128]>;
def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">;
def TF_ComplexTensor : TensorOf<[TF_Complex]>;
//===----------------------------------------------------------------------===//
// String/variant/resource types
def TF_Str : TF_TensorFlowType<"String", "string">;
def TF_StrTensor : TensorOf<[TF_Str]>;
def TF_Variant : TF_TensorFlowType<"Variant", "variant">;
def TF_VariantTensor : TensorOf<[TF_Variant]>;
def TF_Resource : TF_TensorFlowType<"Resource", "resource">;
def TF_ResourceTensor : TensorOf<[TF_Resource]>;
//===----------------------------------------------------------------------===//
// Reference types
// Float reference types
def TF_F16Ref : TF_TensorFlowType<"HalfRef", "f16ref">;
def TF_F32Ref : TF_TensorFlowType<"FloatRef", "f32ref">;
def TF_F64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">;
def TF_Float16Ref : TF_TensorFlowType<"HalfRef", "f16ref">;
def TF_Float32Ref : TF_TensorFlowType<"FloatRef", "f32ref">;
def TF_Float64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">;
def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">;
// Any float reference type
def TF_FloatRef : AnyTypeOf<[TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_Bfloat16Ref],
"floating-point reference">;
// Complex reference types
def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">;
def TF_Complex128Ref : TF_TensorFlowType<"Complex128Ref", "complex128ref">;
// Any complex reference type
def TF_ComplexRef : AnyTypeOf<[TF_Complex64Ref, TF_Complex128Ref], "complex reference">;
// Integer reference types
def TF_Int8Ref : TF_TensorFlowType<"Int8Ref", "i8ref">;
def TF_Int16Ref : TF_TensorFlowType<"Int16Ref", "i16ref">;
@ -286,17 +206,6 @@ def TF_Uint16Ref : TF_TensorFlowType<"Uint16Ref", "ui16ref">;
def TF_Uint32Ref : TF_TensorFlowType<"Uint32Ref", "ui32ref">;
def TF_Uint64Ref : TF_TensorFlowType<"Uint64Ref", "ui64ref">;
// Any signed integer reference type
def TF_SIntRef : AnyTypeOf<[TF_Int8Ref, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref],
"signed integer reference">;
// Any unsigned integer reference type
def TF_UIntRef : AnyTypeOf<[TF_Uint8Ref, TF_Uint16Ref, TF_Uint32Ref,
TF_Uint64Ref], "unsigned integer reference">;
// Any integer reference type
def TF_IntRef : AnyTypeOf<[TF_SIntRef, TF_UIntRef], "integer reference">;
// Quantized reference types
def TF_Qint8Ref : TF_TensorFlowType<"Qint8Ref", "qint8ref">;
def TF_Qint16Ref : TF_TensorFlowType<"Qint16Ref", "qint16ref">;
@ -304,54 +213,152 @@ def TF_Qint32Ref : TF_TensorFlowType<"Qint32Ref", "qint32ref">;
def TF_Quint8Ref : TF_TensorFlowType<"Quint8Ref", "quint8ref">;
def TF_Quint16Ref : TF_TensorFlowType<"Quint16Ref", "quint16ref">;
// Any quantized reference type
def TF_QuantizedRef : AnyTypeOf<[TF_Qint8Ref, TF_Qint16Ref, TF_Qint32Ref,
TF_Quint8Ref, TF_Quint16Ref], "quantized reference">;
// Other reference types
def TF_BoolRef : TF_TensorFlowType<"BoolRef", "boolref">;
def TF_ResourceRef : TF_TensorFlowType<"ResourceRef", "resourceref">;
def TF_StrRef : TF_TensorFlowType<"StringRef", "stringref">;
def TF_VariantRef : TF_TensorFlowType<"VariantRef", "variantref">;
// Reference tensor types
def TF_FpRefTensor : TensorOf<[TF_FloatRef]>;
def TF_I32OrI64RefTensor : TensorOf<[TF_Int32Ref, TF_Int64Ref]>;
//===----------------------------------------------------------------------===//
// Integer types (including corresponding reference types)
def TF_Bool : AnyTypeOf<[I<1>, TF_BoolRef], "bool">;
def TF_Int8 : AnyTypeOf<[I8, TF_Int8Ref], "8-bit integer">;
def TF_Int16 : AnyTypeOf<[I16, TF_Int16Ref], "16-bit integer">;
def TF_Int32 : AnyTypeOf<[I32, TF_Int32Ref], "32-bit integer">;
def TF_Int64 : AnyTypeOf<[I64, TF_Int64Ref], "64-bit integer">;
def TF_I32OrI64 : AnyTypeOf<[I32, I64, TF_Int32Ref, TF_Int64Ref],
"32/64-bit signed integer">;
def TF_Uint8 : AnyTypeOf<[UI<8>, TF_Uint8Ref], "8-bit unsigned integer">;
def TF_Uint16 : AnyTypeOf<[UI<16>, TF_Uint16Ref], "16-bit unsigned integer">;
def TF_Uint32 : AnyTypeOf<[UI<32>, TF_Uint32Ref], "32-bit unsigned integer">;
def TF_Uint64 : AnyTypeOf<[UI<64>, TF_Uint64Ref], "64-bit unsigned integer">;
// Any unsigned integer type
def TF_UInt : AnyTypeOf<[TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64],
"unsigned integer">;
// Any signed integer type
def TF_SInt : AnyTypeOf<[TF_Int8, TF_Int16, TF_Int32, TF_Int64],
"signed integer">;
// Any integer type
def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">;
// Tensor types
def TF_BoolTensor : TensorOf<[TF_Bool]>;
def TF_IntTensor : TensorOf<[TF_Int]>;
def TF_Int8Tensor : TensorOf<[TF_Int8]>;
def TF_Int16Tensor : TensorOf<[TF_Int16]>;
def TF_Int32Tensor : TensorOf<[TF_Int32]>;
def TF_Int64Tensor : TensorOf<[TF_Int64]>;
def TF_I32OrI64Tensor : TensorOf<[TF_I32OrI64]>;
def TF_Uint8Tensor : TensorOf<[TF_Uint8]>;
def TF_Uint16Tensor : TensorOf<[TF_Uint16]>;
def TF_Uint32Tensor : TensorOf<[TF_Uint32]>;
def TF_Uint64Tensor : TensorOf<[TF_Uint64]>;
//===----------------------------------------------------------------------===//
// Quantized types (including corresponding reference types)
def TF_Qint8 : AnyTypeOf<
[TF_TensorFlowType<"Qint8", "qint8">, TF_Qint8Ref],
"8-bit quantized integer">;
def TF_Qint16 : AnyTypeOf<
[TF_TensorFlowType<"Qint16", "qint16">, TF_Qint16Ref],
"16-bit quantized integer">;
def TF_Qint32 : AnyTypeOf<
[TF_TensorFlowType<"Qint32", "qint32">, TF_Qint32Ref],
"32-bit quantized integer">;
def TF_Quint8 : AnyTypeOf<
[TF_TensorFlowType<"Quint8", "quint8">, TF_Quint8Ref],
"8-bit quantized unsigned integer">;
def TF_Quint16 : AnyTypeOf<
[TF_TensorFlowType<"Quint16", "quint16">, TF_Quint16Ref],
"16-bit quantized unsigned integer">;
// Any quantized type
def TF_Quantized : AnyTypeOf<
[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, TF_Quint16], "quantized">;
//===----------------------------------------------------------------------===//
// Floating-point types (including corresponding reference types)
def TF_Float16 : AnyTypeOf<[F16, TF_Float16Ref], "16-bit float">;
def TF_Float32 : AnyTypeOf<[F32, TF_Float32Ref], "32-bit float">;
def TF_Float64 : AnyTypeOf<[F64, TF_Float64Ref], "64-bit float">;
def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">;
def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">;
def TF_Float : AnyTypeOf<
[TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16,
TF_Float16Ref, TF_Float32Ref, TF_Float64Ref, TF_Bfloat16Ref],
"floating-point">;
// Tensor types
def TF_FloatTensor : TensorOf<[TF_Float]>;
def TF_F32OrF64Tensor : TensorOf<[TF_F32OrF64]>;
def TF_Float16Tensor : TensorOf<[TF_Float16]>;
def TF_Float32Tensor : TensorOf<[TF_Float32]>;
def TF_Float64Tensor : TensorOf<[TF_Float64]>;
def TF_Bfloat16Tensor : TensorOf<[TF_Bfloat16]>;
//===----------------------------------------------------------------------===//
// Complex types (including corresponding reference types)
// TODO(suderman): Remove TF_Complex64 and use a standard ops declaration, along
// with the associated cleanup.
def TF_Complex64 : AnyTypeOf<[Complex<F<32>>, TF_Complex64Ref],
"64-bit complex">;
def TF_Complex128 : AnyTypeOf<[Complex<F<64>>, TF_Complex128Ref],
"128-bit complex">;
def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">;
// Tensor types
def TF_ComplexTensor : TensorOf<[TF_Complex]>;
def TF_Complex64Tensor : TensorOf<[TF_Complex64]>;
def TF_Complex128Tensor : TensorOf<[TF_Complex128]>;
//===----------------------------------------------------------------------===//
// String/variant/resource types (including corresponding reference types)
def TF_Str : AnyTypeOf<
[TF_TensorFlowType<"String", "str">, TF_StrRef], "string">;
def TF_StrTensor : TensorOf<[TF_Str]>;
def TF_Variant : AnyTypeOf<
[TF_TensorFlowType<"Variant", "var">, TF_VariantRef], "variant">;
def TF_VariantTensor : TensorOf<[TF_Variant]>;
def TF_Resource : AnyTypeOf<
[TF_TensorFlowType<"Resource", "res">, TF_ResourceRef], "resource">;
def TF_ResourceTensor : TensorOf<[TF_Resource]>;
//===----------------------------------------------------------------------===//
// Multi-category type constraints
def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32Or64]>;
def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32Or64]>;
// Any integer or floating-point tensor types
def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32OrF64]>;
def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32OrI64]>;
def TF_IntOrFpTensor : TensorOf<[TF_Int, TF_Float]>;
def TF_SintOrFpTensor : TensorOf<[TF_SInt, TF_Float]>;
def TF_FpOrComplexTensor : TensorOf<[TF_Float, TF_Complex]>;
def TF_Number : AnyTypeOf<[TF_Int, TF_Float, TF_Quantized, TF_Complex],
"number">;
def TF_NumberRef : AnyTypeOf<[TF_IntRef, TF_FloatRef, TF_QuantizedRef,
TF_ComplexRef], "number reference">;
def TF_Number : AnyTypeOf<
[TF_Int, TF_Float, TF_Quantized, TF_Complex], "number">;
def TF_NumberTensor : TensorOf<[TF_Number]>;
def TF_NumberRefTensor : TensorOf<[TF_NumberRef]>;
def TF_NumberNotQuantizedOrStr :
AnyTypeOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Str]>;
def TF_NumberNotQuantizedOrStrRef :
AnyTypeOf<[TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref, TF_StrRef]>;
def TF_NumberNotQuantizedOrStrTensor : TensorOf<[TF_NumberNotQuantizedOrStr]>;
//===----------------------------------------------------------------------===//
// Tensor and tensor element types
// Bool type
def TF_Bool : I<1>;
// Any tensor element type allowed in TensorFlow ops
// (see https://www.tensorflow.org/api_docs/python/tf/dtypes/DType)
def TF_ElementType : Type<Or<[TF_Float.predicate,

View File

@ -116,6 +116,24 @@ An n-way switch statement, implementing the following:
let verifier = [{
return Verify(*this);
}];
let extraClassDeclaration = [{
int num_branches() { return branches().size(); }
// Gets function corresponding branch # `index`.
FuncOp branch_function(int index) {
auto flat_sym_ref = branches()[index].cast<FlatSymbolRefAttr>();
return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, flat_sym_ref);
}
// Gets all branch functions.
void get_branch_functions(SmallVectorImpl<FuncOp> &functions) {
functions.reserve(num_branches());
for (int idx : llvm::seq<int>(0, num_branches()))
functions.push_back(branch_function(idx));
}
}];
}
def TF_CaseRegionOp : TF_Op<"CaseRegion",
@ -206,12 +224,12 @@ source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs:
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
I32Tensor:$source_target_pairs
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -231,7 +249,7 @@ element_shape: a shape compatible with that of elements in the list.
let arguments = (ins
TF_I32OrI64Tensor:$element_shape,
I32Tensor:$max_num_elements
TF_Int32Tensor:$max_num_elements
);
}
@ -424,7 +442,7 @@ def TF_ParseExampleOp : TF_Op<"ParseExample",
TF_StrTensor:$names,
Variadic<TF_StrTensor>:$sparse_keys,
Variadic<TF_StrTensor>:$dense_keys,
Variadic<TensorOf<[F32, I64, TF_Str]>>:$dense_defaults,
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_defaults,
TF_ShapeAttrArray:$dense_shapes,
I32ElementsAttr:$result_segment_sizes,
@ -432,10 +450,10 @@ def TF_ParseExampleOp : TF_Op<"ParseExample",
);
let results = (outs
Variadic<I64Tensor>:$sparse_indices, // len(sparse_types)
Variadic<TensorOf<[F32, I64, TF_Str]>>:$sparse_values, // len(sparse_types)
Variadic<I64Tensor>:$sparse_shapes, // len(sparse_types)
Variadic<TensorOf<[F32, I64, TF_Str]>>:$dense_values // len(Tdense)
Variadic<TF_Int64Tensor>:$sparse_indices, // len(sparse_types)
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$sparse_values, // len(sparse_types)
Variadic<TF_Int64Tensor>:$sparse_shapes, // len(sparse_types)
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_values // len(Tdense)
);
TF_DerivedOperandSizeAttr Nsparse = TF_DerivedOperandSizeAttr<2>;
@ -459,7 +477,7 @@ def TF_ParseExampleV2Op : TF_Op<"ParseExampleV2",
TF_StrTensor:$sparse_keys,
TF_StrTensor:$dense_keys,
TF_StrTensor:$ragged_keys,
Variadic<TensorOf<[F32, I64, TF_Str]>>:$dense_defaults,
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_defaults,
Confined<I64Attr, [IntMinValue<0>]>:$num_sparse,
TF_ShapeAttrArray:$dense_shapes,
@ -467,13 +485,13 @@ def TF_ParseExampleV2Op : TF_Op<"ParseExampleV2",
);
let results = (outs
Variadic<I64Tensor>:$sparse_indices, // len(sparse_types)
Variadic<TensorOf<[F32, I64, TF_Str]>>:$sparse_values, // len(sparse_types)
Variadic<I64Tensor>:$sparse_shapes, // len(sparse_types)
Variadic<TensorOf<[F32, I64, TF_Str]>>:$dense_values, // len(Tdense)
Variadic<TensorOf<[F32, I64, TF_Str]>>:$ragged_values, // len(ragged_value_types)
Variadic<TF_Int64Tensor>:$sparse_indices, // len(sparse_types)
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$sparse_values, // len(sparse_types)
Variadic<TF_Int64Tensor>:$sparse_shapes, // len(sparse_types)
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_values, // len(Tdense)
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$ragged_values, // len(ragged_value_types)
// = len(ragged_split_types)
Variadic<TensorOf<[I32, I64]>>:$ragged_row_splits // len(ragged_split_types)
Variadic<TensorOf<[TF_Int32, TF_Int64]>>:$ragged_row_splits // len(ragged_split_types)
// = len(ragged_value_types)
);
@ -735,7 +753,7 @@ element_dtype: the desired type of elements in the list.
let arguments = (ins
TF_I32OrI64Tensor:$element_shape,
I32Tensor:$num_elements
TF_Int32Tensor:$num_elements
);
}
@ -956,8 +974,8 @@ Creates a dataset that batches `batch_size` elements from `input_dataset`.
let arguments = (ins
TF_VariantTensor:$input_dataset,
I64Tensor:$batch_size,
I1Tensor:$drop_remainder,
TF_Int64Tensor:$batch_size,
TF_BoolTensor:$drop_remainder,
DefaultValuedAttr<BoolAttr, "false">:$parallel_copy,
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
@ -1006,9 +1024,9 @@ to `batch_size * num_parallel_batches` copies of `f` in parallel.
let arguments = (ins
TF_VariantTensor:$input_dataset,
Variadic<TF_Tensor>:$other_arguments,
I64Tensor:$batch_size,
I64Tensor:$num_parallel_calls,
I1Tensor:$drop_remainder,
TF_Int64Tensor:$batch_size,
TF_Int64Tensor:$num_parallel_calls,
TF_BoolTensor:$drop_remainder,
SymbolRefAttr:$f,
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
@ -1036,7 +1054,7 @@ def TF_ParallelMapDatasetOp : TF_Op<"ParallelMapDataset", [NoSideEffect]> {
let arguments = (ins
TF_VariantTensor:$input_dataset,
Variadic<TF_Tensor>:$other_arguments,
I32Tensor:$num_parallel_calls,
TF_Int32Tensor:$num_parallel_calls,
SymbolRefAttr:$f,
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
@ -1118,11 +1136,11 @@ This function is faster and numerically stabler than `bessel_i0(x)`.
}];
let arguments = (ins
TF_FpTensor:$x
TF_FloatTensor:$x
);
let results = (outs
TF_FpTensor:$y
TF_FloatTensor:$y
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -1139,11 +1157,11 @@ This function is faster and numerically stabler than `bessel_i1(x)`.
}];
let arguments = (ins
TF_FpTensor:$x
TF_FloatTensor:$x
);
let results = (outs
TF_FpTensor:$y
TF_FloatTensor:$y
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -1154,7 +1172,7 @@ def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> {
let arguments = (ins
Variadic<TF_Tensor>:$args,
I32Tensor:$device_ordinal,
TF_Int32Tensor:$device_ordinal,
SymbolRefAttr:$f,
DefaultValuedAttr<I64Attr, "0">:$autotuner_thresh
@ -1264,12 +1282,12 @@ def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastable
}];
let arguments = (ins
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref, TF_Uint32Ref]>:$x,
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref, TF_Uint32Ref]>:$y
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32]>:$x,
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32]>:$y
);
let results = (outs
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref, TF_Uint32Ref]>:$z
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -1289,12 +1307,12 @@ def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape,
}];
let arguments = (ins
TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$x,
TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$y
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Complex]>:$x,
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Complex]>:$y
);
let results = (outs
TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$z
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Complex]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -1310,12 +1328,12 @@ def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, TF
}];
let arguments = (ins
TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$x,
TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$y
TensorOf<[TF_Float, TF_Int16, TF_Int32, TF_Int64, TF_Uint8]>:$x,
TensorOf<[TF_Float, TF_Int16, TF_Int32, TF_Int64, TF_Uint8]>:$y
);
let results = (outs
TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$z
TensorOf<[TF_Float, TF_Int16, TF_Int32, TF_Int64, TF_Uint8]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -1333,12 +1351,12 @@ If `x` and `y` are reals, this will return the floating-point division.
}];
let arguments = (ins
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$x,
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$y
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8]>:$x,
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8]>:$y
);
let results = (outs
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$z
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -1362,12 +1380,12 @@ Both input and output have a range `(-inf, inf)`.
}];
let arguments = (ins
TensorOf<[TF_NumberNotQuantizedOrStr, TF_NumberNotQuantizedOrStrRef]>:$x,
TensorOf<[TF_NumberNotQuantizedOrStr, TF_NumberNotQuantizedOrStrRef]>:$y
TensorOf<[TF_NumberNotQuantizedOrStr]>:$x,
TensorOf<[TF_NumberNotQuantizedOrStr]>:$y
);
let results = (outs
TensorOf<[TF_NumberNotQuantizedOrStr, TF_NumberNotQuantizedOrStrRef]>:$z
TensorOf<[TF_NumberNotQuantizedOrStr]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -1383,13 +1401,13 @@ The generated values will have mean 0 and standard deviation 1.
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_VariableRead, TF_VariableWrite]>:$resource,
I64Tensor:$algorithm,
Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource,
TF_Int64Tensor:$algorithm,
TF_I32OrI64Tensor:$shape
);
let results = (outs
TF_FpTensor:$output
TF_FloatTensor:$output
);
TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>;
@ -1407,12 +1425,12 @@ deviations from the mean are dropped and re-picked.
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource,
I64Tensor:$algorithm,
TF_Int64Tensor:$algorithm,
TF_I32OrI64Tensor:$shape
);
let results = (outs
TF_FpTensor:$output
TF_FloatTensor:$output
);
TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>;
@ -1429,12 +1447,12 @@ lower bound 0 is included in the range, while the upper bound 1 is excluded.
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource,
I64Tensor:$algorithm,
TF_Int64Tensor:$algorithm,
TF_I32OrI64Tensor:$shape
);
let results = (outs
TF_FpTensor:$output
TF_FloatTensor:$output
);
TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>;
@ -1450,12 +1468,12 @@ The generated values are uniform integers covering the whole range of `dtype`.
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource,
I64Tensor:$algorithm,
TF_Int64Tensor:$algorithm,
TF_I32OrI64Tensor:$shape
);
let results = (outs
TensorOf<[I32, I64, TF_Uint32, TF_Uint64]>:$output
TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output
);
TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>;
@ -1480,14 +1498,14 @@ smaller than the range of the output (either `2^32` or `2^64`).
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource,
I64Tensor:$algorithm,
TF_Int64Tensor:$algorithm,
TF_I32OrI64Tensor:$shape,
TensorOf<[I32, I64, TF_Uint32, TF_Uint64]>:$minval,
TensorOf<[I32, I64, TF_Uint32, TF_Uint64]>:$maxval
TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$minval,
TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$maxval
);
let results = (outs
TensorOf<[I32, I64, TF_Uint32, TF_Uint64]>:$output
TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output
);
TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>;
@ -1560,8 +1578,8 @@ filename_suffix: Every event file's name is suffixed with this suffix.
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
TF_StrTensor:$logdir,
I32Tensor:$max_queue,
I32Tensor:$flush_millis,
TF_Int32Tensor:$max_queue,
TF_Int32Tensor:$flush_millis,
TF_StrTensor:$filename_suffix
);
@ -1647,10 +1665,10 @@ max_outputs: Max number of batch elements to generate audio for.
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
I64Tensor:$step,
TF_Int64Tensor:$step,
TF_StrTensor:$tag,
F32Tensor:$tensor,
F32Tensor:$sample_rate,
TF_Float32Tensor:$tensor,
TF_Float32Tensor:$sample_rate,
Confined<DefaultValuedAttr<I64Attr, "3">, [IntMinValue<1>]>:$max_outputs
);
@ -1669,7 +1687,7 @@ tensor: A scalar string of the serialized tf.GraphDef proto.
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
I64Tensor:$step,
TF_Int64Tensor:$step,
TF_StrTensor:$tensor
);
@ -1694,7 +1712,7 @@ values: Any shape. Values to use to build the histogram.
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
I64Tensor:$step,
TF_Int64Tensor:$step,
TF_StrTensor:$tag,
TF_IntOrFpTensor:$values
);
@ -1753,9 +1771,9 @@ bad_color: Color to use for pixels with non-finite values.
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
I64Tensor:$step,
TF_Int64Tensor:$step,
TF_StrTensor:$tag,
TensorOf<[F16, F32, TF_Uint8]>:$tensor,
TensorOf<[TF_Float16, TF_Float32, TF_Uint8]>:$tensor,
TF_Uint8Tensor:$bad_color,
Confined<DefaultValuedAttr<I64Attr, "3">, [IntMinValue<1>]>:$max_images
@ -1777,7 +1795,7 @@ tensor: A tensor holding one or more serialized `Summary` protobufs to write.
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
I64Tensor:$step,
TF_Int64Tensor:$step,
TF_StrTensor:$tensor
);
@ -1798,7 +1816,7 @@ value: Value for the summary.
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
I64Tensor:$step,
TF_Int64Tensor:$step,
TF_StrTensor:$tag,
TF_IntOrFpTensor:$value
);
@ -1822,7 +1840,7 @@ summary_metadata: Serialized SummaryMetadata protocol buffer containing
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
I64Tensor:$step,
TF_Int64Tensor:$step,
TF_Tensor:$tensor,
TF_StrTensor:$tag,
TF_StrTensor:$summary_metadata
@ -1833,7 +1851,6 @@ summary_metadata: Serialized SummaryMetadata protocol buffer containing
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>;
}
// TODO(b/168035831): Model dataset read.
def TF_InitializeTableFromDatasetOp : TF_Op<"InitializeTableFromDataset", []> {
let summary = "";
@ -1875,4 +1892,22 @@ Where to extract the key and value from a line is specified by `key_index` and
let results = (outs);
}
// TODO(b/168035831): Model filename read.
def TF_CacheDatasetV2Op : TF_Op<"CacheDatasetV2", []> {
let summary = "";
let arguments = (ins
TF_VariantTensor:$input_dataset,
TF_StrTensor:$filename,
Arg<TF_ResourceTensor, "", [TF_DatasetMemoryCacheRead, TF_DatasetMemoryCacheWrite]>:$cache,
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
);
let results = (outs
TF_VariantTensor:$handle
);
}
#endif // TF_OPS

View File

@ -546,8 +546,7 @@ LogicalResult FoldConstantCaseOp::matchAndRewrite(
if (!matchPattern(op.branch_index(), m_Constant(&branch))) return failure();
int index = *branch.getValues<int>().begin();
if (index < 0 || index >= op.branches().size())
index = op.branches().size() - 1;
if (index < 0 || index >= op.num_branches()) index = op.num_branches() - 1;
auto func = op.branches()[index].cast<SymbolRefAttr>();
auto empty = rewriter.getStringAttr("");

View File

@ -1032,6 +1032,7 @@ static LogicalResult Verify(SizeOp op) {
OpFoldResult SizeOp::fold(ArrayRef<Attribute> operands) {
ShapedType output_type = getType().cast<ShapedType>();
if (!output_type.hasRank()) return {};
ShapedType input_type = getOperand().getType().cast<ShapedType>();
if (!input_type.hasStaticShape()) return {};
int size = input_type.getNumElements();

View File

@ -43,6 +43,16 @@ struct LookupTable : ::mlir::SideEffects::Resource::Base<LookupTable> {
StringRef getName() final { return "LookupTable"; }
};
struct DatasetSeedGenerator
: ::mlir::SideEffects::Resource::Base<DatasetSeedGenerator> {
StringRef getName() final { return "DatasetSeedGenerator"; }
};
struct DatasetMemoryCache
: ::mlir::SideEffects::Resource::Base<DatasetMemoryCache> {
StringRef getName() final { return "DatasetMemoryCache"; }
};
} // namespace ResourceEffects
} // namespace TF
} // namespace mlir

View File

@ -74,6 +74,17 @@ func @ignore_embedding_ops() -> () {
return
}
// CHECK-LABEL: func @ignore_stack_ops
func @ignore_stack_ops(%arg0: tensor<i32>) -> () {
"tf_device.cluster"() ( {
// CHECK: "tf.StackV2"
// CHECK-NOT: _xla_outside_compilation
%0 = "tf.StackV2"(%arg0) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
tf_device.return
}) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
return
}
// CHECK-LABEL: func @op_string_result
func @op_string_result() -> tensor<i32> {
%0 = "tf_device.cluster"() ( {

View File

@ -575,4 +575,15 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
func @pcall_resource_result_func(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>> {
return %arg0 : tensor<*x!tf.resource<tensor<f32>>>
}
// Check that the fold for tf.Size does not crash with unranked output type.
// CHECK-LABEL: func @unranked_tf_size
func @unranked_tf_size() -> tensor<*xi32> {
%0 = "tf.Const"() {value = dense<[-1, 26]> : tensor<2xi32>} : () -> tensor<2xi32>
%add = "tf.AddV2"(%0, %0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<*xi32>
// CHECK: "tf.Size"
// CHECK-SAME: (tensor<2xi32>) -> tensor<i32>
%size = "tf.Size"(%add) {device = ""} : (tensor<*xi32>) -> tensor<*xi32>
return %size : tensor<*xi32>
}
}

View File

@ -242,7 +242,7 @@ func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<1000
func @testReshape(tensor<*xf32>, tensor<*xf32>) -> (tensor<100x100xf32>) {
^bb0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>):
%shape1 = constant dense<100.> : tensor<2xf32>
// expected-error @+1 {{must be tensor of 32/64-bit signless integer values}}
// expected-error @+1 {{must be tensor of 32/64-bit signed integer values}}
%r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<*xf32>, tensor<2xf32>) -> (tensor<100x100xf32>)
return %r1 : tensor<100x100xf32>
}
@ -1967,7 +1967,7 @@ func @testValidShape(tensor<1x32x32x16xf32>, tensor<*xf32>) -> (tensor<4xi32>, t
// -----
func @testShapeWrongResultElemType(%arg0: tensor<1x32x32x16xf32>) -> tensor<4xf32> {
// expected-error @+1 {{result #0 must be tensor of 32/64-bit signless integer values}}
// expected-error @+1 {{result #0 must be tensor of 32/64-bit signed integer values}}
%0 = "tf.Shape"(%arg0) : (tensor<1x32x32x16xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -2011,7 +2011,7 @@ func @testValidShapeN(%arg0 : tensor<1x32x32x16xf32>, %arg1 : tensor<*xf32>) ->
// -----
func @testShapeNWrongResultElemType(%arg0: tensor<1x32x32x16xf32>) -> tensor<4xf32> {
// expected-error @+1 {{result #1 must be tensor of 32/64-bit signless integer values}}
// expected-error @+1 {{result #1 must be tensor of 32/64-bit signed integer values}}
%0:2 = "tf.ShapeN"(%arg0, %arg0) : (tensor<1x32x32x16xf32>, tensor<1x32x32x16xf32>) -> (tensor<4xi32>, tensor<4xf32>)
return %0#1 : tensor<4xf32>
}
@ -2072,7 +2072,7 @@ func @testVariableShapeMultipleSubtypes(%arg0: tensor<*x!tf.resource<tensor<1x32
// -----
func @testVariableShapeWrongResultElemType(%arg0: tensor<*x!tf.resource<tensor<1x32x32x16xf32>>>) -> tensor<?xf32> {
// expected-error @+1 {{result #0 must be tensor of 32/64-bit signless integer values}}
// expected-error @+1 {{result #0 must be tensor of 32/64-bit signed integer values}}
%0 = "tf.VariableShape"(%arg0) : (tensor<*x!tf.resource<tensor<1x32x32x16xf32>>>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -2208,7 +2208,7 @@ func @testTranspose(tensor<2x3x4xf32>) -> tensor<3x2x4xf32> {
// Test invalid tf.Less
func @testLess(tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> {
^bb0(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>):
// expected-error @+1 {{op result #0 must be tensor of 1-bit signless integer values}}
// expected-error @+1 {{op result #0 must be tensor of bool values}}
%0 = "tf.Less"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
@ -2225,7 +2225,7 @@ func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<i32>) -> tensor<?xf32>
// tf.ConcatV2 with wrong 'axis' element type
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<f32>) -> tensor<?xf32> {
// expected-error @+1 {{operand #2 must be tensor of 32/64-bit signless integer values}}
// expected-error @+1 {{operand #2 must be tensor of 32/64-bit signed integer values}}
%0 = "tf.ConcatV2"(%arg, %arg, %axis) : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<f32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -2258,7 +2258,7 @@ func @testAll64(%arg0: tensor<2x2xi1>, %arg1: tensor<i64>) -> tensor<i1> {
// -----
func @testAllFloat(%arg0: tensor<2x2xi1>, %arg1: tensor<f32>) -> tensor<i1> {
// expected-error @+1 {{'tf.All' op operand #1 must be tensor of 32/64-bit signless integer values}}
// expected-error @+1 {{'tf.All' op operand #1 must be tensor of 32/64-bit signed integer values}}
%0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<f32>) -> tensor<i1>
return %0 : tensor<i1>
}
@ -2266,7 +2266,7 @@ func @testAllFloat(%arg0: tensor<2x2xi1>, %arg1: tensor<f32>) -> tensor<i1> {
// -----
func @testAllI32(%arg0: tensor<2x2xi32>, %arg1: tensor<f32>) -> tensor<i32> {
// expected-error @+1 {{'tf.All' op operand #0 must be tensor of 1-bit signless integer values}}
// expected-error @+1 {{'tf.All' op operand #0 must be tensor of bool values}}
%0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi32>, tensor<f32>) -> tensor<i32>
return %0 : tensor<i32>
}

View File

@ -2,29 +2,81 @@
// Tests that the pass can correctly transform a training loop with 2 replicas.
!tf_res_f32 = type tensor<*x!tf.resource<tensor<f32>>>
!tf_res_md_f32 = type tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> // Multi-dim f32
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
// CHECK-LABEL: func @main
func @main(%arg0: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
%arg2: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"}) {
// CHECK-SAME: %[[ARG0:.*]]: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
// CHECK-SAME: %[[ARG1:.*]]: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
// CHECK-SAME: %[[ARG2:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
// CHECK-SAME: %[[ARG3:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"})
func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"},
%arg1: !tf_res_f32 {tf.device = "/device:TPU:1"},
%arg2: !tf_res_md_f32 {tf.device = "/device:TPU:0"},
%arg3: !tf_res_md_f32 {tf.device = "/device:TPU:1"}) {
%0 = "tf.Const"() {value = dense<100> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[STATE0:.*]] = "tf.VarHandleOp"()
// CHECK-SAME: device = "/device:TPU:0"
// CHECK: %[[STATE1:.*]] = "tf.VarHandleOp"()
// CHECK-SAME: device = "/device:TPU:1"
// CHECK: %[[WHILE:.*]]:7 = "tf.While"(
// CHECK-SAME: %[[STATE0]], %[[STATE1]])
%1:5 = "tf.While"(%0, %arg0, %arg1, %arg2, %arg3)
{T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE",
"tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE",
"tfdtype$DT_RESOURCE"], body = @while_body_7560,
cond = @while_cond_7550, device = "", is_stateless = false}
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
// CHECK: %[[WHILE:.*]] = "tf.WhileRegion"(
%1 = "tf.WhileRegion"(%0) ( {
// Condition region
// CHECK: ^bb
// CHECK: "tf.Yield"
^bb0(%carg0: tensor<i32>):
%c0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%c1 = "tf.GreaterEqual"(%carg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf.Yield"(%c1) : (tensor<i1>) -> ()
}, {
// Body region
// CHECK: ^bb0
^bb0(%barg0: tensor<i32>):
%b0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%b1 = "tf.AddV2"(%barg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
%compile:2 = "tf_device.launch"() ( {
%b2:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
tf_device.return %b2#0, %b2#1 : tensor<!tf.string>, tensor<2x!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG0]], %[[ARG1]]] as %[[R0:.*]]: tensor<*x!tf.resource<tensor<f32>>>,
// CHECK-SAME: [%[[ARG2]], %[[ARG3]]] as %[[R1:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
// CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[R_STATE:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>>
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
%rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
[%arg2, %arg3] as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
// CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]])
%id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]])
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
// CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1)
"tf_device.launch"() ( {
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf.string>) -> ()
tf_device.return
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
tf_device.return %ret : tensor<i32>
}
// CHECK: "tf.Yield"
"tf.Yield"(%b1) : (tensor<i32>) -> ()
}) {device = "", is_stateless = false} : (tensor<i32>) -> (tensor<i32>)
// CHECK: %[[DEFAULT:.*]] = "tf.Const"()
// CHECK: tf_device.replicate
// CHECK-SAME: as %[[V0:.*]]: tensor<*x!tf.resource<tensor<f32>>>,
@ -37,165 +89,72 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
return
}
// CHECK-LABEL: func @while_body_7560
func @while_body_7560(%arg0: tensor<i32>,
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"})
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) {
// CHECK-SAME: (%[[ITER:.*]]: tensor<i32>,
// CHECK-SAME: %[[BODY_ARG1:.*]]: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
// CHECK-SAME: %[[BODY_ARG2:.*]]: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
// CHECK-SAME: %[[BODY_ARG3:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
// CHECK-SAME: %[[BODY_ARG4:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"},
// CHECK-SAME: %[[STATE_ARG0:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>> {tf.device = "/device:TPU:0"},
// CHECK-SAME: %[[STATE_ARG1:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>> {tf.device = "/device:TPU:1"})
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
%compile:2 = "tf_device.launch"() ( {
%2:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<2x!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[BODY_ARG1]], %[[BODY_ARG2]]] as %[[R0:.*]]: tensor<*x!tf.resource<tensor<f32>>>,
// CHECK-SAME: [%[[BODY_ARG3]], %[[BODY_ARG4]]] as %[[R1:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
// CHECK-SAME: [%[[STATE_ARG0]], %[[STATE_ARG1]]] as %[[R_STATE:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>>
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
%rep:2 = tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
[%arg3, %arg4] as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
// CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]])
%id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]])
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
// CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1)
"tf_device.launch"() ( {
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf.string>) -> ()
tf_device.return
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
tf_device.return %ret : tensor<i32>
}
return %1, %arg1, %arg2, %arg3, %arg4 : tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
}
// CHECK-LABEL: func @while_cond_7550
func @while_cond_7550(%arg0: tensor<i32>,
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"})
-> tensor<i1> {
%0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.GreaterEqual"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
return %1 : tensor<i1>
}
}
// -----
// Tests that the pass does not format variables with other uses.
!tf_res_f32 = type tensor<*x!tf.resource<tensor<f32>>>
!tf_res_md_f32 = type tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> // Multi-dim f32
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
// CHECK-LABEL: func @main
// CHECK-NOT: TPUReshardVariables
func @main(%arg0: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
%arg2: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"},
%arg4: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
%arg5: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"}) {
func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"},
%arg1: !tf_res_f32 {tf.device = "/device:TPU:1"},
%arg2: !tf_res_md_f32 {tf.device = "/device:TPU:0"},
%arg3: !tf_res_md_f32 {tf.device = "/device:TPU:1"},
%arg4: !tf_res_f32 {tf.device = "/device:TPU:1"}) {
%0 = "tf.Const"() {value = dense<100> : tensor<i32>} : () -> tensor<i32>
%1:7 = "tf.While"(%0, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
{body = @while_body_7560,
cond = @while_cond_7550, device = "", is_stateless = false}
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
%1 = "tf.WhileRegion"(%0) ( {
// Condition region
^bb0(%carg0: tensor<i32>):
%c0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%c1 = "tf.GreaterEqual"(%carg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf._UnknownOp1_"(%arg1) : (!tf_res_f32) -> ()
"tf.Yield"(%c1) : (tensor<i1>) -> ()
}, {
// Body region
^bb0(%barg0: tensor<i32>):
%b0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%b1 = "tf.AddV2"(%barg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%compile:2 = "tf_device.launch"() ( {
%b2:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
tf_device.return %b2#0, %b2#1 : tensor<!tf.string>, tensor<2x!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
%id0 = "tf.Identity"(%arg3) : (!tf_res_md_f32) -> !tf_res_md_f32
"tf._Unknown_"(%id0) : (!tf_res_md_f32) -> ()
%newvar = "tf._SomeOp"() : () -> !tf_res_f32
%rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: !tf_res_f32,
[%arg2, %arg3] as %arg31: !tf_res_md_f32,
[%newvar, %arg4] as %arg32 : !tf_res_f32)
{_mirrored_variable_indices = [0, 1, 2], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
// %arg30 is used in the cond function, %arg31 has other uses (%id0), and
// %arg32 is not a pass-through.
"tf_device.launch"() ( {
"tf.TPUExecuteAndUpdateVariables"(%arg30, %arg31, %arg32, %compile#1)
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
: (!tf_res_f32, !tf_res_md_f32, !tf_res_f32, tensor<2x!tf.string>) -> ()
tf_device.return
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
tf_device.return %ret : tensor<i32>
}
"tf.Yield"(%b1) : (tensor<i32>) -> ()
}) {device = "", is_stateless = false} : (tensor<i32>) -> (tensor<i32>)
return
}
// CHECK-LABEL: func @while_body_7560
// CHECK-NOT: TPUReshardVariables
func @while_body_7560(%arg0: tensor<i32>,
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"},
%arg5: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
%arg6: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"})
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>) {
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%compile:2 = "tf_device.launch"() ( {
%2:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<2x!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
%id0 = "tf.Identity"(%arg3) : (tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) -> tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
"tf._Unknown_"(%id0) : (tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) -> ()
%newvar = "tf._SomeOp"() : () -> tensor<*x!tf.resource<tensor<f32>>>
tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
[%arg3, %arg4] as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
[%newvar, %arg6] as %arg32: tensor<*x!tf.resource<tensor<f32>>>)
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
// %arg30 is used in the cond function, %arg31 has other uses (%id0), and
// %arg32 is not a pass-through.
"tf_device.launch"() ( {
"tf.TPUExecuteAndUpdateVariables"(%arg30, %arg31, %arg32, %compile#1)
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
tensor<*x!tf.resource<tensor<f32>>>, tensor<2x!tf.string>) -> ()
tf_device.return
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
tf_device.return
}
return %1, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6 : tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>
}
// CHECK-LABEL: func @while_cond_7550
func @while_cond_7550(%arg0: tensor<i32>,
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"},
%arg5: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
%arg6: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"})
-> tensor<i1> {
%0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.GreaterEqual"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf._UnknownOp1_"(%arg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> ()
return %1 : tensor<i1>
}
}
// -----
@ -203,81 +162,62 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
// Tests that the pass does not format variables when model parallelism is
// present.
!tf_res_f32 = type tensor<*x!tf.resource<tensor<f32>>>
!tf_res_md_f32 = type tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> // Multi-dim f32
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
// CHECK-LABEL: func @main
// CHECK-NOT: TPUReshardVariables
func @main(%arg0: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
%arg2: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"}) {
func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"},
%arg1: !tf_res_f32 {tf.device = "/device:TPU:1"},
%arg2: !tf_res_md_f32 {tf.device = "/device:TPU:0"},
%arg3: !tf_res_md_f32 {tf.device = "/device:TPU:1"}) {
%0 = "tf.Const"() {value = dense<100> : tensor<i32>} : () -> tensor<i32>
%1:5 = "tf.While"(%0, %arg0, %arg1, %arg2, %arg3)
{T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE",
"tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE",
"tfdtype$DT_RESOURCE"], body = @while_body_7560,
cond = @while_cond_7550, device = "", is_stateless = false}
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
return
}
// CHECK-LABEL: func @while_body_7560
// CHECK-NOT: TPUReshardVariables
func @while_body_7560(%arg0: tensor<i32>,
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"})
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) {
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%compile:2 = "tf_device.launch"() ( {
%2:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<2x!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
%rep:2 = tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
[%arg3, %arg4] as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
%id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
"tf_device.parallel_execute"() ({
"tf_device.launch"() ( {
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf.string>) -> ()
tf_device.return
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
tf_device.return
%1 = "tf.WhileRegion"(%0) ( {
// Condition region
^bb0(%carg0: tensor<i32>):
%c0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%c1 = "tf.GreaterEqual"(%carg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf.Yield"(%c1) : (tensor<i1>) -> ()
}, {
tf_device.return
}) {} : () -> ()
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
tf_device.return %ret : tensor<i32>
}
return %1, %arg1, %arg2, %arg3, %arg4 : tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
}
// CHECK-LABEL: func @while_cond_7550
func @while_cond_7550(%arg0: tensor<i32>,
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"})
-> tensor<i1> {
%0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.GreaterEqual"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
return %1 : tensor<i1>
// Body region
^bb0(%barg0: tensor<i32>):
%b0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%b1 = "tf.AddV2"(%barg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%compile:2 = "tf_device.launch"() ( {
%b2:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
tf_device.return %b2#0, %b2#1 : tensor<!tf.string>, tensor<2x!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
%rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
[%arg2, %arg3] as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
%id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
"tf_device.parallel_execute"() ({
"tf_device.launch"() ( {
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf.string>) -> ()
tf_device.return
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
tf_device.return
}, {
tf_device.return
}) {} : () -> ()
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
tf_device.return %ret : tensor<i32>
}
"tf.Yield"(%b1) : (tensor<i32>) -> ()
}) {device = "", is_stateless = false} : (tensor<i32>) -> (tensor<i32>)
return
}
}
@ -285,34 +225,83 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
// Tests that the pass can correctly transform a training loop with a packed
// variable.
!tf_res_f32 = type tensor<*x!tf.resource<tensor<f32>>>
!tf_res_md_f32 = type tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> // Multi-dim f32
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
// CHECK-LABEL: func @main
func @main(%arg0: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
%arg2: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:COMPOSITE:0"}) {
// CHECK-SAME: %[[ARG0:.*]]: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
// CHECK-SAME: %[[ARG1:.*]]: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
// CHECK-SAME: %[[ARG2:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:COMPOSITE:0"})
func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"},
%arg1: !tf_res_f32 {tf.device = "/device:TPU:1"},
%arg2: !tf_res_md_f32 {tf.device = "/device:COMPOSITE:0"}) {
%0 = "tf.Const"() {value = dense<100> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[STATE0:.*]] = "tf.VarHandleOp"()
// CHECK-SAME: device = "/device:TPU:0"
// CHECK: %[[STATE1:.*]] = "tf.VarHandleOp"()
// CHECK-SAME: device = "/device:TPU:1"
// CHECK: %[[WHILE:.*]]:6 = "tf.While"(
// CHECK-SAME: %[[STATE0]], %[[STATE1]])
%1:4 = "tf.While"(%0, %arg0, %arg1, %arg2)
{T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE",
"tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE"],
body = @while_body_7560,
cond = @while_cond_7550, device = "", is_stateless = false}
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
// CHECK: %[[WHILE:.*]] = "tf.WhileRegion"(
%1 = "tf.WhileRegion"(%0) ( {
// Condition region
// CHECK: ^bb
// CHECK: "tf.Yield"
^bb0(%carg0: tensor<i32>):
%c0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%c1 = "tf.GreaterEqual"(%carg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf.Yield"(%c1) : (tensor<i1>) -> ()
}, {
// Body region
// CHECK: ^bb0
^bb0(%barg0: tensor<i32>):
%b0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%b1 = "tf.AddV2"(%barg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
%compile:2 = "tf_device.launch"() ( {
%b2:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
tf_device.return %b2#0, %b2#1 : tensor<!tf.string>, tensor<2x!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG0]], %[[ARG1]]] as %[[R0:.*]]: tensor<*x!tf.resource<tensor<f32>>>,
// CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[R_STATE:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>>
// CHECK-SAME: %[[ARG2]] as %[[R1:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
%rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
%arg2 as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
{_mirrored_variable_indices = [0, 1], _packed_input_indices = [1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
// CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]])
%id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]])
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
// CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1)
"tf_device.launch"() ( {
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf.string>) -> ()
tf_device.return
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
tf_device.return %ret : tensor<i32>
}
// CHECK: "tf.Yield"
"tf.Yield"(%b1) : (tensor<i32>) -> ()
}) {device = "", is_stateless = false} : (tensor<i32>) -> (tensor<i32>)
// CHECK: %[[DEFAULT:.*]] = "tf.Const"()
// CHECK: tf_device.replicate
// CHECK-SAME: as %[[V0:.*]]: tensor<*x!tf.resource<tensor<f32>>>,
// CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[STATE:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>>,
// CHECK-SAME: as %[[V1:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
// CHECK-SAME: [%[[ARG0]], %[[ARG1]]] as %[[V0:.*]]: tensor<*x!tf.resource<tensor<f32>>>,
// CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[STATE:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>>
// CHECK-SAME: %[[ARG2]] as %[[V1:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUReshardVariables"(%[[V0]], %[[V1]], %[[DEFAULT]], %[[STATE]])
@ -320,70 +309,4 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
return
}
// CHECK-LABEL: func @while_body_7560
func @while_body_7560(%arg0: tensor<i32>,
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:COMPOSITE:0"})
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) {
// CHECK-SAME: (%[[ITER:.*]]: tensor<i32>,
// CHECK-SAME: %[[BODY_ARG1:.*]]: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
// CHECK-SAME: %[[BODY_ARG2:.*]]: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
// CHECK-SAME: %[[BODY_ARG3:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:COMPOSITE:0"},
// CHECK-SAME: %[[STATE_ARG0:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>> {tf.device = "/device:TPU:0"},
// CHECK-SAME: %[[STATE_ARG1:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>> {tf.device = "/device:TPU:1"})
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
%compile:2 = "tf_device.launch"() ( {
%2:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<2x!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[BODY_ARG1]], %[[BODY_ARG2]]] as %[[R0:.*]]: tensor<*x!tf.resource<tensor<f32>>>,
// CHECK-SAME: [%[[STATE_ARG0]], %[[STATE_ARG1]]] as %[[R_STATE:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>>,
// CHECK-SAME: %[[BODY_ARG3]] as %[[R1:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
%rep:2 = tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
%arg3 as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
{_mirrored_variable_indices = [0, 1], _packed_input_indices = [1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
// CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]])
%id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]])
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
// CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1)
"tf_device.launch"() ( {
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf.string>) -> ()
tf_device.return
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
tf_device.return %ret : tensor<i32>
}
return %1, %arg1, %arg2, %arg3 : tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
}
// CHECK-LABEL: func @while_cond_7550
func @while_cond_7550(%arg0: tensor<i32>,
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:COMPOSITE:0"})
-> tensor<i1> {
%0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.GreaterEqual"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
return %1 : tensor<i1>
}
}

View File

@ -120,8 +120,8 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
pm.addNestedPass<FuncOp>(CreateTPUParallelExecuteSinkResourceWritePass());
pm.addNestedPass<FuncOp>(CreateTPUMergeVariablesWithExecutePass());
pm.addNestedPass<FuncOp>(CreateTPUColocateCompositeResourceOps());
pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
pm.addPass(CreateTPUVariableReformattingPass());
pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
}
void CreateTPUBridgePipelineV1(OpPassManager &pm) {

View File

@ -217,7 +217,7 @@ def : Pat<(TF_RsqrtGradOp $lhs, $rhs),
// TODO(hinsu): Support complex input types.
def LowerTanhGradOp :
Pat<(TF_TanhGradOp TF_FpTensor:$y, TF_FpTensor:$dy),
Pat<(TF_TanhGradOp TF_FloatTensor:$y, TF_FloatTensor:$dy),
(TF_MulOp $dy,
(TF_SubOp (TF_ConstOp (GetScalarOfType<1> $y)),
(TF_SquareOp $y)))>;

View File

@ -77,6 +77,47 @@ void AddRewrittenEmbeddingOps(MLIRContext* context,
TF::SendTPUEmbeddingGradientsOp::getOperationName(), context));
}
// Stack, TensorList and TensorArray ops are rewritten during the second phase
// of the bridge (compilation of TPUCompile op). They would not match any
// legalization/canonicalization pattern and have to be manually added to the
// list of supported ops.
void AddRewrittenCompositeOps(MLIRContext* context,
llvm::DenseSet<OperationName>* supported_ops) {
#define GET_OPERATION_NAME(op) OperationName(op::getOperationName(), context)
llvm::SmallDenseSet<OperationName, 32> allowlist_ops = {
// Stack ops.
GET_OPERATION_NAME(TF::StackV2Op),
GET_OPERATION_NAME(TF::StackPushV2Op),
GET_OPERATION_NAME(TF::StackPopV2Op),
// Tensor Array ops.
GET_OPERATION_NAME(TF::TensorArrayV3Op),
GET_OPERATION_NAME(TF::TensorArrayReadV3Op),
GET_OPERATION_NAME(TF::TensorArrayWriteV3Op),
GET_OPERATION_NAME(TF::TensorArrayConcatV3Op),
GET_OPERATION_NAME(TF::TensorArraySplitV3Op),
GET_OPERATION_NAME(TF::TensorArraySizeV3Op),
GET_OPERATION_NAME(TF::TensorArrayGradV3Op),
GET_OPERATION_NAME(TF::TensorArrayGatherV3Op),
GET_OPERATION_NAME(TF::TensorArrayScatterV3Op),
GET_OPERATION_NAME(TF::TensorListFromTensorOp),
// Tensor List Ops.
GET_OPERATION_NAME(TF::EmptyTensorListOp),
GET_OPERATION_NAME(TF::TensorListReserveOp),
GET_OPERATION_NAME(TF::TensorListFromTensorOp),
GET_OPERATION_NAME(TF::TensorListPushBackOp),
GET_OPERATION_NAME(TF::TensorListPopBackOp),
GET_OPERATION_NAME(TF::TensorListGetItemOp),
GET_OPERATION_NAME(TF::TensorListSetItemOp),
GET_OPERATION_NAME(TF::TensorListLengthOp),
GET_OPERATION_NAME(TF::TensorListElementShapeOp),
GET_OPERATION_NAME(TF::TensorListGatherOp),
GET_OPERATION_NAME(TF::TensorListScatterIntoExistingListOp),
};
#undef GET_OPERATION_NAME
supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());
}
bool HasStringOperand(Operation& op) {
for (auto operand : op.getOperands()) {
if (getElementTypeOrSelf(operand).isa<TF::StringType>()) return true;
@ -201,6 +242,7 @@ void MarkOpsForOutsideCompilation::runOnOperation() {
}
AddSupportedControlFlowOps(module.getContext(), &supported_ops);
AddRewrittenEmbeddingOps(module.getContext(), &supported_ops);
AddRewrittenCompositeOps(module.getContext(), &supported_ops);
auto result = module.walk([&](tf_device::ClusterOp cluster) {
// Only if `allow_soft_placement` attribute is true should we mark ops

View File

@ -180,8 +180,7 @@ tf_executor::IslandOp CreateInputBarrierIsland(
// Create YieldOp for the new input sink island.
builder->setInsertionPointToEnd(&input_sink_island.GetBody());
builder->create<tf_executor::YieldOp>(island_op.getLoc(),
llvm::to_vector<8>(island_inputs));
builder->create<tf_executor::YieldOp>(island_op.getLoc(), island_inputs);
return input_sink_island;
}

View File

@ -501,9 +501,8 @@ void RegionResourceHoister::ReplaceOpWithNewOp() {
OpBuilder builder(op_);
// Clone ths old operation but with new result types.
Operation* new_op = Operation::create(
op_->getLoc(), op_->getName(), new_result_types,
llvm::to_vector<4>(op_->getOperands()), op_->getAttrs(),
llvm::to_vector<4>(op_->getSuccessors()), op_->getNumRegions());
op_->getLoc(), op_->getName(), new_result_types, op_->getOperands(),
op_->getAttrs(), op_->getSuccessors(), op_->getNumRegions());
builder.insert(new_op);
// Move regions to the new op.
@ -1224,14 +1223,11 @@ LogicalResult HoistForControlFlow(
return failure();
} else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {
SmallVector<FuncOp, 4> branch_functions;
branch_functions.reserve(case_op.branches().size());
for (const Attribute& branch : case_op.branches()) {
FuncOp func =
module.lookupSymbol<FuncOp>(branch.cast<FlatSymbolRefAttr>());
case_op.get_branch_functions(branch_functions);
for (FuncOp func : branch_functions) {
// Recursively handle the nested control flow.
HoistForControlFlow(&func.front(), module,
lifted_partitioned_call_callees);
branch_functions.push_back(func);
}
if (failed(HandleCaseOrIfOp(case_op, branch_functions))) return failure();
} else if (auto call_op = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {

View File

@ -86,9 +86,8 @@ void EliminateUnusedResults(
// Rebuild the new operation with lesser number of results.
OpBuilder builder(op);
Operation *new_op = Operation::create(
op->getLoc(), op->getName(), new_result_types,
llvm::to_vector<4>(op->getOperands()), op->getAttrs(),
llvm::to_vector<4>(op->getSuccessors()), op->getNumRegions());
op->getLoc(), op->getName(), new_result_types, op->getOperands(),
op->getAttrs(), op->getSuccessors(), op->getNumRegions());
builder.insert(new_op);
// Move region bodies to the new operation.
@ -415,11 +414,7 @@ LogicalResult CleanupAndCanonicalize(Operation *parent_op) {
op, {if_op.then_function(), if_op.else_function()}, if_op.input());
} else if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
SmallVector<FuncOp, 4> branches;
for (Attribute branch : case_op.branches()) {
auto sym = branch.cast<FlatSymbolRefAttr>();
branches.push_back(
SymbolTable::lookupNearestSymbolFrom<FuncOp>(op, sym));
}
case_op.get_branch_functions(branches);
result = CanonicalizeFunctionalIfCase(case_op, branches, case_op.input());
} else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
if (while_op.cond_function().walk(check_while_cond).wasInterrupted())

View File

@ -927,10 +927,7 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
{if_op.then_function(), if_op.else_function()}, max_iteration);
} else if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
SmallVector<FuncOp, 4> branches;
for (Attribute branch : case_op.branches()) {
auto sym = branch.cast<FlatSymbolRefAttr>();
branches.push_back(SymbolTable::lookupNearestSymbolFrom<FuncOp>(op, sym));
}
case_op.get_branch_functions(branches);
return PropagateShapeToFunctions(module,
drop_begin(case_op.getOperandTypes(), 1),
branches, max_iteration);

View File

@ -139,8 +139,7 @@ void ModifyFunctionSignature(
handle_new_size_vars(func.getArguments().drop_front(original_arg_count));
}
func.setType(FunctionType::get(
new_input_types,
llvm::to_vector<8>(func.front().getTerminator()->getOperandTypes()),
new_input_types, func.front().getTerminator()->getOperandTypes(),
func.getContext()));
}

View File

@ -708,10 +708,7 @@ LogicalResult DecomposeTensorListOpsInternal(
}
} else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {
SmallVector<FuncOp, 2> branches;
for (auto branch_symbol : case_op.branches()) {
branches.push_back(module.lookupSymbol<FuncOp>(
branch_symbol.cast<FlatSymbolRefAttr>()));
}
case_op.get_branch_functions(branches);
if (failed(HandleCaseOrIfOp(case_op, branches, module, buffer_to_size,
decomposed_partitioned_call_callees))) {
return failure();

View File

@ -115,9 +115,8 @@ struct TPUSpaceToDepthPass
// Updates func argument type to have the updated input shape.
void UpdateFuncType(FuncOp func) {
auto arg_types = llvm::to_vector<8>(func.front().getArgumentTypes());
auto result_types =
llvm::to_vector<4>(func.front().getTerminator()->getOperandTypes());
auto arg_types = func.front().getArgumentTypes();
auto result_types = func.front().getTerminator()->getOperandTypes();
func.setType(FunctionType::get(arg_types, result_types, func.getContext()));
}

View File

@ -138,14 +138,17 @@ Value SkipIdentity(Value v, bool allow_other_use,
// Finds the formattable arguments of `execute` and annotates the metadata of
// `compile` to record these arguments. In addition, it returns a mapping from
// the formattable arguments of `execute` to the corresponding arguments of
// `while_op` (which should be passed through to `execute` via `replicate`). The
// the formattable arguments of `execute` to the corresponding operand of
// `replicate`. The
// entries in the mapping are sorted in the order of operands of `execute`.
llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4>
AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
TF::WhileOp while_op, tf_device::ReplicateOp replicate,
TF::WhileRegionOp while_op, tf_device::ReplicateOp replicate,
TF::TPUExecuteAndUpdateVariablesOp execute,
tf_device::LaunchOp compile_launch, FuncOp body, FuncOp cond) {
tf_device::LaunchOp compile_launch) {
Region& body = while_op.body();
Region& cond = while_op.cond();
llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4> mapping;
auto mirrored_variable_indices_attr =
replicate.getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
@ -204,39 +207,43 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
// We have found a mirrored variable which is an input to the replicated
// `execute`. Now find if this mirrored variable is a pass-through of while
// arguments.
llvm::SmallVector<Value, 4> while_args;
llvm::SmallVector<Value, 4> replicate_args;
for (int64_t i = 0; i < num_inputs; ++i) {
llvm::SmallPtrSet<Operation*, 4> skipped_identities;
auto replicate_operand = SkipIdentity(
replicate.GetReplicaOperandForBlockArgument(block_arg, i),
/*allow_other_use=*/false, &skipped_identities);
auto block_arg = replicate_operand.dyn_cast<BlockArgument>();
// To qualify for a valid pass-through mirrored variable, it must satisfy
// 1) it is the body's argument;
// 2) it has no other uses than `replicate`, the skipped identitiy ops,
// or the return;
// 3) the corresponding argument in the cond function has no uses.
if (!block_arg || block_arg.getOwner() != &body.front() ||
llvm::any_of(replicate_operand.getUsers(),
[&](Operation* user) {
return user != body.front().getTerminator() &&
skipped_identities.count(user) == 0 &&
user != replicate;
}) ||
!cond.getArgument(block_arg.getArgNumber()).use_empty()) {
while_args.clear();
// For region based control flow, the resource operand for the replicate
// should be a region capture. If this has any use other than the
// replicate op (within the body of the while) or the skipped identities,
// then do not apply the transformation to this variable.
bool is_region_capture =
replicate_operand.getParentRegion()->isProperAncestor(&body);
bool has_other_use_in_body =
llvm::any_of(replicate_operand.getUsers(), [&](Operation* user) {
// Ignore uses that are not in the while body or condition.
if (!body.isAncestor(user->getParentRegion()) &&
!cond.isAncestor(user->getParentRegion()))
return false;
// Within the body or cond, only uses in replicate and the skipped
// identities is allowed.
return user != replicate && skipped_identities.count(user) == 0;
});
if (!is_region_capture || has_other_use_in_body) {
replicate_args.clear();
break;
}
while_args.push_back(while_op.getOperand(block_arg.getArgNumber()));
replicate_args.push_back(replicate_operand);
}
if (while_args.empty()) continue;
if (replicate_args.empty()) continue;
// Now set the enable_xla_sharding field in the metadata to inform the
// compile op.
auto metadata_arg = metadata.mutable_args(it->second);
metadata_arg->set_enable_xla_sharding(
::tensorflow::tpu::TPUCompileMetadataProto_Arg::ALLOWED);
mapping.emplace_back(it->second, std::move(while_args));
mapping.emplace_back(it->second, std::move(replicate_args));
}
// Sort the mapping according to execute operand order.
llvm::sort(mapping, llvm::less_first());
@ -261,7 +268,8 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
// Adds a new replicated input to the replicate op.
tf_device::ReplicateOp AddInputsToReplicateOp(
tf_device::ReplicateOp replicate, ArrayRef<Value> new_inputs,
tf_device::ReplicateOp replicate,
MutableArrayRef<TF::VarHandleOp> new_inputs,
const llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>&
devices) {
int64_t num_replicas = replicate.n();
@ -293,7 +301,11 @@ tf_device::ReplicateOp AddInputsToReplicateOp(
new_packed_inputs.emplace_back(
replicate.GetReplicaOperandForBlockArgument(arg, /*replica=*/0));
}
new_replicated_inputs.emplace_back(new_inputs, new_inputs.front().getType());
SmallVector<Value, 4> new_input_values;
new_input_values.reserve(new_inputs.size());
for (auto var : new_inputs) new_input_values.push_back(var.resource());
new_replicated_inputs.emplace_back(new_input_values,
new_input_values.front().getType());
OpBuilder builder(replicate);
auto new_replicate = builder.create<tf_device::ReplicateOp>(
replicate.getLoc(), num_replicas, devices, new_replicated_inputs,
@ -319,58 +331,6 @@ tf_device::ReplicateOp AddInputsToReplicateOp(
return new_replicate;
}
// Adds the per-device state variables to the while-loop's inputs/outputs.
TF::WhileOp AddStateVarsToWhileOp(TF::WhileOp while_op, FuncOp body,
FuncOp cond,
ArrayRef<TF::VarHandleOp> state_vars) {
auto body_return = llvm::cast<ReturnOp>(body.front().back());
auto new_body_return_vals = llvm::to_vector<4>(body_return.getOperands());
auto new_while_operands = llvm::to_vector<4>(while_op.getOperands());
auto append_types = [&](ArrayRef<Type> types) {
auto new_types = llvm::to_vector<4>(types);
for (auto state_var : state_vars) {
new_types.push_back(state_var.resource().getType());
}
return new_types;
};
for (auto state_var : state_vars) {
body.front().addArgument(state_var.resource().getType());
cond.front().addArgument(state_var.resource().getType());
auto inner_arg = body.getArgument(body.front().getNumArguments() - 1);
new_body_return_vals.push_back(inner_arg);
new_while_operands.push_back(state_var.resource());
}
OpBuilder builder = OpBuilder::atBlockEnd(&body.front());
// Update return values.
builder.create<ReturnOp>(body_return.getLoc(), new_body_return_vals);
body_return.erase();
body.setType(FunctionType::get(append_types(body.getType().getInputs()),
append_types(body.getType().getResults()),
body.getContext()));
cond.setType(FunctionType::get(append_types(cond.getType().getInputs()),
cond.getType().getResults(),
cond.getContext()));
for (int64_t i = 0, end = state_vars.size(); i < end; ++i) {
int64_t arg_index = body.getNumArguments() - state_vars.size() + i;
TF::VarHandleOp state_var = state_vars[i];
auto device_attr = state_var.getAttr(kDeviceAttr);
if (device_attr) {
body.setArgAttr(arg_index, kFuncDeviceAttr, device_attr);
cond.setArgAttr(arg_index, kFuncDeviceAttr, device_attr);
}
}
builder.setInsertionPoint(while_op);
auto new_while_op = builder.create<TF::WhileOp>(
while_op.getLoc(),
append_types(llvm::to_vector<4>(while_op.getResultTypes())),
new_while_operands, while_op.getAttrs());
while_op.replaceAllUsesWith(
new_while_op.getResults().take_front(while_op.getNumResults()));
while_op.erase();
return new_while_op;
}
// Creates the per-device variables that represent the formatting state of each
// device.
llvm::SmallVector<TF::VarHandleOp, 4> CreateStateVars(
@ -421,8 +381,8 @@ void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op,
}
// Performs the transformation for a replicate op inside a while loop.
void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
MLIRContext* context) {
void HandleReplicateOp(TF::WhileRegionOp while_op,
tf_device::ReplicateOp replicate) {
int64_t num_replicas = replicate.n();
if (num_replicas == 1) return;
tf_device::LaunchOp execute_launch;
@ -452,13 +412,10 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
!llvm::isa<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front()))
return;
FuncOp body = while_op.body_function();
FuncOp cond = while_op.cond_function();
// Analyze the formattable inputs.
auto execute_arg_to_outer_args =
AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
while_op, replicate, execute, compile_launch, body, cond);
while_op, replicate, execute, compile_launch);
if (execute_arg_to_outer_args.empty()) return;
// Extract the replicated devices.
@ -489,16 +446,7 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
RankedTensorType::get({2}, TF::StringType::get(builder.getContext()));
auto state_vars =
CreateStateVars(devices, while_op.getLoc(), key_type, &builder);
while_op = AddStateVarsToWhileOp(while_op, body, cond, state_vars);
// Add the new while loop inputs to the replicate op inside the body.
int64_t new_while_operand_count = while_op.getNumOperands();
llvm::SmallVector<Value, 4> inner_state_vars;
for (int64_t i = new_while_operand_count - num_replicas;
i < new_while_operand_count; ++i) {
inner_state_vars.push_back(body.front().getArgument(i));
}
replicate = AddInputsToReplicateOp(replicate, inner_state_vars, devices);
replicate = AddInputsToReplicateOp(replicate, state_vars, devices);
// Build the reformat according to the compilation. Build it inside
// `replicate`.
llvm::SmallVector<Value, 8> reformat_operands;
@ -576,10 +524,9 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
void TPUVariableRuntimeReformattingPass::runOnOperation() {
auto module = getOperation();
module.walk([&](TF::WhileOp while_op) {
auto body = llvm::cast<FuncOp>(module.lookupSymbol(while_op.body()));
module.walk([&](TF::WhileRegionOp while_op) {
tf_device::ReplicateOp replicate;
body.walk([&](tf_device::ReplicateOp replicate_op) {
while_op.body().walk([&](tf_device::ReplicateOp replicate_op) {
if (replicate == nullptr) {
replicate = replicate_op;
return WalkResult::advance();
@ -592,7 +539,7 @@ void TPUVariableRuntimeReformattingPass::runOnOperation() {
// `tf_device.parallel_execute` op in the `tf_device.replicate` is present.
if (replicate &&
replicate.GetBody().getOps<tf_device::ParallelExecuteOp>().empty())
HandleReplicateOp(while_op, replicate, &getContext());
HandleReplicateOp(while_op, replicate);
});
}

View File

@ -137,7 +137,7 @@ void LowerCase(TF::CaseOp op, ModuleOp module) {
auto tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
// Create replica of input tuple for each branch
SmallVector<Value, 4> n_tuple_inputs(op.branches().size(), tuple_input);
SmallVector<Value, 4> n_tuple_inputs(op.num_branches(), tuple_input);
// Create the new case op with tuple inputs.
auto case_op =
@ -145,9 +145,8 @@ void LowerCase(TF::CaseOp op, ModuleOp module) {
n_tuple_inputs, op.branches().size());
// Import the regions for all branches.
for (unsigned i = 0; i < op.branches().size(); ++i) {
mlir::FuncOp branch_func = module.lookupSymbol<mlir::FuncOp>(
op.branches()[i].cast<SymbolRefAttr>());
for (unsigned i = 0; i < op.num_branches(); ++i) {
mlir::FuncOp branch_func = op.branch_function(i);
ImportXlaRegion(branch_func, &case_op.branches()[i], loc,
/*tuple_return=*/false);
}

View File

@ -586,6 +586,7 @@ foreach Mapping = [
[TF_RealOp, HLO_RealOp],
[TF_RsqrtOp, HLO_RsqrtOp],
[TF_SigmoidOp, HLO_LogisticOp],
[TF_SinhOp, HLOClient_SinhOp],
[TF_SinOp, HLO_SinOp],
[TF_SqrtOp, HLO_SqrtOp],
[TF_TanhOp, HLO_TanhOp],

View File

@ -436,8 +436,8 @@ Status LhloDialectEmitter::Initialize() {
}
}
FunctionType function_type = builder_.getFunctionType(
llvm::to_vector<8>(block->getArgumentTypes()), {});
FunctionType function_type =
builder_.getFunctionType(block->getArgumentTypes(), {});
func_op.setType(function_type);
func_op.setAllArgAttrs(args_attrs);

View File

@ -31,6 +31,7 @@ import six
from tensorflow.compiler.tests import xla_test
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import gradient_checker_v2
@ -54,6 +55,16 @@ def _igammac(a, x):
return math_ops.igammac(a, x)
@def_function.function(experimental_compile=True)
def _polygamma(n, x):
return math_ops.polygamma(n, x)
@def_function.function(experimental_compile=True)
def _zeta(a, q):
return math_ops.zeta(a, q)
# This is df/da / df/dx, where f = igamma.
def implicit_reparameterization_grad(a, x):
log_prob = math_ops.xlogy(a - 1., x) - math_ops.lgamma(a) - x
@ -136,6 +147,208 @@ class Log1pTest(xla_test.XLATestCase, parameterized.TestCase):
self._test_range(0., 3., dtype, rtol, atol, is_negative=False)
class ZetaTest(xla_test.XLATestCase, parameterized.TestCase):
def setUp(self):
if flags.FLAGS.vary_seed:
entropy = os.urandom(64)
if six.PY2:
answer = int(entropy.encode('hex'), 16)
else:
answer = int.from_bytes(entropy, 'big')
np.random.seed(answer % (2**32 - 1))
super(ZetaTest, self).setUp()
def adjust_tolerance_for_tpu(self, dtype, rtol, atol):
if self.device not in ['TPU']:
return rtol, atol
if dtype == np.float32:
return 2e-2, 1e-7
return 2e-4, 1e-20
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
def testBadValues(self):
q = np.random.uniform(low=0.3, high=20., size=[10])
with self.session() as sess:
with self.test_scope():
y = _zeta(np.float64(1.), q)
actual = sess.run(y)
# When x == 1, this is the Harmonic series.
self.assertTrue(np.all(np.isinf(actual)))
with self.session() as sess:
with self.test_scope():
y = _zeta(np.float64(0.1), q)
actual = sess.run(y)
# When x < 1, this is undefined.
self.assertTrue(np.all(np.isnan(actual)))
with self.session() as sess:
with self.test_scope():
y = _zeta([1., 1.1], [-1.1, -1.])
actual = sess.run(y)
# When q is negative, zeta is not defined
# if q is an integer or x is not an integer.
self.assertTrue(np.all(np.isinf(actual)))
@parameterized.parameters((np.float32, 1e-2, 1e-11),
(np.float64, 1e-4, 1e-30))
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
def testLargeXSmallQ(self, dtype, rtol, atol):
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64:
# TODO(b/165739664): Figure out why on TPU F64 Zeta sometimes returns
# infs.
self.skipTest(
'Skipping test because some F64 operations are numerically '
'unstable on TPU.')
x = np.random.uniform(low=100., high=200., size=[NUM_SAMPLES]).astype(dtype)
q = np.random.uniform(low=0.3, high=1., size=[NUM_SAMPLES]).astype(dtype)
expected_values = sps.zeta(x, q)
with self.session() as sess:
with self.test_scope():
y = _zeta(x, q)
actual = sess.run(y)
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
@parameterized.parameters((np.float32, 1e-2, 1e-11),
(np.float64, 1e-4, 1e-30))
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
def testSmallValues(self, dtype, rtol, atol):
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
# Test values near zero.
x = np.random.uniform(low=1.1, high=10., size=[NUM_SAMPLES]).astype(dtype)
q = np.random.uniform(
low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype)
expected_values = sps.zeta(x, q)
with self.session() as sess:
with self.test_scope():
actual = sess.run(_zeta(x, q))
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
@parameterized.parameters((np.float32, 1e-2, 1e-11),
(np.float64, 1e-4, 1e-30))
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
def testMediumValues(self, dtype, rtol, atol):
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
x = np.random.uniform(low=1.1, high=100., size=[NUM_SAMPLES]).astype(dtype)
q = np.random.uniform(low=1., high=1e1, size=[NUM_SAMPLES]).astype(dtype)
expected_values = sps.zeta(x, q)
with self.session() as sess:
with self.test_scope():
actual = sess.run(_zeta(x, q))
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
@parameterized.parameters((np.float32, 2e-2, 1e-5), (np.float64, 1e-4, 1e-30))
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
def testLargeValues(self, dtype, rtol, atol):
x = np.random.uniform(
low=100., high=int(1e3), size=[NUM_SAMPLES]).astype(dtype)
q = np.random.uniform(
low=1., high=int(1e1), size=[NUM_SAMPLES]).astype(dtype)
expected_values = sps.zeta(x, q)
with self.session() as sess:
with self.test_scope():
actual = sess.run(_zeta(x, q))
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
class PolygammaTest(xla_test.XLATestCase, parameterized.TestCase):
def setUp(self):
if flags.FLAGS.vary_seed:
entropy = os.urandom(64)
if six.PY2:
answer = int(entropy.encode('hex'), 16)
else:
answer = int.from_bytes(entropy, 'big')
np.random.seed(answer % (2**32 - 1))
super(PolygammaTest, self).setUp()
def adjust_tolerance_for_tpu(self, dtype, rtol, atol):
if self.device not in ['TPU']:
return rtol, atol
if dtype == np.float32:
return 2e-2, 1e-7
return 2e-4, 1e-20
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
def testBadValues(self):
x = np.random.uniform(low=0.3, high=20., size=[10])
with self.session() as sess:
with self.test_scope():
y = _polygamma(np.float64(-1.), x)
actual = sess.run(y)
# Not defined for negative numbers.
self.assertTrue(np.all(np.isnan(actual)))
with self.session() as sess:
with self.test_scope():
y = _polygamma(np.float64(0.1), x)
actual = sess.run(y)
# Not defined for non-integers.
self.assertTrue(np.all(np.isnan(actual)))
@parameterized.parameters((np.float32, 1e-2, 1e-11),
(np.float64, 1e-4, 1e-30))
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
def testRecoverDigamma(self, dtype, rtol, atol):
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64:
self.skipTest(
'Skipping test because some F64 operations are '
'numerically unstable on TPU.'
)
x = np.random.uniform(low=0.1, high=50., size=[NUM_SAMPLES]).astype(dtype)
expected_values = sps.digamma(x)
with self.session() as sess:
with self.test_scope():
y = _polygamma(dtype(0.), x)
actual = sess.run(y)
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
@parameterized.parameters((np.float32, 1e-2, 1e-11),
(np.float64, 1e-4, 1e-30))
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
def testSmallN(self, dtype, rtol, atol):
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
# Test values near zero.
n = np.random.randint(low=1, high=5, size=[NUM_SAMPLES]).astype(dtype)
x = np.random.uniform(
low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype)
expected_values = sps.polygamma(n, x)
with self.session() as sess:
with self.test_scope():
actual = sess.run(_polygamma(n, x))
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
@parameterized.parameters((np.float32, 1e-2, 1e-11),
(np.float64, 1e-4, 1e-30))
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
def testMediumLargeN(self, dtype, rtol, atol):
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
n = np.random.randint(low=5, high=10, size=[NUM_SAMPLES]).astype(dtype)
x = np.random.uniform(low=1., high=1e1, size=[NUM_SAMPLES]).astype(dtype)
expected_values = sps.polygamma(n, x)
with self.session() as sess:
with self.test_scope():
actual = sess.run(_polygamma(n, x))
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
class IgammaTest(xla_test.XLATestCase, parameterized.TestCase):
def setUp(self):

View File

@ -290,6 +290,21 @@ xla::XlaOp IgammacImpl(xla::XlaOp x, xla::XlaOp y,
XLA_MAKE_BINARY(Igammac, IgammacImpl(lhs, rhs, broadcast_helper));
xla::XlaOp PolygammaImpl(xla::XlaOp n, xla::XlaOp x,
const BCast& broadcast_helper) {
std::tie(n, x) = XlaBinaryOp::Broadcast(n, x, broadcast_helper);
return xla::Polygamma(n, x);
}
XLA_MAKE_BINARY(Polygamma, PolygammaImpl(lhs, rhs, broadcast_helper));
xla::XlaOp ZetaImpl(xla::XlaOp x, xla::XlaOp q, const BCast& broadcast_helper) {
std::tie(x, q) = XlaBinaryOp::Broadcast(x, q, broadcast_helper);
return xla::Zeta(x, q);
}
XLA_MAKE_BINARY(Zeta, ZetaImpl(lhs, rhs, broadcast_helper));
#undef XLA_MAKE_BINARY
class ApproximateEqualOp : public XlaOpKernel {

View File

@ -206,6 +206,8 @@ igamma = _broadcasting_binary_op(math_ops.igamma)
igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a)
random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad)
igammac = _broadcasting_binary_op(math_ops.igammac)
polygamma = _broadcasting_binary_op(math_ops.polygamma)
zeta = _broadcasting_binary_op(math_ops.zeta)
def _binary_op(fn):

View File

@ -1832,4 +1832,139 @@ XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x) {
});
}
XlaOp Polygamma(XlaOp n, XlaOp x) {
auto& builder = *x.builder();
auto doit = [](XlaOp n, XlaOp x, PrimitiveType type) -> XlaOp {
XlaOp n_plus_one = n + ScalarLike(n, 1.);
XlaOp sign =
(ScalarLike(n, 2.) * Rem(n, ScalarLike(n, 2.)) - ScalarLike(n, 1.));
const double nan = std::numeric_limits<double>::quiet_NaN();
XlaOp output = Select(Eq(n, ScalarLike(n, 0.)), Digamma(x),
sign * Exp(Lgamma(n_plus_one)) * Zeta(n_plus_one, x));
// Check that n is a natural number.
output = Select(Or(Ne(n, Floor(n)), Lt(n, ScalarLike(n, 0.))),
ScalarLike(n, nan), output);
return output;
};
return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto n_shape, builder.GetShape(n));
TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x));
if (n_shape != x_shape) {
return InvalidArgument(
"Arguments to Polygamma must have equal shapes and types; "
"got %s and %s",
n_shape.ToString(), x_shape.ToString());
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Zeta", x));
bool needs_upcast =
n_shape.element_type() == F16 || x_shape.element_type() == BF16;
if (needs_upcast) {
n = ConvertElementType(n, F32);
x = ConvertElementType(x, F32);
}
XlaOp result = doit(n, x, n_shape.element_type());
if (needs_upcast) {
result = ConvertElementType(result, n_shape.element_type());
}
return result;
});
}
XlaOp Zeta(XlaOp x, XlaOp q) {
auto& builder = *x.builder();
auto doit = [&builder](XlaOp x, XlaOp q, PrimitiveType type) -> XlaOp {
// (2k) ! / B_{2k}, where B_{2k} are the Bernoulli numbers.
// These are ordered in reverse.
static const std::array<double, 12> kZetaCoeffs{
-7.1661652561756670113e18,
1.8152105401943546773e17,
-4.5979787224074726105e15,
1.1646782814350067249e14,
-2.950130727918164224e12,
7.47242496e10,
-1.8924375803183791606e9,
47900160.0,
-1209600.0,
30240.0,
-720.0,
12.0,
};
// For speed we'll always use 9 iterations for the initial series estimate,
// and a 12 term expansion for the Euler-Maclaurin formula.
XlaOp a = q;
XlaOp neg_power = ScalarLike(a, 0.);
XlaOp initial_sum = Pow(q, Neg(x));
for (int i = 0; i < 9; ++i) {
a = a + ScalarLike(a, 1.);
neg_power = Pow(a, Neg(x));
initial_sum = initial_sum + neg_power;
}
a = a + ScalarLike(a, 1.);
neg_power = Pow(a, Neg(x));
XlaOp s = initial_sum + neg_power * a / (x - ScalarLike(a, 1.));
XlaOp a_inverse_square = Reciprocal(Square(a));
XlaOp horner_sum = ScalarLike(a, 0.);
XlaOp factor = ScalarLike(a, 1.);
// Use Horner's rule for this.
// Note this differs from Cephes which does a 'naive' polynomial evaluation.
// Using Horner's rule allows to avoid some NaN's and Infs from happening,
// resulting in more numerically stable code.
for (int i = 0; i < 11; ++i) {
factor =
(x - ScalarLike(x, 22 - 2 * i)) * (x - ScalarLike(x, 21 - 2 * i));
horner_sum = factor * a_inverse_square *
(horner_sum + ScalarLike(a, 1. / kZetaCoeffs[i]));
}
s = s + neg_power *
(ScalarLike(neg_power, 0.5) +
x / a * (ScalarLike(a, 1. / kZetaCoeffs[11]) + horner_sum));
const double nan = std::numeric_limits<double>::quiet_NaN();
const double inf = std::numeric_limits<double>::infinity();
// Use the initial zeta sum without the correction term coming
// from Euler-Maclaurin if it is accurate enough.
XlaOp output =
Select(Lt(Abs(neg_power), Abs(initial_sum) * Epsilon(&builder, type)),
initial_sum, s);
// This is the harmonic series.
output = Select(Eq(x, ScalarLike(x, 1.)), ScalarLike(x, inf), output);
// Function is not defined for x < 1.
output = Select(Lt(x, ScalarLike(x, 1.)), ScalarLike(x, nan), output);
// If q <= 0, then when q is an integer or x is not an integer, this is
// NaN.
XlaOp domain_error = And(Le(q, ScalarLike(x, 0.)), Ne(x, Floor(x)));
XlaOp negative_integer_q = And(Le(q, ScalarLike(x, 0.)), Eq(q, Floor(q)));
output = Select(negative_integer_q, ScalarLike(x, inf), output);
output = Select(domain_error, ScalarLike(x, nan), output);
return output;
};
return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x));
TF_ASSIGN_OR_RETURN(auto q_shape, builder.GetShape(q));
if (x_shape != q_shape) {
return InvalidArgument(
"Arguments to Zeta must have equal shapes and types; got %s and %s",
x_shape.ToString(), q_shape.ToString());
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Zeta", x));
bool needs_upcast =
x_shape.element_type() == F16 || x_shape.element_type() == BF16;
if (needs_upcast) {
x = ConvertElementType(x, F32);
q = ConvertElementType(q, F32);
}
XlaOp result = doit(x, q, x_shape.element_type());
if (needs_upcast) {
result = ConvertElementType(result, x_shape.element_type());
}
return result;
});
}
} // namespace xla

View File

@ -72,6 +72,12 @@ XlaOp RandomGammaGrad(XlaOp a, XlaOp x);
// Computes an approximation of the complementary incomplete gamma function.
XlaOp Igammac(XlaOp a, XlaOp x);
// Computes the Polygamma of two arguments.
XlaOp Polygamma(XlaOp n, XlaOp x);
// Computes the Riemann zeta function of two arguments.
XlaOp Zeta(XlaOp x, XlaOp q);
// Rounds the given number to even when the number is equidistant between two
// integers.
XlaOp RoundToEven(XlaOp x);

View File

@ -239,7 +239,7 @@ struct CacheEntry {
// a signature and if the object has been insterted already, other threads
// will wait for the notification.
absl::Notification compilation_complete;
absl::optional<std::exception> compilation_error = absl::nullopt;
absl::optional<Status> compilation_error = absl::nullopt;
};
// A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the
@ -314,7 +314,7 @@ class CompiledFunction {
// absl::optional<absl::Notification> is not supported
bool first_compilation_started_ = false;
absl::Notification first_compilation_complete_;
absl::optional<std::exception> first_compilation_error_ = absl::nullopt;
absl::optional<Status> first_compilation_error_ = absl::nullopt;
};
CompiledFunction::CompiledFunction(py::function fun,
@ -646,7 +646,8 @@ CacheEntry& CompiledFunction::GetCacheEntry(
py::gil_scoped_release gil_release;
found_iterator->second->compilation_complete.WaitForNotification();
if (found_iterator->second->compilation_error) {
throw found_iterator->second->compilation_error.value();
throw std::invalid_argument(
found_iterator->second->compilation_error.value().error_message());
}
}
return *(found_iterator->second);
@ -671,8 +672,8 @@ CacheEntry& CompiledFunction::SetAndReturnCacheEntry(
} else {
try {
executable_and_pytree = cache_miss_fun_(*args, **kwargs);
} catch (const std::exception& e) {
cache_entry.compilation_error = e;
} catch (const py::error_already_set& e) {
cache_entry.compilation_error = InvalidArgument("%s", e.what());
cache_entry.compilation_complete.Notify();
throw;
}
@ -736,16 +737,17 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
if (!first_compilation_complete_.HasBeenNotified()) {
py::gil_scoped_release gil_release;
first_compilation_complete_.WaitForNotification();
if (first_compilation_error_) {
throw first_compilation_error_.value();
}
}
if (first_compilation_error_) {
throw std::invalid_argument(
first_compilation_error_.value().error_message());
}
} else {
first_compilation_started_ = true;
try {
cache_miss_result = cache_miss_fun_(*args, **kwargs);
} catch (const std::exception& e) {
first_compilation_error_ = e;
} catch (const py::error_already_set& e) {
first_compilation_error_ = InvalidArgument("%s", e.what());
first_compilation_complete_.Notify();
throw;
}
@ -754,9 +756,14 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
pyclient_ = executable->client();
default_device_ = executable->LocalDevices()[0].contents;
if (!default_device_) {
throw std::invalid_argument(
"executable->LocalDevices()[0] should not be null!");
}
first_compilation_complete_.Notify();
}
}
CHECK(default_device_);
// The C++ jit do not support Tracers arguments yet. The Python-based jit
// function will be called if any of the dynamic arguments is unsupported.

View File

@ -291,6 +291,7 @@ void BuildOpsSubmodule(py::module* m) {
ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x"));
ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"),
py::arg("b"), py::arg("x"));
ops.def("Zeta", &Zeta, py::arg("x"), py::arg("q"));
#define BINARY_OP(op) \
ops.def( \

View File

@ -56,17 +56,21 @@ StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
XlaOp a, PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
const int n_dims = a_shape.rank();
const int ndims = a_shape.rank();
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
std::vector<int64> error_dims(a_shape.dimensions().begin(),
a_shape.dimensions().end());
error_dims.back() = error_dims.at(ndims - 2) = 1;
auto major_dims = AsInt64Slice(a_shape.dimensions())
.subspan(
/*pos=*/0,
/*len=*/n_dims - 2);
/*len=*/ndims - 2);
auto matrix_dims = AsInt64Slice(a_shape.dimensions())
.subspan(
/*pos=*/0,
/*len=*/n_dims);
/*len=*/ndims);
XlaOp l = ZerosLike(a);
@ -79,9 +83,9 @@ StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
auto body_l = loop_vars[1];
auto seen_error = loop_vars[2];
auto iota_row =
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 1);
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 1);
auto iota_col =
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 2);
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 2);
auto mask_pred = Ge(iota_col, iota_row);
mask_pred = And(mask_pred, Eq(iota_row, i));
@ -91,25 +95,32 @@ StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
// L * L.T, This matrix has of a lot of multiplying with zero
// (namely, L[:, j:] = 0) and redundant computation, but it is faster
// than slice.
auto l_square = BatchDot(body_l, false, body_l, true, precision);
auto l_square =
BatchDot(body_l, false, MaybeConjugate(body_l, true), true, precision);
// A - L*L.T
l_square = body_a - l_square;
auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1});
l_ii = Sqrt(l_ii);
if (ShapeUtil::ElementIsComplex(a_shape)) {
auto sqrt = Sqrt(Real(l_ii));
l_ii = Complex(sqrt, ZerosLike(sqrt));
seen_error = Or(seen_error, IsNan(sqrt));
} else {
l_ii = Sqrt(l_ii);
seen_error = Or(seen_error, IsNan(l_ii));
}
// L = (A - L*L.T) / l_ii * mask + L
body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l;
seen_error =
Or(seen_error, Any(Or(Le(l_ii, ZerosLike(l_ii)), IsNan(l_ii))));
return std::vector<XlaOp>{body_a, body_l, seen_error};
};
TF_ASSIGN_OR_RETURN(
auto cholesky_while,
ForEachIndex(n, S32, body_fn, {a, l, ConstantR0<bool>(builder, false)},
"unblocked", builder));
ForEachIndex(
n, S32, body_fn,
{a, l, Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims))},
"unblocked", builder));
return std::make_pair(cholesky_while[1], cholesky_while[2]);
}
@ -133,23 +144,23 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
ShapeUtil::HumanString(a_shape));
}
if (primitive_util::IsComplexType(a_shape.element_type())) {
return Unimplemented(
"Complex types are not implemented in Cholesky; got shape %s",
ShapeUtil::HumanString(a_shape));
}
if (block_size < 1) {
return InvalidArgument(
"block_size argument to Cholesky must be >= 1; got %d", block_size);
}
std::vector<int64> error_dims(a_shape.dimensions().begin(),
a_shape.dimensions().end());
error_dims.back() = error_dims.at(ndims - 2) = 1;
std::vector<int64> error_dim_indices(ndims);
absl::c_iota(error_dim_indices, 0);
// Blocked left-looking Cholesky factorization.
// Algorithm 1 from
// Haidar, Azzam, et al. "High-performance Cholesky factorization for
// GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
XlaOp l = ZerosLike(a);
XlaOp seen_error = ConstantR0<bool>(builder, false);
XlaOp seen_error = Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims));
for (int64 i = 0; i < n; i += block_size) {
int64 k = std::min(block_size, n - i);
auto panel = SliceInMinorDims(a, {i, i}, {n, i + k});
@ -159,7 +170,8 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
// a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i]))
auto lhs = SliceInMinorDims(l, {i, 0}, {n, i});
auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i});
auto delta = BatchDot(lhs, false, rhs, true, precision);
auto delta =
BatchDot(lhs, false, MaybeConjugate(rhs, true), true, precision);
panel = panel - delta;
}
@ -170,8 +182,14 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
// other elements.
XlaOp factorized_error;
if (k == 1) {
factorized = Sqrt(x);
factorized_error = Any(IsNan(factorized));
if (ShapeUtil::ElementIsComplex(a_shape)) {
auto sqrt = Sqrt(Real(x));
factorized = Complex(sqrt, ZerosLike(sqrt));
factorized_error = IsNan(sqrt);
} else {
factorized = Sqrt(x);
factorized_error = IsNan(factorized);
}
} else {
TF_ASSIGN_OR_RETURN(auto tile_output, CholeskyUnblocked(x, precision));
std::tie(factorized, factorized_error) = tile_output;
@ -187,12 +205,13 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
/*left_side=*/false,
/*lower=*/true,
/*unit_diagonal=*/false,
/*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
/*transpose_a=*/TriangularSolveOptions::ADJOINT);
l = UpdateSliceInMinorDims(l, update, {i + k, i});
}
}
return Select(seen_error,
FullLike(l, std::numeric_limits<float>::quiet_NaN()), l);
return Select(
BroadcastInDim(seen_error, a_shape.dimensions(), error_dim_indices),
FullLike(l, std::numeric_limits<float>::quiet_NaN()), l);
});
}

View File

@ -113,28 +113,47 @@ int64 CountNonLeafOps(const OpCollection& ops) {
// of reuses This is used as a placeholder only, assuming all
// instructions can be fused to enable data reuses
int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
// Reuses in some way work like forces that pull instructions
// towards each other. We use a number 0-10 to classify how strong the force
// is between a pair of operations. Given a group of instructions that can be
// moved together, if the forces inside a conditional are stronger, the group
// will be moved incide or remain inside the conditional; otherwise, it will
// be moved outside to or remain outside of the conditional.
VLOG(2) << "ConditionalCodeMotion: Add reuses carried by instr: "
<< op->ToString() << "=>" << user->ToString() << "\n";
switch (user->opcode()) {
case HloOpcode::kGetTupleElement:
case HloOpcode::kTuple:
return 0;
case HloOpcode::kConvert:
// Because convert is treated not moveable when following Dot or
// convolution, here if op is dot or convolution, they must be separated
// by a conditional boundary. Here we do not try to pull convert inside
// conditionals to be together with the dot or convolution.
switch (op->opcode()) {
case HloOpcode::kConvolution:
case HloOpcode::kDot:
return 0;
default:
break;
}
break;
default:
break;
}
switch (op->opcode()) {
// These instructions are lightweight and easy to fuse.
// These instructions do not carry weight of reuse themselves.
case HloOpcode::kParameter:
case HloOpcode::kConstant:
case HloOpcode::kGetTupleElement:
return 0;
case HloOpcode::kConditional:
return 10;
default:
// Assume fusion will not happen anyway if user count > 1)
if (CountNonLeafOps(op->users()) > 1) {
return 0;
}
return 10;
default: {
// Assume the reuse decreases with increasing user count.
int count1 = CountNonLeafOps(op->users());
int count2 = CountNonLeafOps(user->operands());
return 10 / count1 / count2;
}
}
}
@ -192,17 +211,35 @@ Status CopyInOrOutOfConditional(
absl::InlinedVector<HloInstruction*, 4> new_operands;
for (int i = 0; i < op->operands().size(); ++i) {
auto op_i = op->operands()[i];
VLOG(2) << "Looking for operand:" << op_i->ToString() << "\n";
VLOG(2) << "Looking for " << op_i->ToString() << "\n";
if (ContainsKey(hoisted_instructions, op_i)) {
auto new_op_i =
FindOrDie(hoisted_instructions, op_i).operands()[dest_index];
VLOG(2) << "new operand:" << new_op_i->ToString() << "\n";
VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
new_operands.push_back(new_op_i);
} else {
CHECK(op_i->opcode() == HloOpcode::kConstant);
auto new_op_i = parent->AddInstruction(op_i->Clone());
VLOG(2) << "new operand:" << new_op_i->ToString() << "\n";
new_operands.push_back(new_op_i);
switch (op_i->opcode()) {
case HloOpcode::kConstant: {
auto new_op_i = parent->AddInstruction(op_i->Clone());
VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
new_operands.push_back(new_op_i);
break;
}
case HloOpcode::kGetTupleElement: {
auto gte = Cast<HloGetTupleElementInstruction>(op_i);
int64 index = gte->tuple_index();
HloInstruction* root = parent->root_instruction();
CHECK(root->opcode() == HloOpcode::kTuple &&
index < root->operand_count());
auto new_op_i = root->mutable_operand(index);
VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
new_operands.push_back(new_op_i);
break;
}
default:
LOG(FATAL) << "Unexpected out-of-boundary instruction:"
<< op_i->ToString() << "\n";
}
}
}
HloInstruction* new_instruction = parent->AddInstruction(
@ -492,6 +529,7 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
int64 index = tuple_opd->tuple_index();
CHECK(old_root->operands().size() > index);
HloInstruction* old_opd = old_root->operands()[index];
VLOG(2) << "old opd = " << old_opd << "\n";
CHECK(ContainsKey(hoisted_instructions, old_opd));
HloInstruction* new_opd = hoisted_instructions[old_opd].operands()[0];
CHECK(old_opd != nullptr);
@ -535,6 +573,7 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
HloInstruction* new_root =
conditional->branch_computation(0)->root_instruction();
*conditional->mutable_shape() = new_root->shape();
//
VLOG(1) << "done moving instructions out of branches\n"
<< conditional_parent->ToString(HloPrintOptions::Fingerprint())
@ -558,16 +597,26 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionIn(
absl::flat_hash_map<HloInstruction*, Boundary> hoisted_instructions;
int64 to_move_in_size = to_move_in.size();
int64 branch_count = conditional->branch_count();
HloGetTupleElementInstruction* tuple_use =
DynCast<HloGetTupleElementInstruction>(to_move_in[0].operands()[0]);
// If use_index is -1, the old conditional root entry used by to_move_in
// instructions still need to be included as an entry of the modified
// conditional root, and the new result of the to_move_in instructions
// need to be added as an extra entry of the modified root; otherwise, the
// old root entry will be replaced with the new result in the modified root.
// The entry replacement should be allowed only if tuple_use has <=1 users.
int64 use_index = (tuple_use != nullptr && tuple_use->user_count() == 1)
? tuple_use->tuple_index()
: -1;
VLOG(2) << "Tuple use index = " << use_index << "\n";
// Number of old conditional entries still to be used outside.
// If conditional shape is not tuple, will create a tuple and use subscript
// 0 to save the old operand being used.
int64 op_index = conditional->shape().IsTuple()
? conditional->shape().tuple_shapes_size() - 1
: 0;
HloGetTupleElementInstruction* tuple_use =
dynamic_cast<HloGetTupleElementInstruction*>(to_move_in[0].operands()[0]);
int64 use_index = (tuple_use != nullptr) ? tuple_use->tuple_index() : -1;
VLOG(2) << "Tuple use index = " << use_index << "\n";
int64 op_index =
conditional->shape().IsTuple()
? ((use_index >= 0) ? conditional->shape().tuple_shapes_size() - 1
: conditional->shape().tuple_shapes_size())
: 0;
// Use to map the tuple_use instruction to its operand;
Boundary b_opd_use(Boundary::Position::kInsideBranch);
Boundary b_old_root(Boundary::Position::kInsideBranch);
@ -628,26 +677,29 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionIn(
hoisted_instructions[conditional] = b_old_root;
int64 cp_start = 0;
if (use_index >= 0) {
VLOG(2) << "Mapping GTE: " << tuple_use->ToString() << "\n";
hoisted_instructions[tuple_use] = b_opd_use;
cp_start = 1;
}
for (int64 i = cp_start; i < to_move_in_size; i++) {
Boundary b_to_move = to_move_in[i];
cp_start = (tuple_use != nullptr) ? 1 : 0;
for (int64 to_move_index = cp_start; to_move_index < to_move_in_size;
to_move_index++) {
Boundary b_to_move = to_move_in[to_move_index];
HloInstruction* op = b_to_move.operands()[0];
CHECK(op != nullptr);
bool to_be_used_outside = true;
VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
if (i < to_move_in_size - 1 && op->user_count() == 1 &&
op->users()[0] == to_move_in[i + 1].operands()[0]) {
if (to_move_index < to_move_in_size - 1 && op->user_count() == 1 &&
op->users()[0] == to_move_in[to_move_index + 1].operands()[0]) {
to_be_used_outside = false;
VLOG(2) << "Instruction is not to be used outside the branch\n";
}
Boundary b(Boundary::Position::kInsideBranch);
for (int i = 0; i < branch_count; i++) {
auto computation = conditional->branch_computation(i);
VLOG(2) << "Copying to branch: " << i << "\n";
TF_RETURN_IF_ERROR(CopyInOrOutOfConditional(b_to_move, i, computation,
hoisted_instructions));
VLOG(2) << "After Copying to branch: " << computation->ToString() << "\n";
VLOG(2) << "Done:" << computation->ToString() << "\n";
if (to_be_used_outside) {
auto new_op = hoisted_instructions[op].operands()[i];
auto new_root = computation->root_instruction();
@ -681,18 +733,19 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionIn(
// Remove hoisted instructions from the branches.
for (int64 i = to_move_in_size - 1; i >= 0; i--) {
Boundary boundary_to_move_in = to_move_in[i];
VLOG(2) << "Removing boundary:" << boundary_to_move_in.ToString() << "\n";
HloInstruction* op = boundary_to_move_in.operands()[0];
for (auto user : op->users()) {
VLOG(2) << "Has User: " << user->ToString() << "\n";
if (op->user_count() == 0) {
VLOG(2) << "Removing boundary:" << boundary_to_move_in.ToString() << "\n";
TF_RETURN_IF_ERROR(conditional->parent()->RemoveInstruction(op));
VLOG(2) << "Done removing boundary.\n";
}
TF_RETURN_IF_ERROR(conditional->parent()->RemoveInstruction(op));
}
// Reset shapes of user gtes to the new shape.
if (use_index != -1) {
for (auto* user : conditional->users()) {
if (user->opcode() == HloOpcode::kGetTupleElement) {
VLOG(2) << "Resetting shape of user: " << user->ToString() << "\n";
*user->mutable_shape() =
conditional->shape().tuple_shapes(user->tuple_index());
}
@ -712,14 +765,21 @@ class GroupConnectedBoundaries {
HloComputation* conditional_parent_;
bool is_layout_sensitive_;
// Instructions that have been visited but are not going to be moved.
absl::flat_hash_set<HloInstruction*> visited_;
absl::flat_hash_map<HloInstruction*, int>& visited_count_;
public:
explicit GroupConnectedBoundaries(HloInstruction* conditional,
bool is_layout_sensitive)
explicit GroupConnectedBoundaries(
HloInstruction* conditional, bool is_layout_sensitive,
absl::flat_hash_map<HloInstruction*, int>& visited_count)
: conditional_(conditional),
conditional_parent_(conditional->parent()),
is_layout_sensitive_(is_layout_sensitive) {}
is_layout_sensitive_(is_layout_sensitive),
visited_count_(visited_count) {}
void clear_recently_visited() {
for (const auto& boundary : new_boundaries_) {
visited_count_.erase(boundary.operands()[0]);
}
}
// Returns true if `instruction` is worth hoisting.
bool WorthHoisting(HloInstruction* instruction) {
// This is needed for the "moving-in" transformation, to prevent the root
@ -736,19 +796,26 @@ class GroupConnectedBoundaries {
// ops such as Dot or Convolutional, it is better to keep convert
// within conditional so that convert can be fused with Dot or
// Convolutional.
//
// TODO(b/154283721): figure out the scenario when convert can be
// fused with AllReduce out of conditional.
switch (instruction->operand(0)->opcode()) {
case HloOpcode::kAllReduce:
case HloOpcode::kReshape:
case HloOpcode::kGetTupleElement:
return true;
default:
VLOG(2) << "Instruction is convert and its operand is not know to "
VLOG(2) << "Instruction is convert and its operand is not known to "
"be worth hoisting\n";
return false;
}
case HloOpcode::kGetTupleElement:
switch (instruction->operand(0)->opcode()) {
// do not move GTE if its operand is a parameter
case HloOpcode::kParameter:
return false;
default:
return true;
}
case HloOpcode::kAllReduce:
case HloOpcode::kReduce:
case HloOpcode::kAdd:
case HloOpcode::kPower:
case HloOpcode::kCopy:
@ -758,8 +825,10 @@ class GroupConnectedBoundaries {
case HloOpcode::kDivide:
case HloOpcode::kTuple:
case HloOpcode::kSqrt:
case HloOpcode::kRsqrt:
case HloOpcode::kReshape:
case HloOpcode::kGetTupleElement:
case HloOpcode::kMinimum:
case HloOpcode::kMaximum:
return true;
default:
VLOG(2) << "Instruction is not known to be worth hoisting\n";
@ -772,14 +841,20 @@ class GroupConnectedBoundaries {
// The operand must be an instruction that is not going to be moved (if
// user is inside the conditional); otherwise it must be the conditional
// itself and its user must be outside of the conditional.
if (!ContainsKey(visited_, op) && op != conditional_) {
if (!ContainsKey(visited_count_, op) && op != conditional_) {
continue;
}
// Only consider single-user cases as reuseable.
if (user->opcode() == HloOpcode::kGetTupleElement &&
user->user_count() == 1) {
if (auto tuple_gte = DynCast<HloGetTupleElementInstruction>(user)) {
if (op->opcode() == HloOpcode::kConditional) {
auto tuple = op->branch_computation(0)->root_instruction();
if (tuple->opcode() == HloOpcode::kTuple) {
auto index = tuple_gte->tuple_index();
CHECK(index < tuple->operand_count());
op = tuple->mutable_operand(index);
}
}
reuses += ReusesCarriedBy(op, user->users()[0]);
} else if (op->user_count() == 1) {
} else {
reuses += ReusesCarriedBy(op, user);
}
}
@ -797,6 +872,7 @@ class GroupConnectedBoundaries {
// some aspects of the overall algorithm need to be redesigned to
// accommandate the change.
if (all_users.size() > 1) {
VLOG(2) << "Having multiple users from: " << user->ToString() << "\n";
return 0;
}
if (!all_users.empty()) {
@ -818,7 +894,7 @@ class GroupConnectedBoundaries {
}
}
}
} else if (ContainsKey(visited_, op)) {
} else if (ContainsKey(visited_count_, op)) {
reuses += ReusesCarriedBy(user, op);
}
VLOG(2) << "reuses after instruction " << user->ToString() << ":"
@ -866,6 +942,50 @@ class GroupConnectedBoundaries {
}
return b2;
}
// Checking whether it is safe to move a boundary when visited through a
// dependent already considered for moving.
bool IsSafeToMoveBoundary(const Boundary& next_boundary) {
int64 next_boundary_count =
(next_boundary.IsInsideBranch())
? next_boundary.operands()[0]->user_count()
: CountNonLeafOps(next_boundary.operands()[0]->operands());
if (next_boundary_count <= 1) {
// If boundary has only a single or no dependent, safe to move.
return true;
} else {
if (!ContainsKey(visited_count_, next_boundary.operands()[0])) {
VLOG(2) << "Skip next boundary " << next_boundary.ToString() << "\n"
<< " because it has multiple dependents: "
<< next_boundary_count << "\n";
visited_count_[next_boundary.operands()[0]] = 1;
new_boundaries_.push_back(next_boundary);
} else {
auto pos = std::find(new_boundaries_.begin(), new_boundaries_.end(),
next_boundary);
if (pos != new_boundaries_.end() ||
next_boundary.operands().size() == 1) {
int count = ++visited_count_[next_boundary.operands()[0]];
if (count == next_boundary_count) {
VLOG(2) << "Recovering next boundary " << next_boundary.ToString()
<< "\n"
<< " because all of its dependents have been visited: "
<< next_boundary_count << "\n";
visited_count_.erase(next_boundary.operands()[0]);
if (pos != new_boundaries_.end()) {
new_boundaries_.erase(pos);
}
return true;
}
} else {
VLOG(2) << "Skip incompatible multi-dependent boundary: "
<< next_boundary.ToString() << ":" << next_boundary_count
<< "\n";
}
}
}
return false;
}
// This function is reused both for moving the boundary outside or into a
// conditional. As the result, the readability is somewhat compromised.
// It might be nice to refactor this function to factor the outside-inside
@ -879,7 +999,7 @@ class GroupConnectedBoundaries {
VLOG(2) << "visiting boundary " << b.ToString() << "\n";
if ((b.IsOutsideBranch() || InstructionWithinBranchIdentical(
b.operands(), is_layout_sensitive_)) &&
WorthHoisting(b.operands()[0])) {
IsSafeToMoveBoundary(b) && WorthHoisting(b.operands()[0])) {
connected_boundaries_.push_back(b);
VLOG(2) << "boundary can be moved\n";
int64 operand_count = (b.IsInsideBranch())
@ -887,26 +1007,12 @@ class GroupConnectedBoundaries {
: b.operands()[0]->users().size();
for (int i = 0; i < operand_count; i++) {
Boundary next_boundary = GetNextBoundary(b, i);
int64 next_boundary_count =
(next_boundary.IsInsideBranch())
? next_boundary.operands()[0]->user_count()
: CountNonLeafOps(next_boundary.operands()[0]->operands());
// only consider adding an exclusive producor into the same group.
if (next_boundary_count == 1) {
VLOG(2) << "Add operand " << i << " to visit later\n";
visitor.AddToWorkList(next_boundary);
} else {
VLOG(2) << "Next boundary " << i
<< " has multiple uses: " << next_boundary_count << "\n";
if (!ContainsKey(visited_, next_boundary.operands()[0])) {
visited_.insert(next_boundary.operands()[0]);
new_boundaries_.push_back(next_boundary);
}
}
VLOG(2) << "Add operand/user " << i << " to visit later\n";
visitor.AddToWorkList(next_boundary);
}
} else {
VLOG(2) << "boundary cannot be moved\n";
visited_.insert(b.operands()[0]);
visited_count_[b.operands()[0]] = 1;
new_boundaries_.push_back(b);
}
}
@ -948,8 +1054,10 @@ class GroupConnectedBoundaries {
ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
HloInstruction* conditional, const Boundary& cur_boundary,
std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries) {
GroupConnectedBoundaries connect(conditional, is_layout_sensitive_);
std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries,
absl::flat_hash_map<HloInstruction*, int>& visited_count) {
GroupConnectedBoundaries connect(conditional, is_layout_sensitive_,
visited_count);
auto move_in_or_out =
connect.BoundariesToMoveInOrOut(conditional, cur_boundary);
if (!move_in_or_out.empty()) {
@ -964,16 +1072,21 @@ ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
// at the first entry of the sequence is sufficient to know which
// direction the move is intended.
to_move = move_in_or_out;
return to_move[0].IsInsideBranch() ? Decision::kMoveOutOfBranch
: Decision::kMoveIntoBranch;
return Decision(to_move[0].IsInsideBranch()
? Decision::Direction::kMoveOutOfBranch
: Decision::Direction::kMoveIntoBranch,
benefit);
} else {
connect.clear_recently_visited();
}
} else {
connect.AddNewBoundaries(new_boundaries);
}
return ConditionalCodeMotion::Decision::kNoChange;
return Decision(Decision::Direction::kNoChange, 0);
}
StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
VLOG(2) << "Begin a new pass of conditional code motion optimization.\n";
bool changed = false;
bool cleanup_changed = false;
{
@ -1018,6 +1131,8 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
}
}
// Use to collect mappings between cloned instructions.
HloCloneContext clone_context(module);
for (HloInstruction* conditional : conditional_ops) {
int branch_count = conditional->branch_count();
// check for shared conditional computations
@ -1031,7 +1146,13 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
}
// Boundaries to move out or to move into the branches.
std::vector<Boundary> to_move_out, to_move_in, new_boundaries;
std::vector<std::vector<Boundary> > to_move_out, to_move_in;
std::vector<std::vector<Boundary> > new_boundaries_for_moveout;
std::vector<std::vector<Boundary> > new_boundaries_for_movein;
// Number of times each instruction has been visited for moving.
absl::flat_hash_map<HloInstruction*, int> visited_count;
int benefit_move_out = 0, benefit_move_in = 0;
Decision::Direction final_d = Decision::Direction::kNoChange;
// The conditional is moved into a worklist as the seed (starting point).
// The conditional will be expanded into multiple seeds (starting points),
// its roots and its users, when it is visited by GroupConnectedBoundaries.
@ -1039,76 +1160,114 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
// so that the other seeding boundaries can be visited in turn.
BoundaryVisitor visitor(conditional);
VLOG(2) << "Analyzing conditional:" << conditional->ToString() << "\n";
ConditionalCodeMotion::Decision d = Decision::kNoChange;
// The following loop breaks out as soon as a decision to modify the
// conditional is reached --- irrespective of whether visitor is empty.
while (d == Decision::kNoChange && visitor.HasNextBoundary()) {
// Try visit all the boundaries, collect the analysis results, and save
// all the benefitical non-conflicting decisions. If two decisions conflict
// with each other, save the more benefitical one.
while (visitor.HasNextBoundary()) {
std::vector<Boundary> to_move, next_boundary;
Boundary boundary = visitor.PopNextBoundary();
VLOG(2) << "Analyzing boundary:" << boundary.ToString() << "\n";
d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary);
if (d != Decision::kNoChange && conditional_is_shared) {
for (int i = 0; i < branch_count; ++i) {
HloComputation* branch_i = conditional->branch_computation(i);
if (conditional_computations[branch_i] > 0) {
// Cloning is absolutely needed if the computation is shared by
// different branches, but the cloning can be potentially avoided
// if the sharing is only among branches of the same conditional.
// If cloning these branches causes a problem due to space issues,
// a fix can pass a vector of unique branches to the actual
// transformations, as an alternative representation of the
// conditional branches to be modified. Right now we assume the
// overhead of cloning is minimal since later stages of the compiler
// inline all the computations anyway.
HloComputation* clone_i =
conditional->parent()->parent()->AddEmbeddedComputation(
branch_i->Clone());
conditional->set_branch_computation(i, clone_i);
conditional_computations[branch_i]--;
auto d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary,
visited_count);
switch (d.GetDirection()) {
case Decision::Direction::kMoveOutOfBranch:
VLOG(2) << "Local Decision is move out of branch\n";
to_move_out.push_back(to_move);
new_boundaries_for_moveout.push_back(next_boundary);
benefit_move_out += d.GetBenefit();
if (benefit_move_out >= benefit_move_in) {
final_d = Decision::Direction::kMoveOutOfBranch;
VLOG(2) << "Current Decision is move out of branch\n";
} else {
VLOG(2) << "Current Decision remains move into branch\n";
}
}
to_move.clear();
next_boundary.clear();
VLOG(2) << "Cloned branches as needed: " << conditional->ToString()
<< "\n";
// Need to reanalyze the cloned code to generate correct result.
d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary);
}
switch (d) {
case Decision::kMoveOutOfBranch:
VLOG(2) << "Decision is move out of branch\n";
to_move_out.insert(to_move_out.end(), to_move.begin(), to_move.end());
new_boundaries.insert(new_boundaries.end(), next_boundary.begin(),
next_boundary.end());
break;
case Decision::kMoveIntoBranch:
case Decision::Direction::kMoveIntoBranch:
VLOG(2) << "Decision is move into branch\n";
to_move_in.insert(to_move_in.end(), to_move.begin(), to_move.end());
new_boundaries.insert(new_boundaries.end(), next_boundary.begin(),
next_boundary.end());
to_move_in.push_back(to_move);
new_boundaries_for_movein.push_back(next_boundary);
benefit_move_in += d.GetBenefit();
if (benefit_move_out >= benefit_move_in) {
VLOG(2) << "Current Decision remains move out of branch\n";
} else {
final_d = Decision::Direction::kMoveIntoBranch;
VLOG(2) << "Current Decision is move into branch\n";
}
break;
case Decision::kNoChange:
case Decision::Direction::kNoChange:
VLOG(2) << "Decision is no change\n";
for (const Boundary& b : next_boundary) {
visitor.AddToWorkList(b);
VLOG(2) << "Adding new boundary to worklist:" << b.ToString()
<< "\n";
}
break;
}
}
// If modification is to be made, need to clone the shared branches.
if (final_d != Decision::Direction::kNoChange && conditional_is_shared) {
for (int i = 0; i < branch_count; ++i) {
HloComputation* branch_i = conditional->branch_computation(i);
if (conditional_computations[branch_i] > 0) {
// Cloning is absolutely needed if the computation is shared by
// different branches, but the cloning can be potentially avoided
// if the sharing is only among branches of the same conditional.
// If cloning these branches causes a problem due to space issues,
// a fix can pass a vector of unique branches to the actual
// transformations, as an alternative representation of the
// conditional branches to be modified. Right now we assume the
// overhead of cloning is minimal since later stages of the compiler
// inline all the computations anyway.
HloComputation* clone_i =
conditional->parent()->parent()->AddEmbeddedComputation(
branch_i->Clone("clone", &clone_context));
conditional->set_branch_computation(i, clone_i);
conditional_computations[branch_i]--;
// Need to translate the analysis result to generate correct result.
auto update_boundary = [&](Boundary& boundary) {
auto cloned_instr =
clone_context.FindInstruction(boundary.operands()[i]);
CHECK(cloned_instr != nullptr);
VLOG(2) << "boundary before cloning:" << boundary.operands()[i]
<< "\n";
boundary.mutable_operands()[i] = cloned_instr;
VLOG(2) << "boundary after cloning:" << boundary.operands()[i]
<< "\n";
};
// Only boundaries to move out need to be updated.
if (final_d == Decision::Direction::kMoveOutOfBranch) {
for (int i = 0; i < to_move_out.size(); ++i) {
std::vector<Boundary>& m = to_move_out[i];
std::for_each(m.begin(), m.end(), update_boundary);
}
for (int i = 0; i < new_boundaries_for_moveout.size(); ++i) {
std::vector<Boundary>& m = new_boundaries_for_moveout[i];
std::for_each(m.begin(), m.end(), update_boundary);
}
}
}
}
VLOG(2) << "Cloned branches as needed: " << conditional->ToString()
<< "\n";
}
// At most one of to_move_out or to_move_in can be non-empty, since there is
// only one optimization decision.
if (!to_move_out.empty()) {
TF_ASSIGN_OR_RETURN(
bool result,
MoveInstructionOut(conditional, to_move_out, new_boundaries));
VLOG(2) << "moving out result:" << result << "\n";
changed |= result;
} else if (!to_move_in.empty()) {
TF_ASSIGN_OR_RETURN(
bool result,
MoveInstructionIn(conditional, to_move_in, new_boundaries));
VLOG(2) << "moving in result:" << result << "\n";
changed |= result;
if (final_d == Decision::Direction::kMoveOutOfBranch) {
CHECK(to_move_out.size() == new_boundaries_for_moveout.size());
for (int i = 0; i < to_move_out.size(); ++i) {
TF_ASSIGN_OR_RETURN(bool result,
MoveInstructionOut(conditional, to_move_out[i],
new_boundaries_for_moveout[i]));
changed |= result;
}
} else if (final_d == Decision::Direction::kMoveIntoBranch) {
CHECK(to_move_in.size() == new_boundaries_for_movein.size());
for (int i = 0; i < to_move_in.size(); ++i) {
TF_ASSIGN_OR_RETURN(bool result,
MoveInstructionIn(conditional, to_move_in[i],
new_boundaries_for_movein[i]));
changed |= result;
}
} else if (pursue_full_conditional_code_motion_ && !conditional_is_shared) {
// Invoke special handling for convert rematerialization/hoisting
// We need to make sure no sharing is present in the branches because no
@ -1118,11 +1277,13 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
bool convert_result,
ConvertSpecialMove(conditional, is_layout_sensitive_));
changed |= convert_result;
VLOG(2) << "Done special moving of convert\n";
}
}
if (changed) {
HloPassPipeline subpipeline(
"after_conditional_code_motion_after_convert_hoisting");
VLOG(2) << "starting after motion passes: DCE\n";
subpipeline.AddPass<HloDCE>();
subpipeline.AddPass<TupleSimplifier>();
subpipeline.AddPass<HloDCE>();

View File

@ -52,6 +52,9 @@ class Boundary {
}
return res;
}
bool operator==(const Boundary& that) {
return ContainersEqual(operands_, that.operands_);
}
private:
// Boundary instructions in the conditional branches, one from each branch
@ -78,13 +81,30 @@ class ConditionalCodeMotion : public HloModulePass {
StatusOr<bool> Run(HloModule* module) override;
// Optimization decision for each boundary of the conditional instruction.
enum class Decision { kMoveOutOfBranch, kMoveIntoBranch, kNoChange };
class Decision {
public:
enum class Direction : uint8 {
kMoveOutOfBranch,
kMoveIntoBranch,
kNoChange
};
public:
Decision(Direction direction, int benefit)
: direction_(direction), benefit_(benefit) {}
Direction GetDirection() const { return direction_; }
int GetBenefit() const { return benefit_; }
private:
Direction direction_;
int benefit_;
};
// If the optimization decision is NO_CHANGE, new_boundary is set to nullptr;
// otherwise, it is set to the new boundary after proposed optimization.
virtual Decision ConsiderCodeMotion(HloInstruction* conditional,
const Boundary& cur_boundary,
std::vector<Boundary>& to_move,
std::vector<Boundary>& new_boundaries);
virtual Decision ConsiderCodeMotion(
HloInstruction* conditional, const Boundary& cur_boundary,
std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries,
absl::flat_hash_map<HloInstruction*, int>& visited_count);
private:
const bool is_layout_sensitive_;

View File

@ -234,17 +234,16 @@ ENTRY main {
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional");
const HloComputation* on_true = conditional->branch_computation(0);
ASSERT_EQ(on_true->instruction_count(), 2);
ASSERT_EQ(on_true->instruction_count(), 1);
const HloComputation* on_false = conditional->branch_computation(1);
ASSERT_EQ(on_false->instruction_count(), 2);
ASSERT_EQ(on_false->instruction_count(), 1);
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(op::Tuple(op::Add(op::Convert(op::Reshape(op::GetTupleElement(
op::GetTupleElement(op::Conditional())))),
op::Convert(op::Reshape(op::GetTupleElement(
op::GetTupleElement(op::Conditional()))))))));
AllOf(op::Tuple(op::Add(
op::Convert(op::Reshape(op::GetTupleElement(op::Conditional()))),
op::Convert(op::Reshape(op::GetTupleElement(op::Conditional())))))));
}
TEST_F(ConditionalCodeMotionTest, UserShareOperandCannotBeMoved) {
@ -335,7 +334,7 @@ on_false {
get-tuple-element.2 = f32[] get-tuple-element(arg_tuple.2), index=0
constant.3 = f32[] constant(1)
constant.4 = f32[] constant(2)
add.4 = f32[] add(get-tuple-element.2, constant.3)
add.4 = f32[] add(constant.4, constant.3)
add.5 = f32[] add(get-tuple-element.2, constant.4)
add.6 = f32[] add(add.4, add.5)
ROOT tuple.4 = (f32[]) tuple(add.6)
@ -360,7 +359,7 @@ ENTRY main {
const HloComputation* on_true = conditional->branch_computation(0);
ASSERT_EQ(on_true->instruction_count(), 1);
const HloComputation* on_false = conditional->branch_computation(1);
ASSERT_EQ(on_false->instruction_count(), 1);
ASSERT_EQ(on_false->instruction_count(), 3);
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
@ -543,6 +542,7 @@ ENTRY main {
pred.1 = pred[] parameter(0)
arg_tuple.3 = (bf16[2,54,168,128], bf16[2,52,168,128]) parameter(1)
arg_tuple.4 = (bf16[2,86,104,128], bf16[2,84,104,128]) parameter(2)
arg_tuple.5 = f32[3,3,128,128] parameter(3)
conditional = (f32[3,3,128,128])
conditional(pred.1, arg_tuple.3, arg_tuple.4), true_computation=on_true,
false_computation=on_false
@ -557,6 +557,7 @@ ENTRY main {
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional");
CHECK(conditional != nullptr);
const HloComputation* on_true = conditional->branch_computation(0);
ASSERT_EQ(on_true->instruction_count(), 5);
const HloComputation* on_false = conditional->branch_computation(1);
@ -619,7 +620,7 @@ ENTRY main {
EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional())));
}
TEST_F(ConditionalCodeMotionTest, NoMoveInWithMultipleGTE) {
TEST_F(ConditionalCodeMotionTest, MoveInWithMultipleGTE) {
absl::string_view hlo_string =
R"(
HloModule RemoveIdenticalInstruction
@ -647,19 +648,19 @@ ENTRY main {
false_computation=on_false
get-first-index = f32[10] get-tuple-element(conditional), index=0
get-first-index.2 = f32[10] get-tuple-element(conditional), index=0
pow.1 = f32[10] power(get-first-index, get-first-index)
pow.1 = f32[10] power(get-first-index, get-first-index.2)
ROOT tuple.3 = (f32[10], f32[10]) tuple(pow.1, get-first-index.2)
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass(true, true);
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root,
op::Tuple(op::Power(), op::GetTupleElement(op::Conditional())));
EXPECT_THAT(root, op::Tuple(op::GetTupleElement(op::Conditional()),
op::GetTupleElement(op::Conditional())));
}
TEST_F(ConditionalCodeMotionTest, MovePowInWithSharedBranch) {
TEST_F(ConditionalCodeMotionTest, MoveOutWithSharedBranch) {
absl::string_view hlo_string =
R"(
HloModule RemoveIdenticalInstruction
@ -688,12 +689,16 @@ ENTRY main {
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional");
const HloComputation* on_true = conditional->branch_computation(0);
ASSERT_EQ(on_true->instruction_count(), 5);
ASSERT_EQ(on_true->instruction_count(), 1);
const HloComputation* on_false = conditional->branch_computation(1);
ASSERT_EQ(on_false->instruction_count(), 5);
ASSERT_EQ(on_false->instruction_count(), 1);
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional())));
EXPECT_THAT(
root, AllOf(op::Power(op::Add(op::GetTupleElement(op::Conditional()),
op::GetTupleElement(op::Conditional())),
op::Add(op::GetTupleElement(op::Conditional()),
op::GetTupleElement(op::Conditional())))));
}
TEST_F(ConditionalCodeMotionTest, MovePowInWithNonTupleRoot) {
@ -959,6 +964,104 @@ ENTRY main {
op::AllReduce(op::GetTupleElement(op::Conditional())))))));
}
TEST_F(ConditionalCodeMotionTest, DoNotMoveWithExtraOperand) {
absl::string_view hlo_string =
R"(
HloModule RemoveIdenticalInstruction
branch {
arg.1 = f32[10] parameter(0)
ROOT add.1 = f32[10] add(arg.1, arg.1)
}
ENTRY main {
pred.1 = pred[] parameter(0)
tuple.1 = f32[10] parameter(1)
tuple.2 = f32[10] parameter(2)
conditional = f32[10]
conditional(pred.1, tuple.1, tuple.2), true_computation=branch,
false_computation=branch
ROOT pow.1 = f32[10] power(conditional, tuple.2)
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass(true, true);
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
}
TEST_F(ConditionalCodeMotionTest, MultipleIndependentMoveIns) {
absl::string_view hlo_string =
R"(
HloModule FromNMT
%add.31755 (x.139: f32[], y.139: bf16[]) -> bf16[] {
%x.139 = bf16[]{:T(512)} parameter(0)
%y.139 = bf16[]{:T(512)} parameter(1)
ROOT %add.44073 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.139, bf16[]{:T(512)} %y.139)
}
%nmt.1 {
%wide_param.3 = (bf16[1024,4096]{1,0}, bf16[18,64,1024]{2,1,0}, s32[]) parameter(0)
%get-tuple-element.16525 = bf16[1024,4096]{1,0} get-tuple-element((bf16[1024,4096]{1,0}, bf16[18,64,1024]{2,1,0}, s32[]) %wide_param.3), index=0
%get-tuple-element.16527 = bf16[18,64,1024]{2,1,0} get-tuple-element((bf16[1024,4096]{1,0}, bf16[18,64,1024]{2,1,0}, s32[]) %wide_param.3), index=1
%get-tuple-element.16588 = s32[] get-tuple-element((bf16[1024,4096]{1,0}, bf16[18,64,1024]{2,1,0}, s32[]) %wide_param.3), index=2
%add.3764 = s32[] add(s32[] %get-tuple-element.16588, s32[] %get-tuple-element.16588), metadata={op_type="Sub" op_name="sub"}
%reshape.9821 = s32[1]{0} reshape(s32[] %add.3764)
%reshape.9822 = s32[] reshape(s32[1]{0} %reshape.9821)
%constant.13127 = s32[] constant(0)
%dynamic-slice.1245 = bf16[1,64,1024]{2,1,0} dynamic-slice(bf16[18,64,1024]{2,1,0} %get-tuple-element.16527, s32[] %reshape.9822, s32[] %constant.13127, s32[] %constant.13127), dynamic_slice_sizes={1,64,1024}
%reshape.9825 = bf16[64,1024]{1,0} reshape(bf16[1,64,1024]{2,1,0} %dynamic-slice.1245), metadata={op_type="GatherV2" op_name="GatherV2"}
%logistic.814 = bf16[64,1024]{1,0} logistic(bf16[64,1024]{1,0} %reshape.9825), metadata={op_type="Sigmoid" op_name="Sigmoid"}
%multiply.4890 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %reshape.9825, bf16[64,1024]{1,0} %logistic.814), metadata={op_type="Mul" op_name="mul"}
%tanh.573 = bf16[64,1024]{1,0} tanh(bf16[64,1024]{1,0} %reshape.9825), metadata={op_type="Tanh" op_name="Tanh"}
%multiply.4891 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %logistic.814, bf16[64,1024]{1,0} %tanh.573), metadata={op_type="Mul" op_name="mul_1"}
%add.3766 = bf16[64,1024]{1,0} add(bf16[64,1024]{1,0} %multiply.4890, bf16[64,1024]{1,0} %multiply.4891), metadata={op_type="AddV2" op_name="add_1"}
%multiply.4894 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %add.3766, bf16[64,1024]{1,0} %logistic.814), metadata={op_type="Mul" op_name="gradients_1/mul_grad/Mul"}
%constant.10568 = bf16[] constant(1), metadata={op_type="TanhGrad" op_name="gradients/Tanh_1_grad/TanhGrad"}
%broadcast.7198 = bf16[64,1024]{1,0} broadcast(bf16[] %constant.10568), dimensions={}, metadata={op_type="TanhGrad" op_name="gradients/Tanh_1_grad/TanhGrad"}
%multiply.4896 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %tanh.573, bf16[64,1024]{1,0} %tanh.573), metadata={op_type="TanhGrad" op_name="gradients/Tanh_1_grad/TanhGrad"}
%constant.10571 = bf16[] constant(1), metadata={op_type="SigmoidGrad" op_name="gradients/Sigmoid_grad/SigmoidGrad"}
%broadcast.7201 = bf16[64,1024]{1,0} broadcast(bf16[] %constant.10571), dimensions={}, metadata={op_type="SigmoidGrad" op_name="gradients/Sigmoid_grad/SigmoidGrad"}
%subtract.1702 = bf16[64,1024]{1,0} subtract(bf16[64,1024]{1,0} %broadcast.7201, bf16[64,1024]{1,0} %logistic.814), metadata={op_type="SigmoidGrad" op_name="gradients/Sigmoid_grad/SigmoidGrad"}
%multiply.4907 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %tanh.573, bf16[64,1024]{1,0} %add.3766), metadata={op_type="Mul" op_name="gradients/mul_2_grad/Mul_1"}
%multiply.4908 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %multiply.4907, bf16[64,1024]{1,0} %logistic.814), metadata={op_type="SigmoidGrad" op_name="gradients/Sigmoid_2_grad/SigmoidGrad"}
%dot.781 = bf16[64,4096]{1,0} dot(bf16[64,1024]{1,0} %multiply.4908, bf16[1024,4096]{1,0} %get-tuple-element.16525), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="MatMul" op_name="MatMul"}
ROOT %tuple.3200 = (bf16[64,1024]{1,0}, bf16[64,4096]{1,0}, s32[]) tuple(bf16[64,1024]{1,0} %multiply.4894, bf16[64,4096]{1,0} %dot.781, s32[] %reshape.9822)
}
ENTRY main {
pred.1 = pred[] parameter(0)
arg_tuple.3 = (bf16[1024,4096]{1,0}, bf16[18,64,1024]{2,1,0}, s32[]) parameter(1)
arg_tuple.4 = (bf16[1024,4096]{1,0}, bf16[18,64,1024]{2,1,0}, s32[]) parameter(2)
%arg.2 = s32[] parameter(3)
%conditional.3 = (bf16[64,1024]{1,0}, bf16[64,4096]{1,0}, s32[]) conditional(pred.1, arg_tuple.3, arg_tuple.4), true_computation=nmt.1, false_computation=nmt.1
%get-tuple-element.15889 = bf16[64,1024]{1,0} get-tuple-element((bf16[64,1024]{1,0}, bf16[64,4096]{1,0}, s32[]) %conditional.3), index=0, metadata={op_type="Case" op_name="switch_case/indexed_case"}
%multiply.4596 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %get-tuple-element.15889, bf16[64,1024]{1,0} %get-tuple-element.15889), metadata={op_type="L2Loss" op_name="global_norm/L2Loss"}
%constant.10279 = bf16[] constant(0), metadata={op_type="L2Loss" op_name="global_norm/L2Loss"}
%reduce.844 = bf16[] reduce(bf16[64,1024]{1,0} %multiply.4596, bf16[] %constant.10279), dimensions={0,1}, to_apply=%add.31755, metadata={op_type="L2Loss" op_name="global_norm/L2Loss"}
%get-tuple-element.15890 = bf16[64,4096]{1,0} get-tuple-element((bf16[64,1024]{1,0}, bf16[64,4096]{1,0}, s32[]) %conditional.3), index=1, metadata={op_type="Case" op_name="switch_case/indexed_case"}
%multiply.4597 = bf16[64,4096]{1,0} multiply(bf16[64,4096]{1,0} %get-tuple-element.15890, bf16[64,4096]{1,0} %get-tuple-element.15890), metadata={op_type="L2Loss" op_name="global_norm/L2Loss"}
%constant.10280 = bf16[] constant(0), metadata={op_type="L2Loss" op_name="global_norm/L2Loss"}
%reduce.845 = bf16[] reduce(bf16[64,4096]{1,0} %multiply.4597, bf16[] %constant.10280), dimensions={0,1}, to_apply=%add.31755, metadata={op_type="L2Loss" op_name="global_norm/L2Loss"}
%multiply.4667 = bf16[] multiply(bf16[] %reduce.845, bf16[]{:T(128)} %reduce.844), metadata={op_type="L2Loss" op_name="global_norm/L2Loss"}
ROOT %tuple.3200 = (bf16[], s32[]) tuple(%multiply.4667, s32[] %arg.2)
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional.3");
CHECK(conditional != nullptr);
const HloComputation* on_true = conditional->branch_computation(0);
ASSERT_EQ(on_true->instruction_count(), 27);
const HloComputation* on_false = conditional->branch_computation(1);
ASSERT_EQ(on_false->instruction_count(), 27);
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Tuple(op::GetTupleElement(op::Conditional()),
op::Parameter())));
}
} // namespace conditional_opt
} // namespace xla

View File

@ -237,8 +237,7 @@ Status GpuCompiler::OptimizeHloModule(
return IsMatrixMultiplication(dot)
? candidate_operands
: TransposeFolding::OperandIndices{};
},
TransposeFolding::NeverFoldTranspose);
});
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
pipeline.AddPass<HloDCE>();

View File

@ -84,11 +84,10 @@ Status RunGpuConvForward(GpuConvParams params,
"StreamExecutor doesn't support scaled convolution: %lf.",
params.conv_result_scale);
}
stream->ThenConvolveWithAlgorithm(
return stream->ConvolveWithAlgorithm(
params.input_descriptor, input_buf, params.filter_descriptor, filter_buf,
params.conv_desc, params.output_descriptor, &output_buf,
scratch_allocator, algorithm, options.profile_result);
return Status::OK();
}
template <typename ElementType, typename BiasType, typename OutputType>
@ -123,15 +122,13 @@ Status RunGpuConvForwardActivation(GpuConvParams params,
side_input = output_buf;
}
stream->ThenFusedConvolveWithAlgorithm(
return stream->FusedConvolveWithAlgorithm(
params.input_descriptor, input_buf, params.conv_result_scale,
params.filter_descriptor, filter_buf, params.conv_desc, side_input,
params.fusion->side_input_scale, bias_desc,
DeviceMemory<BiasType>(params.fusion->bias_buf), params.fusion->mode,
params.output_descriptor, &output_buf, scratch_allocator, algorithm,
options.profile_result);
return Status::OK();
}
// StreamExecutor supports various data types via overloading, and the support
@ -162,7 +159,7 @@ Status RunGpuConvInternalImpl(GpuConvParams params,
"StreamExecutor doesn't support scaled convolution: %lf.",
params.conv_result_scale);
}
stream->ThenConvolveBackwardDataWithAlgorithm(
return stream->ConvolveBackwardDataWithAlgorithm(
params.filter_descriptor, filter_buf, params.output_descriptor,
output_buf, params.conv_desc, params.input_descriptor, &input_buf,
scratch_allocator, algorithm, options.profile_result);
@ -173,7 +170,7 @@ Status RunGpuConvInternalImpl(GpuConvParams params,
"StreamExecutor doesn't support scaled convolution: %lf.",
params.conv_result_scale);
}
stream->ThenConvolveBackwardFilterWithAlgorithm(
return stream->ConvolveBackwardFilterWithAlgorithm(
params.input_descriptor, input_buf, params.output_descriptor,
output_buf, params.conv_desc, params.filter_descriptor, &filter_buf,
scratch_allocator, algorithm, options.profile_result);

View File

@ -3341,6 +3341,7 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
last_use_instruction, parameter_time, last_use_time,
absl::StrCat(indent_string, " ")));
} else {
last_use_time = std::min(last_use_time, end_time);
TF_RETURN_IF_ERROR(add_allocation_and_verify(
parameter_time, last_use_time, chunk, value));
}
@ -3359,12 +3360,13 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
TF_RETURN_IF_ERROR(split_conditional_buffer(
last_use_instruction, time_bound.start, time_bound.end, " "));
} else if (!value->uses().empty()) {
last_use_time = std::min(last_use_time, time_bound.end);
VLOG(3) << " buffer: " << buffer.ToString()
<< " value: " << value->ToShortString() << ": ("
<< time_bound.start << ", " << time_bound.end
<< time_bound.start << ", " << last_use_time
<< ") off: " << chunk.offset << ", size: " << chunk.size;
TF_RETURN_IF_ERROR(add_allocation_and_verify(
time_bound.start, time_bound.end, chunk, value));
time_bound.start, last_use_time, chunk, value));
}
}
}

View File

@ -40,8 +40,12 @@ namespace {
// Partition convolution with batch group count.
StatusOr<HloInstruction*> PartitionConvolutionWithBatchGroupCount(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding, const Window& conv_window,
HloInstruction* original_hlo, int64 num_partitions, SpmdBuilder* b) {
const HloSharding& output_sharding,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*,
const Window& conv_window)>& create_sharded_conv,
const Window& conv_window, HloInstruction* original_hlo,
int64 num_partitions, SpmdBuilder* b) {
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
if (original_hlo->batch_group_count() == 1 ||
original_hlo->batch_group_count() < num_partitions) {
@ -115,21 +119,10 @@ StatusOr<HloInstruction*> PartitionConvolutionWithBatchGroupCount(
auto aligned_output_sharding = hlo_sharding_util::TransposeSharding(
lhs.sharding(), lhs_to_output_indices);
// Get LHS and RHS sharded shape.
auto lhs_shard_shape = MakePartitionedShape(lhs.base_shape(), lhs.sharding());
auto rhs_shard_shape = MakePartitionedShape(rhs.base_shape(), rhs.sharding());
const int64 batch_group_count =
CeilOfRatio(original_hlo->batch_group_count(), num_partitions);
// Create partitioned convolution.
TF_ASSIGN_OR_RETURN(
Shape sharded_conv_shape,
ShapeInference::InferConvolveShape(
lhs_shard_shape, rhs_shard_shape, original_hlo->feature_group_count(),
batch_group_count, conv_window, dnums));
auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve(
sharded_conv_shape, lhs.hlo(), rhs.hlo(),
original_hlo->feature_group_count(), batch_group_count, conv_window,
dnums, original_hlo->precision_config()));
auto sharded_conv,
create_sharded_conv(lhs.hlo(), rhs.hlo(), b, conv_window));
sharded_conv->set_sharding(aligned_output_sharding);
return PartitionedHlo(sharded_conv, output_base_shape, lhs.state())
.Reshard(output_sharding)
@ -139,8 +132,12 @@ StatusOr<HloInstruction*> PartitionConvolutionWithBatchGroupCount(
// Partition convolution with feature group count.
StatusOr<HloInstruction*> PartitionConvolutionWithFeatureGroupCount(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding, const Window& conv_window,
HloInstruction* original_hlo, int64 num_partitions, SpmdBuilder* b) {
const HloSharding& output_sharding,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*,
const Window& conv_window)>& create_sharded_conv,
const Window& conv_window, HloInstruction* original_hlo,
int64 num_partitions, SpmdBuilder* b) {
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
if (original_hlo->feature_group_count() == 1 ||
original_hlo->feature_group_count() < num_partitions) {
@ -215,20 +212,9 @@ StatusOr<HloInstruction*> PartitionConvolutionWithFeatureGroupCount(
auto aligned_output_sharding = hlo_sharding_util::TransposeSharding(
lhs.sharding(), lhs_to_output_indices);
auto lhs_shard_shape = MakePartitionedShape(lhs.base_shape(), lhs.sharding());
auto rhs_shard_shape = MakePartitionedShape(rhs.base_shape(), rhs.sharding());
int64 feature_group_count =
CeilOfRatio(original_hlo->feature_group_count(), num_partitions);
TF_ASSIGN_OR_RETURN(
Shape sharded_conv_shape,
ShapeInference::InferConvolveShape(
lhs_shard_shape, rhs_shard_shape, feature_group_count,
original_hlo->batch_group_count(), conv_window, dnums));
auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve(
sharded_conv_shape, lhs.hlo(), rhs.hlo(), feature_group_count,
original_hlo->batch_group_count(), conv_window, dnums,
original_hlo->precision_config()));
auto sharded_conv,
create_sharded_conv(lhs.hlo(), rhs.hlo(), b, conv_window));
sharded_conv->set_sharding(aligned_output_sharding);
return PartitionedHlo(sharded_conv, output_base_shape, lhs.state())
.Reshard(output_sharding)
@ -240,9 +226,12 @@ StatusOr<HloInstruction*> PartitionConvolutionWithFeatureGroupCount(
StatusOr<HloInstruction*>
PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding, const Window& conv_window,
HloInstruction* original_hlo, HloInstruction* partition_id,
HloModule* module, SpmdBuilder* b) {
const HloSharding& output_sharding,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*,
const Window& conv_window)>& create_sharded_conv,
const Window& conv_window, HloInstruction* original_hlo,
HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
TF_RET_CHECK(!lhs.sharding().IsTileMaximal() &&
!rhs.sharding().IsTileMaximal());
@ -491,10 +480,9 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
rhs_with_halo = *concat;
}
auto conv = b->AddInstruction(HloInstruction::CreateConvolve(
output_base_shape, conv_lhs, rhs_with_halo,
original_hlo->feature_group_count(), original_hlo->batch_group_count(),
new_window, dnums, original_hlo->precision_config()));
TF_ASSIGN_OR_RETURN(
auto conv, create_sharded_conv(conv_lhs, rhs_with_halo, b, new_window));
auto ar = collective_ops_creator.create_cross_partition_all_reduce(
b, conv, MakeBinaryAdd(original_hlo->shape().element_type(), module), {},
(*lhs.state().next_channel_id)++);
@ -509,9 +497,12 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
StatusOr<HloInstruction*>
PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding, const Window& conv_window,
HloInstruction* original_hlo, HloInstruction* partition_id,
HloModule* module, SpmdBuilder* b) {
const HloSharding& output_sharding,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*,
const Window& conv_window)>& create_sharded_conv,
const Window& conv_window, HloInstruction* original_hlo,
HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
TF_RET_CHECK(!lhs.sharding().IsTileMaximal() &&
!rhs.sharding().IsTileMaximal());
@ -583,7 +574,6 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
rhs =
rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero, reversed_rhs_dims);
}
// Reshard LHS by exchanging halo such that each shard computes the partial
// sum of the full shape result, and add AllReduce.
//
@ -701,11 +691,8 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
lhs_with_halo = *concat;
}
auto conv = b->AddInstruction(HloInstruction::CreateConvolve(
output_base_shape, lhs_with_halo, rhs.hlo(),
original_hlo->feature_group_count(), original_hlo->batch_group_count(),
new_window, original_hlo->convolution_dimension_numbers(),
original_hlo->precision_config()));
TF_ASSIGN_OR_RETURN(
auto conv, create_sharded_conv(lhs_with_halo, rhs.hlo(), b, new_window));
auto ar =
lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
b, conv, MakeBinaryAdd(output_base_shape.element_type(), module), {},
@ -720,8 +707,11 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
// RHS.
StatusOr<HloInstruction*> PartitionConvolutionTiledOutput(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding, const Window& conv_window,
HloInstruction* original_hlo, SpmdBuilder* b) {
const HloSharding& output_sharding,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*,
const Window& conv_window)>& create_sharded_conv,
const Window& conv_window, HloInstruction* original_hlo, SpmdBuilder* b) {
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
const auto& dnums = original_hlo->convolution_dimension_numbers();
TF_RET_CHECK(!output_sharding.IsTileMaximal());
@ -772,19 +762,13 @@ StatusOr<HloInstruction*> PartitionConvolutionTiledOutput(
resharded_operand_and_window->shard_window.dimensions(
dnums.input_spatial_dimensions(i));
}
TF_ASSIGN_OR_RETURN(
Shape sharded_conv_shape,
ShapeInference::InferConvolveShape(
resharded_operand_and_window->sharded_input->shape(),
rhs.hlo()->shape(), original_hlo->feature_group_count(),
original_hlo->batch_group_count(), new_window, dnums));
auto sharded_conv,
create_sharded_conv(resharded_operand_and_window->sharded_input,
rhs.hlo(), b, new_window));
auto shard_shape = MakePartitionedShape(output_base_shape, output_sharding);
*sharded_conv_shape.mutable_layout() = shard_shape.layout();
auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve(
sharded_conv_shape, resharded_operand_and_window->sharded_input,
rhs.hlo(), original_hlo->feature_group_count(),
original_hlo->batch_group_count(), new_window, dnums,
original_hlo->precision_config()));
if (!resharded_operand_and_window->dynamic_slice_index_on_output
.has_value()) {
CHECK(ShapeUtil::Compatible(shard_shape, sharded_conv->shape()));
@ -799,29 +783,34 @@ StatusOr<HloInstruction*> PartitionConvolutionTiledOutput(
// Partition convolution with only one kind of dims partitioned.
StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding, const Window& conv_window,
HloInstruction* original_hlo, int64 num_partitions,
const SpmdPartitionerOptions& options, HloInstruction* partition_id,
HloModule* module, SpmdBuilder* b) {
const HloSharding& output_sharding,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*,
const Window& conv_window)>& create_sharded_conv,
const Window& conv_window, HloInstruction* original_hlo,
int64 num_partitions, const SpmdPartitionerOptions& options,
HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
// Case 1: Handle depthwise convolution with batch group count or
// feature group count.
if (original_hlo->batch_group_count() > 1) {
TF_ASSIGN_OR_RETURN(auto parallel_partitioned_conv,
PartitionConvolutionWithBatchGroupCount(
lhs, rhs, output_base_shape, output_sharding,
conv_window, original_hlo, num_partitions, b));
TF_ASSIGN_OR_RETURN(
auto parallel_partitioned_conv,
PartitionConvolutionWithBatchGroupCount(
lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
conv_window, original_hlo, num_partitions, b));
if (parallel_partitioned_conv) {
return parallel_partitioned_conv;
}
}
if (original_hlo->feature_group_count() > 1) {
TF_ASSIGN_OR_RETURN(auto parallel_partitioned_conv,
PartitionConvolutionWithFeatureGroupCount(
lhs, rhs, output_base_shape, output_sharding,
conv_window, original_hlo, num_partitions, b));
TF_ASSIGN_OR_RETURN(
auto parallel_partitioned_conv,
PartitionConvolutionWithFeatureGroupCount(
lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
conv_window, original_hlo, num_partitions, b));
if (parallel_partitioned_conv) {
return parallel_partitioned_conv;
}
@ -837,8 +826,8 @@ StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
TF_ASSIGN_OR_RETURN(
auto partitioned_conv,
PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
lhs, rhs, output_base_shape, output_sharding, conv_window,
original_hlo, partition_id, module, b));
lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
conv_window, original_hlo, partition_id, module, b));
if (partitioned_conv) {
return partitioned_conv;
}
@ -846,8 +835,8 @@ StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
TF_ASSIGN_OR_RETURN(
auto partitioned_conv,
PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
lhs, rhs, output_base_shape, output_sharding, conv_window,
original_hlo, partition_id, module, b));
lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
conv_window, original_hlo, partition_id, module, b));
if (partitioned_conv) {
return partitioned_conv;
@ -860,7 +849,7 @@ StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
TF_ASSIGN_OR_RETURN(auto partitioned_conv,
PartitionConvolutionTiledOutput(
lhs, rhs, output_base_shape, output_sharding,
conv_window, original_hlo, b));
create_sharded_conv, conv_window, original_hlo, b));
if (partitioned_conv) {
return partitioned_conv;
@ -869,22 +858,92 @@ StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
return nullptr;
}
StatusOr<std::unique_ptr<HloInstruction>> CreateShardedConvConvolution(
const HloInstruction& conv,
const dot_as_convolution_util::DotConvolutionDimsInfo& dot_dnums,
HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo,
const Window& conv_window) {
CHECK_EQ(conv.opcode(), HloOpcode::kConvolution);
const auto& conv_dnums = conv.convolution_dimension_numbers();
auto window = conv.window();
for (const auto& dim : dot_dnums.batch_dims) {
auto wd = window.mutable_dimensions(dim.spatial_dim);
wd->set_size(sharded_lhs_hlo->shape().dimensions(
conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
wd->set_stride(std::max<int64>(1, wd->size() - 1));
wd->set_base_dilation(wd->size());
}
for (const auto& dim : dot_dnums.contracting_dims) {
if (dim.spatial_dim < 0) {
continue;
}
auto wd = window.mutable_dimensions(dim.spatial_dim);
wd->set_size(sharded_lhs_hlo->shape().dimensions(
conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
}
for (const auto& dim : dot_dnums.rhs_non_contracting_dims) {
if (dim.spatial_dim < 0) {
continue;
}
auto wd = window.mutable_dimensions(dim.spatial_dim);
wd->set_size(sharded_rhs_hlo->shape().dimensions(
conv_dnums.kernel_spatial_dimensions(dim.spatial_dim)));
wd->set_padding_high(wd->size() - 1);
wd->set_padding_low(wd->size() - 1);
}
for (const auto& dim : dot_dnums.conv_spatial_dims) {
auto wd = window.mutable_dimensions(dim.spatial_dim);
const auto& new_window_dimension = conv_window.dimensions(dim.spatial_dim);
wd->set_size(new_window_dimension.size());
wd->set_padding_high(new_window_dimension.padding_high());
wd->set_padding_low(new_window_dimension.padding_low());
}
int64 feature_group_count = conv.feature_group_count();
if (feature_group_count > 1) {
feature_group_count = sharded_lhs_hlo->shape().dimensions(
conv_dnums.input_feature_dimension()) /
sharded_rhs_hlo->shape().dimensions(
conv_dnums.kernel_input_feature_dimension());
}
int64 batch_group_count = conv.batch_group_count();
if (batch_group_count > 1) {
batch_group_count =
sharded_lhs_hlo->shape().dimensions(conv_dnums.input_batch_dimension());
}
TF_ASSIGN_OR_RETURN(
Shape sharded_conv_shape,
ShapeInference::InferConvolveShape(
sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
feature_group_count, batch_group_count, window, conv_dnums));
*sharded_conv_shape.mutable_layout() = conv.shape().layout();
return HloInstruction::CreateConvolve(
sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo, feature_group_count,
batch_group_count, window, conv_dnums, conv.precision_config());
}
} // namespace
// Partition convolution.
StatusOr<HloInstruction*> PartitionConvolution(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*,
const Window& conv_window)>& create_sharded_conv,
const Window& conv_window, HloInstruction* original_hlo,
int64 num_partitions, const SpmdPartitionerOptions& options,
HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
TF_ASSIGN_OR_RETURN(
auto try_partitioned_conv,
PartitionConvolutionBaseCase(lhs, rhs, output_base_shape, output_sharding,
conv_window, original_hlo, num_partitions,
options, partition_id, module, b));
TF_ASSIGN_OR_RETURN(auto try_partitioned_conv,
PartitionConvolutionBaseCase(
lhs, rhs, output_base_shape, output_sharding,
create_sharded_conv, conv_window, original_hlo,
num_partitions, options, partition_id, module, b));
if (try_partitioned_conv) {
return try_partitioned_conv;
}
@ -932,13 +991,22 @@ Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) {
}
auto create_sharded_conv =
[&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo,
spmd::SpmdBuilder* b) -> StatusOr<HloInstruction*> {
TF_ASSIGN_OR_RETURN(
auto sharded_conv,
dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
*hlo, dims_info, lhs_hlo, rhs_hlo));
return b->AddInstruction(std::move(sharded_conv));
spmd::SpmdBuilder* b,
const Window& conv_window) -> StatusOr<HloInstruction*> {
if (dims_info.conv_spatial_dims.empty()) {
TF_ASSIGN_OR_RETURN(
auto sharded_conv,
dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
*hlo, dims_info, lhs_hlo, rhs_hlo));
return b->AddInstruction(std::move(sharded_conv));
} else {
TF_ASSIGN_OR_RETURN(auto sharded_conv,
CreateShardedConvConvolution(*hlo, dims_info, lhs_hlo,
rhs_hlo, conv_window));
return b->AddInstruction(std::move(sharded_conv));
}
};
return HandleDotHelper(hlo, mapping, create_sharded_conv);
}

View File

@ -29,6 +29,9 @@ namespace spmd {
StatusOr<HloInstruction*> PartitionConvolution(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*,
const Window& conv_window)>& create_sharded_conv,
const Window& conv_window, HloInstruction* original_hlo,
int64 num_partitions, const SpmdPartitionerOptions& options,
HloInstruction* partition_id, HloModule* module, SpmdBuilder* b);

View File

@ -19,8 +19,10 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
@ -29,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/numbers.h"
@ -72,8 +75,9 @@ Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) {
mapping.rhs_non_contracting_dims.back().rhs = i;
mapping.rhs_non_contracting_dims.back().output = next_output_dim++;
}
auto create_sharded_dot = [&](HloInstruction* l, HloInstruction* r,
SpmdBuilder* b) -> StatusOr<HloInstruction*> {
auto create_sharded_dot =
[&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b,
const Window& conv_window) -> StatusOr<HloInstruction*> {
TF_ASSIGN_OR_RETURN(
auto sharded_dot_shape,
ShapeInference::InferDotOpShape(l->shape(), r->shape(),
@ -92,11 +96,13 @@ StatusOr<HloInstruction*> PartitionBaseCase(
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
HloModule* module, HloInstruction* original_hlo, int64 lhs_batch_partitions,
int64 rhs_batch_partitions, int64 output_batch_partitions,
int64 lhs_contracting_partitions, int64 rhs_contracting_partitions,
int64 lhs_non_contracting_partitions, int64 rhs_non_contracting_partitions,
HloInstruction*, HloInstruction*, SpmdBuilder*,
const Window& conv_window)>& create_sharded_dot,
const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
int64 lhs_batch_partitions, int64 rhs_batch_partitions,
int64 output_batch_partitions, int64 lhs_contracting_partitions,
int64 rhs_contracting_partitions, int64 lhs_non_contracting_partitions,
int64 rhs_non_contracting_partitions,
int64 output_lhs_non_contracting_partitions,
int64 output_rhs_non_contracting_partitions,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
@ -170,7 +176,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
if (lhs_batch_partitions == rhs_batch_partitions &&
rhs_batch_partitions == num_partitions &&
lhs_sharding_transposed_to_match_rhs == rhs_sharding) {
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b));
TF_ASSIGN_OR_RETURN(
auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
dot->set_sharding(*lhs_sharding_transposed_to_match_output);
return PartitionedHlo(dot, output_base_shape, lhs.state())
.Reshard(output_sharding)
@ -196,7 +203,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
}
auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs);
TF_ASSIGN_OR_RETURN(
auto dot, create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b));
auto dot,
create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b, conv_window));
return dot;
}
// RHS and output are batch partitioned in the same way.
@ -212,7 +220,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
}
auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs);
TF_ASSIGN_OR_RETURN(
auto dot, create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b));
auto dot,
create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b, conv_window));
return dot;
}
return nullptr;
@ -310,8 +319,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
dot_rhs = slice;
}
}
TF_ASSIGN_OR_RETURN(auto dot,
create_sharded_dot(dot_lhs, dot_rhs, &body_b));
TF_ASSIGN_OR_RETURN(
auto dot, create_sharded_dot(dot_lhs, dot_rhs, &body_b, conv_window));
if (windowed_at_contracting_dims) {
// Accumulate the partial output to the result buffer.
o = body_b.AddInstruction(
@ -465,7 +474,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
rhs =
rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
}
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b));
TF_ASSIGN_OR_RETURN(
auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
auto ar =
lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {},
@ -481,8 +491,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
output_lhs_non_contracting_partitions == num_partitions &&
lhs_sharding_transposed_to_match_output == output_sharding) {
auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo();
TF_ASSIGN_OR_RETURN(auto dot,
create_sharded_dot(lhs.hlo(), rhs_replicated, b));
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs_replicated,
b, conv_window));
return dot;
}
@ -491,8 +501,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
output_rhs_non_contracting_partitions == num_partitions &&
rhs_sharding_transposed_to_match_output == output_sharding) {
auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo();
TF_ASSIGN_OR_RETURN(auto dot,
create_sharded_dot(lhs_replicated, rhs.hlo(), b));
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs_replicated, rhs.hlo(),
b, conv_window));
return dot;
}
@ -503,8 +513,9 @@ StatusOr<HloInstruction*> PartitionBaseCase(
lhs.Reshard(*output_sharding_transposed_to_match_lhs);
auto resharded_rhs =
rhs.Reshard(*output_sharding_transposed_to_match_rhs);
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(),
resharded_rhs.hlo(), b));
TF_ASSIGN_OR_RETURN(
auto dot, create_sharded_dot(resharded_lhs.hlo(), resharded_rhs.hlo(),
b, conv_window));
return dot;
}
// Output is partitioned along LHS non-contracting dimensions.
@ -513,8 +524,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
lhs.Reshard(*output_sharding_transposed_to_match_lhs);
auto replicated_rhs = rhs.Reshard(HloSharding::Replicate());
TF_ASSIGN_OR_RETURN(
auto dot,
create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), b));
auto dot, create_sharded_dot(resharded_lhs.hlo(),
replicated_rhs.hlo(), b, conv_window));
return dot;
}
// Output is partitioned along RHS non-contracting dimensions.
@ -522,8 +533,9 @@ StatusOr<HloInstruction*> PartitionBaseCase(
auto replicated_lhs = lhs.Reshard(HloSharding::Replicate());
auto resharded_rhs =
rhs.Reshard(*output_sharding_transposed_to_match_rhs);
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(),
resharded_rhs.hlo(), b));
TF_ASSIGN_OR_RETURN(
auto dot, create_sharded_dot(replicated_lhs.hlo(),
resharded_rhs.hlo(), b, conv_window));
return dot;
}
}
@ -566,7 +578,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
rhs =
rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
}
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b));
TF_ASSIGN_OR_RETURN(
auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
return lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {},
(*lhs.state().next_channel_id)++);
@ -579,8 +592,9 @@ StatusOr<HloInstruction*> PartitionDot(
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
HloModule* module, HloInstruction* original_hlo,
HloInstruction*, HloInstruction*, SpmdBuilder*,
const Window& conv_window)>& create_sharded_dot,
const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
windowed_dot_general_loops);
@ -592,8 +606,9 @@ StatusOr<HloInstruction*> PartitionDotGroupOnBatch(
int64 rhs_contracting_partitions, int64 lhs_non_contracting_partitions,
int64 rhs_non_contracting_partitions,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
HloModule* module, HloInstruction* original_hlo,
HloInstruction*, HloInstruction*, SpmdBuilder*,
const Window& conv_window)>& create_sharded_dot,
const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
bool require_matching_devices_to_group,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
@ -808,8 +823,8 @@ StatusOr<HloInstruction*> PartitionDotGroupOnBatch(
GetPerGroupBaseShape(output_grouped, output_base_shape),
output_grouped.sharding, dims_mapping,
num_partitions / output_grouped.device_groups.size(),
create_sharded_dot, module, original_hlo, options, b,
windowed_dot_general_loops));
create_sharded_dot, conv_window, module, original_hlo,
options, b, windowed_dot_general_loops));
dot->set_sharding(UngroupSharding(output_grouped));
return PartitionedHlo(dot, output_base_shape, lhs.state())
.Reshard(output_sharding)
@ -826,8 +841,9 @@ StatusOr<HloInstruction*> PartitionDotGroupOnNonContracting(
const Shape& output_base_shape, const HloSharding& output_sharding,
const DotConvDimsMapping& dims_mapping, int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
HloModule* module, HloInstruction* original_hlo,
HloInstruction*, HloInstruction*, SpmdBuilder*,
const Window& conv_window)>& create_sharded_dot,
const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
bool require_matching_devices_to_group,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
@ -952,8 +968,8 @@ StatusOr<HloInstruction*> PartitionDotGroupOnNonContracting(
GetPerGroupBaseShape(output_grouped, output_base_shape),
output_grouped.sharding, dims_mapping,
num_partitions / matching_grouped.device_groups.size(),
create_sharded_dot, module, original_hlo, options, b,
windowed_dot_general_loops));
create_sharded_dot, conv_window, module, original_hlo,
options, b, windowed_dot_general_loops));
return dot;
}
@ -966,8 +982,9 @@ StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
HloModule* module, HloInstruction* original_hlo,
HloInstruction*, HloInstruction*, SpmdBuilder*,
const Window& conv_window)>& create_sharded_dot,
const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
bool require_matching_devices_to_group,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
@ -1090,10 +1107,9 @@ StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
PartitionedHlo(rhs.hlo(),
GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
inner_state),
MakePartitionedShape(output_base_shape, outer_output_tmp_sharding),
inner_output_sharding, dims_mapping, num_partitions / group_count,
create_sharded_dot, module, original_hlo, options, b,
windowed_dot_general_loops));
output_base_shape, inner_output_sharding, dims_mapping,
num_partitions / group_count, create_sharded_dot, conv_window, module,
original_hlo, options, b, windowed_dot_general_loops));
if (!dot) {
return nullptr;
}
@ -1124,8 +1140,9 @@ StatusOr<HloInstruction*> PartitionDot(
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
HloModule* module, HloInstruction* original_hlo,
HloInstruction*, HloInstruction*, SpmdBuilder*,
const Window& conv_window)>& create_sharded_dot,
const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
bool require_matching_devices_to_group,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
@ -1180,35 +1197,25 @@ StatusOr<HloInstruction*> PartitionDot(
// Try partition the purely spatially-partitioned convolution with convolution
// spatial dimension partitioned or depthwise parallel dimension partitioned.
if (!dims_mapping.conv_spatial_dims.empty() &&
bool is_conv_spatial_dim_partitioned =
(lhs_conv_spatial_partitions > 1 || rhs_conv_spatial_partitions > 1 ||
output_conv_spatial_partitions > 1 ||
original_hlo->batch_group_count() > 1 ||
original_hlo->feature_group_count() > 1)) {
const auto& conv_dnums = original_hlo->convolution_dimension_numbers();
auto window = original_hlo->window();
// TODO(wangtao): remove this hack by passing create_sharded_conv to
// PartitionConv.
// Update convolution window when it is in the recursive call for
// batch_dims.
if (original_hlo->batch_group_count() == 1 &&
original_hlo->feature_group_count() == 1 &&
!ShapeUtil::Compatible(original_hlo->shape(), output_base_shape)) {
for (const auto& dim : dims_mapping.batch_dims) {
auto wd = window.mutable_dimensions(dim.spatial);
wd->set_size(lhs.hlo()->shape().dimensions(
conv_dnums.input_spatial_dimensions(dim.spatial)));
wd->set_stride(std::max<int64>(1, wd->size() - 1));
wd->set_base_dilation(wd->size());
}
}
output_conv_spatial_partitions > 1);
bool is_conv_batch_or_contracting_dim_partitioned =
(lhs_batch_partitions > 1 || rhs_batch_partitions > 1 ||
output_batch_partitions > 1 ||
(lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1));
if ((!dims_mapping.conv_spatial_dims.empty() &&
is_conv_spatial_dim_partitioned &&
!is_conv_batch_or_contracting_dim_partitioned) ||
(original_hlo->opcode() == HloOpcode::kConvolution &&
(original_hlo->batch_group_count() > 1 ||
original_hlo->feature_group_count() > 1))) {
TF_ASSIGN_OR_RETURN(
auto partitioned_conv,
PartitionConvolution(lhs, rhs, output_base_shape, output_sharding,
dims_mapping, window, original_hlo, num_partitions,
options, lhs.state().partition_id, module, b));
dims_mapping, create_sharded_dot, conv_window,
original_hlo, num_partitions, options,
lhs.state().partition_id, module, b));
if (partitioned_conv) {
return partitioned_conv;
@ -1219,7 +1226,7 @@ StatusOr<HloInstruction*> PartitionDot(
auto try_partitioned_dot,
PartitionBaseCase(
lhs, rhs, output_base_shape, output_sharding, dims_mapping,
num_partitions, create_sharded_dot, module, original_hlo,
num_partitions, create_sharded_dot, conv_window, module, original_hlo,
lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
lhs_contracting_partitions, rhs_contracting_partitions,
lhs_non_contracting_partitions, rhs_non_contracting_partitions,
@ -1243,8 +1250,8 @@ StatusOr<HloInstruction*> PartitionDot(
lhs, rhs, output_base_shape, output_sharding, dims_mapping,
num_partitions, lhs_contracting_partitions,
rhs_contracting_partitions, lhs_non_contracting_partitions,
rhs_non_contracting_partitions, create_sharded_dot, module,
original_hlo, require_matching_devices_to_group, options, b,
rhs_non_contracting_partitions, create_sharded_dot, conv_window,
module, original_hlo, require_matching_devices_to_group, options, b,
windowed_dot_general_loops));
if (dot) {
return dot;
@ -1268,7 +1275,6 @@ StatusOr<HloInstruction*> PartitionDot(
ShapeUtil::ByteSizeOf(rhs.hlo()->shape()) <=
rhs_non_contracting_partitions *
ShapeUtil::ByteSizeOf(lhs.hlo()->shape()));
TF_ASSIGN_OR_RETURN(
auto dot,
PartitionDotGroupOnNonContracting(
@ -1284,7 +1290,7 @@ StatusOr<HloInstruction*> PartitionDot(
lhs_matching ? output_rhs_non_contracting_partitions
: output_lhs_non_contracting_partitions,
output_base_shape, output_sharding, dims_mapping, num_partitions,
create_sharded_dot, module, original_hlo,
create_sharded_dot, conv_window, module, original_hlo,
require_matching_devices_to_group, options, b,
windowed_dot_general_loops));
if (dot) {
@ -1304,15 +1310,15 @@ StatusOr<HloInstruction*> PartitionDot(
}
if (!matching_dims.empty()) {
TF_ASSIGN_OR_RETURN(
auto dot,
PartitionDotGroupOnNonContracting(
/*lhs_matching=*/true, lhs, rhs, lhs_contracting_partitions,
rhs_contracting_partitions, matching_dims,
rhs_non_contracting_partitions,
output_rhs_non_contracting_partitions, output_base_shape,
output_sharding, dims_mapping, num_partitions, create_sharded_dot,
module, original_hlo, require_matching_devices_to_group, options,
b, windowed_dot_general_loops));
auto dot, PartitionDotGroupOnNonContracting(
/*lhs_matching=*/true, lhs, rhs,
lhs_contracting_partitions, rhs_contracting_partitions,
matching_dims, rhs_non_contracting_partitions,
output_rhs_non_contracting_partitions,
output_base_shape, output_sharding, dims_mapping,
num_partitions, create_sharded_dot, conv_window, module,
original_hlo, require_matching_devices_to_group,
options, b, windowed_dot_general_loops));
if (dot) {
return dot;
}
@ -1331,15 +1337,15 @@ StatusOr<HloInstruction*> PartitionDot(
}
if (!matching_dims.empty()) {
TF_ASSIGN_OR_RETURN(
auto dot,
PartitionDotGroupOnNonContracting(
/*lhs_matching=*/false, rhs, lhs, rhs_contracting_partitions,
lhs_contracting_partitions, matching_dims,
lhs_non_contracting_partitions,
output_lhs_non_contracting_partitions, output_base_shape,
output_sharding, dims_mapping, num_partitions, create_sharded_dot,
module, original_hlo, require_matching_devices_to_group, options,
b, windowed_dot_general_loops));
auto dot, PartitionDotGroupOnNonContracting(
/*lhs_matching=*/false, rhs, lhs,
rhs_contracting_partitions, lhs_contracting_partitions,
matching_dims, lhs_non_contracting_partitions,
output_lhs_non_contracting_partitions,
output_base_shape, output_sharding, dims_mapping,
num_partitions, create_sharded_dot, conv_window, module,
original_hlo, require_matching_devices_to_group,
options, b, windowed_dot_general_loops));
if (dot) {
return dot;
}
@ -1356,7 +1362,8 @@ StatusOr<HloInstruction*> PartitionDot(
output_lhs_non_contracting_partitions,
output_rhs_non_contracting_partitions, output_base_shape,
output_sharding, dims_mapping, num_partitions, create_sharded_dot,
module, original_hlo, require_matching_devices_to_group, options, b,
conv_window, module, original_hlo,
require_matching_devices_to_group, options, b,
windowed_dot_general_loops));
if (dot) {
return dot;
@ -1374,14 +1381,14 @@ StatusOr<HloInstruction*> PartitionDot(
}
if (!matching_dims.empty()) {
TF_ASSIGN_OR_RETURN(
auto dot,
PartitionDotGroupOnContracting(
lhs, rhs, matching_dims, output_batch_partitions,
output_lhs_non_contracting_partitions,
output_rhs_non_contracting_partitions, output_base_shape,
output_sharding, dims_mapping, num_partitions, create_sharded_dot,
module, original_hlo, require_matching_devices_to_group, options,
b, windowed_dot_general_loops));
auto dot, PartitionDotGroupOnContracting(
lhs, rhs, matching_dims, output_batch_partitions,
output_lhs_non_contracting_partitions,
output_rhs_non_contracting_partitions,
output_base_shape, output_sharding, dims_mapping,
num_partitions, create_sharded_dot, conv_window, module,
original_hlo, require_matching_devices_to_group,
options, b, windowed_dot_general_loops));
if (dot) {
return dot;
}
@ -1401,8 +1408,9 @@ StatusOr<HloInstruction*> PartitionDot(
PartitionDot(PartitionedHlo(lhs.hlo(), lhs.base_shape(), inner_state),
PartitionedHlo(rhs.hlo(), rhs.base_shape(), inner_state),
output_base_shape, grouped_output.sharding, dims_mapping,
output_sharding.NumTiles(), create_sharded_dot, module,
original_hlo, options, b, windowed_dot_general_loops));
output_sharding.NumTiles(), create_sharded_dot,
conv_window, module, original_hlo, options, b,
windowed_dot_general_loops));
if (dot) {
return dot;
}
@ -1414,7 +1422,7 @@ StatusOr<HloInstruction*> PartitionDot(
auto dot,
PartitionBaseCase(
lhs, rhs, output_base_shape, output_sharding, dims_mapping,
num_partitions, create_sharded_dot, module, original_hlo,
num_partitions, create_sharded_dot, conv_window, module, original_hlo,
lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
lhs_contracting_partitions, rhs_contracting_partitions,
lhs_non_contracting_partitions, rhs_non_contracting_partitions,
@ -1433,8 +1441,9 @@ StatusOr<HloInstruction*> PartitionDot(
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
HloModule* module, HloInstruction* original_hlo,
HloInstruction*, HloInstruction*, SpmdBuilder*,
const Window& conv_window)>& create_sharded_dot,
const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
const SpmdPartitionerOptions& options, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
windowed_dot_general_loops) {
@ -1444,17 +1453,18 @@ StatusOr<HloInstruction*> PartitionDot(
TF_ASSIGN_OR_RETURN(
auto try_partition,
PartitionDot(lhs, rhs, output_base_shape, output_sharding, dims_mapping,
num_partitions, create_sharded_dot, module, original_hlo,
require_matching_devices_to_group, options, b,
windowed_dot_general_loops));
num_partitions, create_sharded_dot, conv_window, module,
original_hlo, require_matching_devices_to_group, options,
b, windowed_dot_general_loops));
if (try_partition) {
return try_partition;
}
}
// Default action.
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.Replicate().hlo(),
rhs.Replicate().hlo(), b));
TF_ASSIGN_OR_RETURN(
auto dot, create_sharded_dot(lhs.Replicate().hlo(), rhs.Replicate().hlo(),
b, conv_window));
dot->set_sharding(HloSharding::Replicate());
return PartitionedHlo(dot, output_base_shape, lhs.state())
.Reshard(output_sharding)
@ -1466,14 +1476,20 @@ StatusOr<HloInstruction*> PartitionDot(
Status SpmdPartitioningVisitor::HandleDotHelper(
HloInstruction* hlo, const DotConvDimsMapping& dims_mapping,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) {
HloInstruction*, HloInstruction*, SpmdBuilder*,
const Window& conv_window)>& create_sharded_dot) {
auto& lhs = GetPartitionedHlo(hlo->operand(0));
auto& rhs = GetPartitionedHlo(hlo->operand(1));
Window conv_window;
if (hlo->opcode() == HloOpcode::kConvolution) {
conv_window = hlo->window();
}
TF_ASSIGN_OR_RETURN(
auto partitioned_dot,
PartitionDot(lhs, rhs, hlo->shape(), hlo->sharding(), dims_mapping,
num_partitions_, create_sharded_dot, module_, hlo, options_,
&b_, &windowed_dot_general_loops_));
num_partitions_, create_sharded_dot, conv_window, module_,
hlo, options_, &b_, &windowed_dot_general_loops_));
SetPartitionedHlo(hlo, [&] { return partitioned_dot; });
return Status::OK();
}

View File

@ -407,10 +407,11 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
Status HandlePartitionId(HloInstruction* hlo) override;
// Implementation of dot partitioning given DotGeneralDimsMapping.
Status HandleDotHelper(
HloInstruction* hlo, const DotConvDimsMapping& dims_mapping,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot);
Status HandleDotHelper(HloInstruction* hlo,
const DotConvDimsMapping& dims_mapping,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*,
const Window& conv_window)>& create_sharded_dot);
// Common handle for elementwise HLOs.
Status HandleElementwise(HloInstruction* hlo);

View File

@ -5893,6 +5893,50 @@ ENTRY entry {
op::Shape("f32[1,1,128,256]")));
}
TEST_F(SpmdPartitioningTest,
ConvolutionInputSpatialDimAndFeatureDimParttiioned) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[8,210,210,12] parameter(0)
%lhs.copy = f32[8,210,210,12] copy(f32[8,210,210,12] %lhs),
sharding={devices=[1,2,1,2]0,1,2,3}
%rhs = f32[3,3,12,32] parameter(1)
%rhs.copy = f32[3,3,12,32] copy(f32[3,3,12,32] %rhs),
sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
ROOT %conv = f32[8,210,210,32] convolution(
f32[8,210,210,12] %lhs.copy,
f32[3,3,12,32] %rhs.copy),
window={size=3x3 pad=1_1x1_1},
dim_labels=b01f_01io->b01f,
sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Reshape())),
op::Shape("f32[8,105,210,6]"));
auto left_halo =
AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,1,210,6]"));
auto right_halo =
AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,1,210,6]"));
auto exchanged_lhs = AllOf(
op::Select(op::And(_, _), op::Concatenate(left_halo, lhs, right_halo),
op::Broadcast(_)),
op::Shape("f32[8,107,210,6]"));
auto rhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
op::Reshape(), op::Constant())),
op::Shape("f32[3,3,6,32]"));
EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(
exchanged_lhs, op::CollectivePermute(rhs))),
op::Shape("f32[8,105,210,32]")));
}
} // namespace
} // namespace spmd
} // namespace xla

View File

@ -2676,6 +2676,7 @@ xla_test(
xla_test(
name = "cholesky_test",
srcs = ["cholesky_test.cc"],
real_hardware_only = True,
tags = [
"no_rocm",
"optonly",

View File

@ -61,6 +61,44 @@ XLA_TEST_F(CholeskyTest, NonPSDInput) {
ErrorSpec(1e-4, 1e-4));
}
XLA_TEST_F(CholeskyTest, NonPSDBatched) {
XlaBuilder builder(TestName());
Array3D<float> a_vals({
{
{10, 0, 0},
{1, 20, 0},
{1, 1, 30},
},
{
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
},
});
XlaOp a;
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
Cholesky(a, /*lower=*/true);
float nan = std::numeric_limits<float>::quiet_NaN();
Array3D<float> expected({
{
{3.16227766, 0., 0.},
{0.31622777, 4.4609416, 0.},
{0.31622777, 0.20175113, 5.46436606},
},
{
{nan, nan, nan},
{nan, nan, nan},
{nan, nan, nan},
},
});
ComputeAndCompareR3<float>(&builder, expected, {a_data.get()},
ErrorSpec(1e-4, 1e-4));
}
XLA_TEST_F(CholeskyTest, Lower) {
XlaBuilder builder(TestName());
@ -181,7 +219,7 @@ class RandomCholeskyTest
: public ClientLibraryTestBase,
public ::testing::WithParamInterface<CholeskyTestCase> {};
XLA_TEST_P(RandomCholeskyTest, Random) {
XLA_TEST_P(RandomCholeskyTest, Real) {
// Test fails with TensorFloat-32 enabled
tensorflow::enable_tensor_float_32_execution(false);
XlaBuilder builder(TestName());
@ -220,14 +258,65 @@ XLA_TEST_P(RandomCholeskyTest, Random) {
ErrorSpec(1e-4, 1e-4));
}
XLA_TEST_P(RandomCholeskyTest, Complex) {
// Test fails with TensorFloat-32 enabled
tensorflow::enable_tensor_float_32_execution(false);
XlaBuilder builder(TestName());
auto test_params = GetParam();
std::vector<int64> dimensions = {std::get<0>(test_params),
std::get<1>(test_params),
std::get<1>(test_params)};
bool lower = std::get<2>(test_params);
Shape shape = ShapeUtil::MakeShape(F32, dimensions);
TF_ASSERT_OK_AND_ASSIGN(
auto literal_real,
LiteralUtil::CreateRandomLiteral<F32>(shape, 0.0, 1.0));
TF_ASSERT_OK_AND_ASSIGN(
auto literal_imag,
LiteralUtil::CreateRandomLiteral<F32>(shape, 0.0, 1.0));
auto input_real = Parameter(&builder, 0, shape, "input_real");
auto input_imag = Parameter(&builder, 1, shape, "input_imag");
auto input = Complex(input_real, input_imag);
// Form a random positive definite matrix.
auto matrix = BatchDot(input, TransposeInMinorDims(Conj(input)),
PrecisionConfig::HIGHEST);
auto cholesky = Triangle(Cholesky(matrix, lower), lower);
// Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0
XlaOp verification;
if (lower) {
verification = BatchDot(cholesky, TransposeInMinorDims(Conj(cholesky)),
PrecisionConfig::HIGHEST);
} else {
verification = BatchDot(TransposeInMinorDims(Conj(cholesky)), cholesky,
PrecisionConfig::HIGHEST);
}
auto delta = matrix - verification;
Reduce(Abs(delta * Conj(delta)), ConstantR0<float>(&builder, 0.0),
CreateScalarAddComputation(F32, &builder), {0, 1, 2});
TF_ASSERT_OK_AND_ASSIGN(auto input_data_real,
client_->TransferToServer(literal_real));
TF_ASSERT_OK_AND_ASSIGN(auto input_data_imag,
client_->TransferToServer(literal_imag));
ComputeAndCompareR0<float>(&builder, 0.0,
{input_data_real.get(), input_data_imag.get()},
ErrorSpec(1e-4, 1e-4));
}
INSTANTIATE_TEST_SUITE_P(RandomCholeskyTestInstance, RandomCholeskyTest,
::testing::Values(CholeskyTestCase{1, 1, true},
CholeskyTestCase{1, 2, true},
CholeskyTestCase{1, 50, true},
CholeskyTestCase{1, 50, false},
CholeskyTestCase{1, 255, false},
CholeskyTestCase{10, 5, true},
CholeskyTestCase{5, 10, false},
CholeskyTestCase{2, 20, true}));
CholeskyTestCase{2, 20, true},
CholeskyTestCase{2, 129, true}));
} // namespace
} // namespace xla

View File

@ -6,9 +6,11 @@ op {
element in the tensor. Input range is `[-inf, inf]` and
output range is `[-1,1]`.
```python
x = tf.constant([-float("inf"), -5, -0.5, 1, 1.2, 2, 3, float("inf")])
tf.math.tanh(x) ==> [-1. -0.99990916 -0.46211717 0.7615942 0.8336547 0.9640276 0.9950547 1.]
```
>>> x = tf.constant([-float("inf"), -5, -0.5, 1, 1.2, 2, 3, float("inf")])
>>> tf.math.tanh(x)
<tf.Tensor: shape=(8,), dtype=float32, numpy=
array([-1. , -0.99990916, -0.46211717, 0.7615942 , 0.8336547 ,
0.9640276 , 0.9950547 , 1. ], dtype=float32)>
END
}

View File

@ -218,6 +218,9 @@ void BaseCollectiveExecutor::StartAbort(const Status& s) {
VLOG(1) << "BaseCollectiveExecutor::StartAbort " << s;
cem_->GetParamResolver()->StartAbort(s);
remote_access_->StartAbort(s);
if (cem_->GetNcclCommunicator() != nullptr) {
cem_->GetNcclCommunicator()->StartAbort(s);
}
}
void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,

View File

@ -102,7 +102,6 @@ class TestCollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
}
NcclCommunicatorInterface* GetNcclCommunicator() const override {
LOG(FATAL) << "Unimplemented"; // Crash OK
return nullptr;
}

View File

@ -227,7 +227,9 @@ Status DataServiceDispatcherClient::EnsureInitialized() {
std::shared_ptr<grpc::ChannelCredentials> credentials;
TF_RETURN_IF_ERROR(
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
auto channel = grpc::CreateChannel(address_, credentials);
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(std::numeric_limits<int32>::max());
auto channel = grpc::CreateCustomChannel(address_, credentials, args);
stub_ = DispatcherService::NewStub(channel);
return Status::OK();
}

View File

@ -70,6 +70,10 @@ Status DataServiceWorkerImpl::Start(const std::string& worker_address) {
Status s = Heartbeat();
while (!s.ok()) {
if (!errors::IsUnavailable(s) && !errors::IsAborted(s) &&
!errors::IsCancelled(s)) {
return s;
}
LOG(WARNING) << "Failed to register with dispatcher at "
<< config_.dispatcher_address() << ": " << s;
Env::Default()->SleepForMicroseconds(kRetryIntervalMicros);

View File

@ -25,6 +25,15 @@ namespace data {
namespace model {
namespace {
// Helper function for node traversal that doesn't skip any nodes.
inline bool IsAnyNode(const std::shared_ptr<Node> node) { return true; }
// Helper function for node traversal that filters out nodes for which
// autotuning is disabled.
inline bool IsAutotuneNode(const std::shared_ptr<Node> node) {
return node->autotune();
}
// Wrapper for the square function to reduce verbosity.
inline double Square(double x) { return x * x; }
@ -82,7 +91,8 @@ class InterleaveMany : public Node {
if (num_inputs() <= 1) {
(*output_times)[long_name()] = self_processing_time;
if (gradients) {
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
for (const auto& node :
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
gradients->erase(node->long_name());
}
}
@ -94,7 +104,8 @@ class InterleaveMany : public Node {
(*output_times)[inputs_.front()->long_name()]) /
static_cast<double>(num_inputs() - 1);
if (gradients) {
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
for (const auto& node :
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
if (gradient) {
*gradient /= static_cast<double>(num_inputs() - 1);
@ -211,7 +222,8 @@ class AsyncInterleaveMany : public Node {
if (num_inputs() <= 1) {
(*output_times)[long_name()] = self_processing_time;
if (gradients) {
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
for (const auto& node :
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
gradients->erase(node->long_name());
}
}
@ -245,7 +257,8 @@ class AsyncInterleaveMany : public Node {
consumer_time_der +
producer_time_der * inputs_time_der_sum / parallelism;
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
for (const auto& node :
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
if (gradient) {
*gradient *= (producer_time_der /
@ -345,14 +358,16 @@ class KnownRatio : public Node {
if (ratio_ == 0) {
(*output_times)[long_name()] = self_processing_time;
if (gradients) {
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
for (const auto& node :
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
gradients->erase(node->long_name());
}
}
return;
}
if (gradients) {
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
for (const auto& node :
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
if (gradient) {
*gradient *= ratio_;
@ -483,7 +498,8 @@ class AsyncKnownRatio : public Node {
consumer_time = input_time;
producer_time = 0.0L;
if (gradients) {
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
for (const auto& node :
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
gradients->erase(node->long_name());
}
@ -528,7 +544,8 @@ class AsyncKnownRatio : public Node {
(*output_time_gradients)[long_name()] =
consumer_time_der + producer_time_der * inputs_time_der_sum;
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
for (const auto& node :
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
if (gradient) {
*gradient *= (ratio_ * producer_time_der);
@ -629,7 +646,8 @@ class UnknownRatio : public Node {
inputs_.front()->num_elements() == 0) {
(*output_times)[long_name()] = self_processing_time;
if (gradients) {
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
for (const auto& node :
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
gradients->erase(node->long_name());
}
}
@ -640,7 +658,8 @@ class UnknownRatio : public Node {
double ratio = static_cast<double>(inputs_.front()->num_elements()) /
static_cast<double>(num_elements_);
if (gradients) {
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
for (const auto& node :
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
if (gradient) {
*gradient *= ratio;
@ -917,7 +936,8 @@ void Node::CollectTunableParameters(
absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) const {
tf_shared_lock l(mu_);
// Collect tunable parameters from the leaves of the nodes tree to the root.
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
for (const auto& node :
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
tf_shared_lock l(node->mu_);
node->CollectTunableParametersHelper(parameters);
}
@ -928,7 +948,8 @@ string Node::DebugString() const {
absl::flat_hash_map<string, string> debug_strings;
tf_shared_lock l(mu_);
// Build up the debug string from the leaves of the nodes tree to the root.
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
for (const auto& node :
CollectNodes(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
tf_shared_lock l(node->mu_);
node->DebugStringHelper(&debug_strings);
}
@ -952,7 +973,7 @@ double Node::OutputTime(absl::flat_hash_map<string, double>* input_times,
// `nullptr`) and the output time for each node.
absl::flat_hash_map<string, double> output_time_gradients, output_times;
tf_shared_lock l(mu_);
auto nodes = CollectNodes(TraversalOrder::BFS);
auto nodes = CollectNodes(TraversalOrder::BFS, IsAutotuneNode);
// Computes and stores input time for each node from the root to leaves of the
// nodes tree.
@ -1001,7 +1022,8 @@ double Node::TotalBufferedBytes() const {
absl::flat_hash_map<string, double> total_bytes;
tf_shared_lock l(mu_);
// Compute total buffered bytes from the leaves of the nodes tree to the root.
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
for (const auto& node :
CollectNodes(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
tf_shared_lock l(node->mu_);
node->TotalBufferedBytesHelper(&total_bytes);
}
@ -1015,7 +1037,8 @@ double Node::TotalMaximumBufferedBytes() const {
tf_shared_lock l(mu_);
// Compute total maximum buffered bytes from the leaves of the nodes tree
// to the root.
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
for (const auto& node :
CollectNodes(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
tf_shared_lock l(node->mu_);
node->TotalMaximumBufferedBytesHelper(&total_bytes);
}
@ -1033,7 +1056,8 @@ double Node::TotalProcessingTime(
// Computes per-element CPU time spent in the subtree rooted in the node from
// the leaves of the nodes tree to the root.
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
for (const auto& node :
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
tf_shared_lock l(node->mu_);
node->TotalProcessingTimeLocked(processing_times, &total_processing_times);
}
@ -1123,14 +1147,17 @@ double Node::SelfProcessingTimeLocked() const {
static_cast<double>(num_elements_);
}
Node::NodeVector Node::CollectNodes(TraversalOrder order) const
Node::NodeVector Node::CollectNodes(
TraversalOrder order, bool collect_node(const std::shared_ptr<Node>)) const
TF_SHARED_LOCKS_REQUIRED(mu_) {
NodeVector node_vector;
std::list<std::shared_ptr<Node>> temp_list;
for (auto& input : inputs_) {
node_vector.push_back(input);
temp_list.push_back(input);
if (collect_node(input)) {
node_vector.push_back(input);
temp_list.push_back(input);
}
}
while (!temp_list.empty()) {
@ -1138,8 +1165,10 @@ Node::NodeVector Node::CollectNodes(TraversalOrder order) const
temp_list.pop_front();
tf_shared_lock l(cur_node->mu_);
for (auto& input : cur_node->inputs_) {
node_vector.push_back(input);
temp_list.push_back(input);
if (collect_node(input)) {
node_vector.push_back(input);
temp_list.push_back(input);
}
}
}

View File

@ -485,8 +485,12 @@ class Node {
// Returns a vector of nodes of the subtree rooted in this node. The nodes are
// either in breadth-first search or reverse breadth-first search order
// depending on the `order` argument. The root node itself is not collected.
NodeVector CollectNodes(TraversalOrder order) const
// depending on the `order` argument. The nodes are collected based on the
// results of the `collect_node` predicate: if the predicate returns `false`
// for a given node, then the subtree rooted in this node is excluded. The
// root node itself is not collected.
NodeVector CollectNodes(TraversalOrder order,
bool collect_node(const std::shared_ptr<Node>)) const
TF_SHARED_LOCKS_REQUIRED(mu_);
// Collect tunable parameters for the node.

View File

@ -531,7 +531,6 @@ Status SchedulerState::Init(const GrapplerItem* item,
initial_nodes->push_back(curr_node);
VLOG(3) << "Added ready node: " << curr_node->name();
}
feed_nodes.erase(curr_node->name());
if (IsPersistent(*curr_node)) {
@ -778,6 +777,10 @@ NodeState& SchedulerState::GetNodeStateOrCreateIt(const NodeDef* node) {
node_state.num_outputs_executed[-1] = 0;
node_state.outputs[-1] = {};
// Initialize time_scheduled to infinity, so we know whether it has been
// assigned a non-default value later.
node_state.time_scheduled = Costs::Duration().infinity();
return it->second;
}
@ -862,10 +865,16 @@ std::vector<const NodeDef*> SchedulerState::MarkNodeExecuted(
// Node is scheduled when the device is available AND all the inputs are
// ready; hence, time_scheduled is time_ready if time_ready > device curr
// time.
node_state.time_scheduled =
std::max(device.GetCurrTime(), node_state.time_ready);
// Override device curr time with the time_scheduled.
device.device_costs.execution_time = node_state.time_scheduled;
// NodeState times are assigned infinity at initialization. If they are
// still infinity here, we need to assign them. If not, it has been assigned
// already, so skip. This latter case may occur when a scheduler in-lines
// function calls, and thus schedules only function sub-nodes.
if (node_state.time_scheduled == Costs::Duration().infinity()) {
node_state.time_scheduled =
std::max(device.GetCurrTime(), node_state.time_ready);
// Override device curr time with the time_scheduled.
device.device_costs.execution_time = node_state.time_scheduled;
}
device.device_costs = CombineCosts(device.device_costs, total_node_costs);
auto curr_time = device.GetCurrTime();
node_state.time_finished = curr_time;
@ -1000,8 +1009,13 @@ Costs SchedulerState::Summary() const {
for (const auto& node_port : state.persistent_nodes) {
const auto* node = node_port.first;
const auto port = node_port.second;
const auto output_size =
CalculateOutputSize(node_map_.at(node).output_properties, port);
auto output_size = 0;
// Check if the node is in the node_map. It may be that the node executed
// on this device was executed by a different Scheduler.
if (node_map_.find(node) != node_map_.end()) {
output_size =
CalculateOutputSize(node_map_.at(node).output_properties, port);
}
persistent_memory_usage += output_size;
op_to_memory[node->op()] += output_size;
persistent_ops.insert(node->op());
@ -1048,8 +1062,12 @@ Costs SchedulerState::Summary() const {
for (const auto& node_port : state.mem_usage_snapshot_at_peak) {
const auto* node = node_port.first;
const auto port = node_port.second;
op_to_memory[node->op()] +=
CalculateOutputSize(node_map_.at(node).output_properties, port);
// Check if the node is in the node_map. It may be that the node executed
// on this device was executed by a different Scheduler.
if (node_map_.find(node) != node_map_.end()) {
op_to_memory[node->op()] +=
CalculateOutputSize(node_map_.at(node).output_properties, port);
}
}
Costs::NanoSeconds total_compute_time_ns;
bool is_total_cost_accurate = true;
@ -1132,6 +1150,12 @@ void SchedulerState::GenerateRunMetadata(RunMetadata* metadata) {
DeviceStepStats* device_stepstats = stepstats->add_dev_stats();
device_stepstats->set_device(device.first);
for (const auto& node_def : device.second.nodes_executed) {
// Only proceed if the node is in the node_map. This is to cover the case
// where a device has executed a node that is not in the node_map of
// this scheduler.
if (node_map_.find(node_def) == node_map_.end()) {
continue;
}
const NodeState& nodestate = node_map_.at(node_def);
NodeExecStats* node_stats = device_stepstats->add_node_stats();
uint64 total_output_size = 0;
@ -1229,6 +1253,12 @@ SchedulerState::GetPersistentMemoryUsage() const {
return result;
}
void SchedulerState::SetNodeStateTimeScheduled(const NodeDef* node) {
auto& node_state = node_map_.at(node);
auto& device = device_[node_state.device_name];
node_state.time_scheduled = device.GetCurrTime();
}
VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
const bool use_aggressive_shape_inference,
Cluster* cluster,

View File

@ -371,6 +371,10 @@ class SchedulerState {
}
protected:
// Assigns the time_scheduled in the NodeState of node to the current
// execution_time of the device executing this node.
void SetNodeStateTimeScheduled(const NodeDef* node);
// This method can be used by a class derived from SchedulerState to
// access the device state map.
std::unordered_map<string, DeviceState>* GetMutableDeviceState() {

View File

@ -5228,6 +5228,24 @@ tf_kernel_library(
deps = STRING_DEPS,
)
tf_cc_test(
name = "as_string_op_test",
size = "small",
srcs = ["as_string_op_test.cc"],
deps = [
":as_string_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_kernel_library(
name = "unicode_ops",
prefix = "unicode_ops",

View File

@ -65,9 +65,26 @@ class AsStringOp : public OpKernel {
OP_REQUIRES(ctx, !(scientific && shortest),
errors::InvalidArgument(
"Cannot select both scientific and shortest notation"));
format_ = "%";
if (!fill_string.empty()) {
switch (fill_string[0]) {
case ' ':
case '+':
case '-':
case '0':
case '#':
strings::Appendf(&format_, "%s", fill_string.c_str());
break;
default:
bool fill_not_supported = true;
OP_REQUIRES(ctx, !fill_not_supported,
errors::InvalidArgument("Fill argument not supported: \"",
fill_string, "\""));
}
}
if (width > -1) {
strings::Appendf(&format_, "%s%d", fill_string.c_str(), width);
strings::Appendf(&format_, "%d", width);
}
if (precision > -1) {
strings::Appendf(&format_, ".%d", precision);

View File

@ -0,0 +1,245 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
class AsStringGraphTest : public OpsTestBase {
protected:
Status Init(DataType input_type, const string& fill = "", int width = -1,
int precision = -1, bool scientific = false,
bool shortest = false) {
TF_CHECK_OK(NodeDefBuilder("op", "AsString")
.Input(FakeInput(input_type))
.Attr("fill", fill)
.Attr("precision", precision)
.Attr("scientific", scientific)
.Attr("shortest", shortest)
.Attr("width", width)
.Finalize(node_def()));
return InitOp();
}
};
TEST_F(AsStringGraphTest, Int8) {
TF_ASSERT_OK(Init(DT_INT8));
AddInputFromArray<int8>(TensorShape({3}), {-42, 0, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
test::FillValues<tstring>(&expected, {"-42", "0", "42"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, Int64) {
TF_ASSERT_OK(Init(DT_INT64));
AddInputFromArray<int64>(TensorShape({3}), {-42, 0, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
test::FillValues<tstring>(&expected, {"-42", "0", "42"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, FloatDefault) {
TF_ASSERT_OK(Init(DT_FLOAT));
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
test::FillValues<tstring>(
&expected, {"-42.000000", "0.000000", "3.141590", "42.000000"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, FloatScientific) {
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
/*scientific=*/true));
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
test::FillValues<tstring>(&expected, {"-4.200000e+01", "0.000000e+00",
"3.141590e+00", "4.200000e+01"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, FloatShortest) {
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
/*scientific=*/false, /*shortest=*/true));
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
test::FillValues<tstring>(&expected, {"-42", "0", "3.14159", "42"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, FloatPrecisionOnly) {
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/2));
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
test::FillValues<tstring>(&expected, {"-42.00", "0.00", "3.14", "42.00"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, FloatWidthOnly) {
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/5));
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
test::FillValues<tstring>(
&expected, {"-42.000000", "0.000000", "3.141590", "42.000000"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, Float_5_2_Format) {
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/5, /*precision=*/2));
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
test::FillValues<tstring>(&expected, {"-42.00", " 0.00", " 3.14", "42.00"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, Complex) {
TF_ASSERT_OK(Init(DT_COMPLEX64, /*fill=*/"", /*width=*/5, /*precision=*/2));
AddInputFromArray<complex64>(TensorShape({3}), {{-4, 2}, {0}, {3.14159, -1}});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
test::FillValues<tstring>(
&expected, {"(-4.00, 2.00)", "( 0.00, 0.00)", "( 3.14,-1.00)"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, Bool) {
TF_ASSERT_OK(Init(DT_BOOL));
AddInputFromArray<bool>(TensorShape({2}), {true, false});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({2}));
test::FillValues<tstring>(&expected, {"true", "false"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, String) {
Status s = Init(DT_STRING);
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
ASSERT_TRUE(absl::StrContains(
s.error_message(),
"Value for attr 'T' of string is not in the list of allowed values"));
}
TEST_F(AsStringGraphTest, OnlyOneOfScientificAndShortest) {
Status s = Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
/*scientific=*/true, /*shortest=*/true);
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
ASSERT_TRUE(
absl::StrContains(s.error_message(),
"Cannot select both scientific and shortest notation"));
}
TEST_F(AsStringGraphTest, NoShortestForNonFloat) {
Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
/*scientific=*/false, /*shortest=*/true);
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
ASSERT_TRUE(absl::StrContains(
s.error_message(),
"scientific and shortest format not supported for datatype"));
}
TEST_F(AsStringGraphTest, NoScientificForNonFloat) {
Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
/*scientific=*/true);
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
ASSERT_TRUE(absl::StrContains(
s.error_message(),
"scientific and shortest format not supported for datatype"));
}
TEST_F(AsStringGraphTest, NoPrecisionForNonFloat) {
Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/5);
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
ASSERT_TRUE(absl::StrContains(s.error_message(),
"precision not supported for datatype"));
}
TEST_F(AsStringGraphTest, LongFill) {
Status s = Init(DT_INT32, /*fill=*/"asdf");
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
ASSERT_TRUE(absl::StrContains(s.error_message(),
"Fill string must be one or fewer characters"));
}
TEST_F(AsStringGraphTest, FillWithZero) {
TF_ASSERT_OK(Init(DT_INT64, /*fill=*/"0", /*width=*/4));
AddInputFromArray<int64>(TensorShape({3}), {-42, 0, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
test::FillValues<tstring>(&expected, {"-042", "0000", "0042"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, FillWithSpace) {
TF_ASSERT_OK(Init(DT_INT64, /*fill=*/" ", /*width=*/4));
AddInputFromArray<int64>(TensorShape({3}), {-42, 0, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
test::FillValues<tstring>(&expected, {" -42", " 0", " 42"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, FillWithChar1) {
TF_ASSERT_OK(Init(DT_INT64, /*fill=*/"-", /*width=*/4));
AddInputFromArray<int64>(TensorShape({3}), {-42, 0, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
test::FillValues<tstring>(&expected, {"-42 ", "0 ", "42 "});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, FillWithChar3) {
Status s = Init(DT_INT32, /*fill=*/"s");
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
ASSERT_TRUE(
absl::StrContains(s.error_message(), "Fill argument not supported"));
}
TEST_F(AsStringGraphTest, FillWithChar4) {
Status s = Init(DT_INT32, /*fill=*/"n");
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
ASSERT_TRUE(
absl::StrContains(s.error_message(), "Fill argument not supported"));
}
} // end namespace
} // end namespace tensorflow

Some files were not shown because too many files have changed in this diff Show More