Merge branch 'master' into toupstream/16x8_batch_matmul
This commit is contained in:
commit
0f082f6b40
@ -619,6 +619,8 @@ tf_cuda_library(
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -17,12 +17,16 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
|
||||
using tensorflow::string;
|
||||
using tensorflow::tstring;
|
||||
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) {
|
||||
float data[] = {value};
|
||||
@ -36,6 +40,19 @@ TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) {
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx,
|
||||
const tensorflow::tstring& value) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_STRING, nullptr, 0, status);
|
||||
tstring* data = static_cast<tstring*>(TF_TensorData(t));
|
||||
*data = value;
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value) {
|
||||
int data[] = {value};
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
@ -28,6 +29,10 @@ TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value);
|
||||
// Return a tensor handle containing a bool scalar
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value);
|
||||
|
||||
// Return a tensor handle containing a tstring scalar
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx,
|
||||
const tensorflow::tstring& value);
|
||||
|
||||
// Return a tensor handle containing a 2x2 matrix of doubles
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx);
|
||||
|
||||
|
@ -249,21 +249,36 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
|
||||
}
|
||||
|
||||
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
||||
auto tf_dlm_context = GetDlContext(h, status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* tf_dlm_data = TFE_TensorHandleDevicePointer(h, status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const Tensor* tensor = GetTensorFromHandle(h, status);
|
||||
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
|
||||
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
|
||||
|
||||
auto tf_dlm_type = GetDlDataType(data_type, status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
|
||||
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
|
||||
tf_dlm_tensor_ctx->reference = tensor_ref;
|
||||
|
||||
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
|
||||
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
|
||||
dlm_tensor->deleter = &DLManagedTensorDeleter;
|
||||
dlm_tensor->dl_tensor.ctx = GetDlContext(h, status);
|
||||
dlm_tensor->dl_tensor.ctx = tf_dlm_context;
|
||||
int ndim = tensor->dims();
|
||||
dlm_tensor->dl_tensor.ndim = ndim;
|
||||
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
|
||||
dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status);
|
||||
dlm_tensor->dl_tensor.data = tf_dlm_data;
|
||||
dlm_tensor->dl_tensor.dtype = tf_dlm_type;
|
||||
|
||||
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
|
||||
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
|
||||
@ -276,13 +291,14 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
||||
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
|
||||
}
|
||||
|
||||
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
|
||||
dlm_tensor->dl_tensor.shape = shape_arr->data();
|
||||
// There are two ways to represent compact row-major data
|
||||
// 1) nullptr indicates tensor is compact and row-majored.
|
||||
// 2) fill in the strides array as the real case for compact row-major data.
|
||||
// Here we choose option 2, since some frameworks didn't handle the strides
|
||||
// argument properly.
|
||||
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
|
||||
dlm_tensor->dl_tensor.strides = stride_arr->data();
|
||||
|
||||
dlm_tensor->dl_tensor.byte_offset =
|
||||
0; // TF doesn't handle the strides and byte_offsets here
|
||||
return static_cast<void*>(dlm_tensor);
|
||||
|
@ -62,6 +62,7 @@ cc_library(
|
||||
":function_metadata",
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:asset",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
|
||||
@ -69,6 +70,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
@ -138,7 +140,6 @@ cc_library(
|
||||
":saved_model_api",
|
||||
":saved_model_utils",
|
||||
":signature_def_function",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core/ops:restore_ops",
|
||||
@ -148,10 +149,10 @@ cc_library(
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
|
||||
"//tensorflow/cc/saved_model:bundle_v2",
|
||||
"//tensorflow/cc/saved_model:constants",
|
||||
"//tensorflow/cc/saved_model:loader_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
|
@ -8,6 +8,25 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "asset",
|
||||
srcs = [
|
||||
"asset.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"asset.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorhandle_convertible",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/cc/saved_model:constants",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "constant",
|
||||
srcs = [
|
||||
|
@ -0,0 +1,49 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Asset::Asset(ImmediateTensorHandlePtr handle)
|
||||
: TensorHandleConvertible(std::move(handle)) {}
|
||||
|
||||
Status Asset::Create(ImmediateExecutionContext* ctx,
|
||||
const std::string& saved_model_dir,
|
||||
const std::string& asset_filename,
|
||||
std::unique_ptr<Asset>* output) {
|
||||
std::string abs_path =
|
||||
io::JoinPath(saved_model_dir, kSavedModelAssetsDirectory, asset_filename);
|
||||
AbstractTensorPtr tensor(ctx->CreateStringScalar(abs_path));
|
||||
if (tensor.get() == nullptr) {
|
||||
return errors::Internal(
|
||||
"Failed to create scalar string tensor for Asset at path ", abs_path);
|
||||
}
|
||||
|
||||
ImmediateTensorHandlePtr handle(ctx->CreateLocalHandle(tensor.get()));
|
||||
output->reset(new Asset(std::move(handle)));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,50 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Asset : public TensorHandleConvertible {
|
||||
public:
|
||||
static Status Create(ImmediateExecutionContext* ctx,
|
||||
const std::string& saved_model_dir,
|
||||
const std::string& asset_filename,
|
||||
std::unique_ptr<Asset>* output);
|
||||
|
||||
// Asset is movable, but not copyable.
|
||||
Asset(Asset&& other) = default;
|
||||
Asset& operator=(Asset&& other) = default;
|
||||
|
||||
~Asset() override = default;
|
||||
|
||||
private:
|
||||
explicit Asset(ImmediateTensorHandlePtr handle);
|
||||
Asset(const Asset&) = delete;
|
||||
Asset& operator=(const Asset&) = delete;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_
|
@ -100,6 +100,20 @@ Status ValidateSavedFunctionCompatibleWithFunctionDef(
|
||||
|
||||
} // namespace
|
||||
|
||||
Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset,
|
||||
const std::string& saved_model_dir,
|
||||
absl::Span<const AssetFileDef> assets,
|
||||
std::unique_ptr<Asset>* output) {
|
||||
int asset_index = asset.asset_file_def_index();
|
||||
if (asset_index >= assets.size()) {
|
||||
return errors::FailedPrecondition(
|
||||
"SavedAsset contained asset index ", asset_index,
|
||||
" but AssetFileDef only contains ", assets.size(), " # of assets");
|
||||
}
|
||||
const std::string& asset_filename = assets[asset_index].filename();
|
||||
return Asset::Create(ctx, saved_model_dir, asset_filename, output);
|
||||
}
|
||||
|
||||
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
|
||||
const TensorProto& proto,
|
||||
std::unique_ptr<Constant>* output) {
|
||||
@ -211,7 +225,8 @@ Status FlattenSignature(const StructuredValue& signature,
|
||||
}
|
||||
|
||||
const SavedObject* FindNodeAtPath(StringPiece path,
|
||||
const SavedObjectGraph& object_graph) {
|
||||
const SavedObjectGraph& object_graph,
|
||||
int* node_id) {
|
||||
const auto& nodes = object_graph.nodes();
|
||||
if (nodes.empty()) {
|
||||
return nullptr;
|
||||
@ -231,6 +246,9 @@ const SavedObject* FindNodeAtPath(StringPiece path,
|
||||
if (child_node_iter == current_node->children().end()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (node_id) {
|
||||
*node_id = child_node_iter->node_id();
|
||||
}
|
||||
current_node = &nodes.Get(child_node_iter->node_id());
|
||||
}
|
||||
|
||||
|
@ -22,7 +22,9 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
@ -31,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
|
||||
@ -52,6 +55,11 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx,
|
||||
const SavedVariable& variable,
|
||||
std::unique_ptr<Variable>* output);
|
||||
|
||||
Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset,
|
||||
const std::string& saved_model_dir,
|
||||
absl::Span<const AssetFileDef> assets,
|
||||
std::unique_ptr<Asset>* output);
|
||||
|
||||
// Creates a TFConcreteFunction from a SavedConcreteFunction.
|
||||
Status LoadTFConcreteFunction(
|
||||
const SavedConcreteFunction& saved_concrete_function,
|
||||
@ -70,9 +78,11 @@ Status FlattenSignature(const StructuredValue& signature,
|
||||
// Find the SavedObject in `object_graph` at location `path`. `path` must be
|
||||
// a dot-delimited string of object names relative to the root object. If no
|
||||
// object is found, returns nullptr. Callers must ensure `object_graph`
|
||||
// outlives the returned pointer.
|
||||
// outlives the returned pointer. If not `nullptr`, `node_id` will contain the
|
||||
// index of the returned object in the `SavedObjectGraph.nodes` array.
|
||||
const SavedObject* FindNodeAtPath(StringPiece path,
|
||||
const SavedObjectGraph& object_graph);
|
||||
const SavedObjectGraph& object_graph,
|
||||
int* node_id = nullptr);
|
||||
|
||||
// Maps each node in `graphdef` to its corresponding Attribute Map.
|
||||
// Callers must ensure that `graphdef` outlives the returned map.
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
|
||||
#include "tensorflow/cc/saved_model/bundle_v2.h"
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
@ -108,12 +109,17 @@ Status ConstantFromSavedConstant(
|
||||
// SavedResources. These are returned via the `out` parameter.
|
||||
Status ReviveObjects(
|
||||
const MetaGraphDef& metagraph, ImmediateExecutionContext* context,
|
||||
const std::string& directory,
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>*
|
||||
revived_objects) {
|
||||
// This is needed to restore "Constant" nodes by looking up their
|
||||
// "Value" attribute.
|
||||
NodeAttrMap node_attr_map = internal::NodeToAttrMap(metagraph.graph_def());
|
||||
|
||||
// These are needed for creating "Assets", by looking up their filenames.
|
||||
std::vector<AssetFileDef> assets;
|
||||
TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(metagraph, &assets));
|
||||
|
||||
// Iterate through all the saved objects, restoring objects as we go.
|
||||
// We don't recreate functions until all other objects have been created.
|
||||
for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) {
|
||||
@ -129,12 +135,10 @@ Status ReviveObjects(
|
||||
node_attr_map, &constant));
|
||||
(*revived_objects)[i] = std::move(constant);
|
||||
} else if (node.kind_case() == SavedObject::kAsset) {
|
||||
// TODO(bmzhao): Implement Asset C++ class. This should be just recreating
|
||||
// the full path to the asset file:
|
||||
// https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/saved_model/load.py#L395-L396
|
||||
// and storing it as a string tensor:
|
||||
// https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/training/tracking/tracking.py#L324-L325
|
||||
return errors::Unimplemented("SavedAsset loading is not implemented yet");
|
||||
std::unique_ptr<Asset> asset;
|
||||
TF_RETURN_IF_ERROR(internal::LoadSavedAsset(context, node.asset(),
|
||||
directory, assets, &asset));
|
||||
(*revived_objects)[i] = std::move(asset);
|
||||
} else if (node.kind_case() == SavedObject::kResource) {
|
||||
// TODO(bmzhao): Figure out how resource loading works and implement it
|
||||
return errors::Unimplemented(
|
||||
@ -264,6 +268,12 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
|
||||
}
|
||||
|
||||
const std::string& checkpoint_key = attribute->checkpoint_key();
|
||||
if (!bundle->variable_reader()->Contains(checkpoint_key)) {
|
||||
LOG(WARNING) << "No checkpoint entry found for " << checkpoint_key
|
||||
<< ". Variable will be uninitialized.";
|
||||
return Status();
|
||||
}
|
||||
|
||||
std::string variables_path_prefix =
|
||||
io::JoinPath(directory, kSavedModelVariablesDirectory,
|
||||
kSavedModelVariablesFilename);
|
||||
@ -321,6 +331,31 @@ std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
|
||||
return result;
|
||||
}
|
||||
|
||||
Status TFSavedModelAPI::GetVariable(const std::string& variable_path,
|
||||
Variable** variable) {
|
||||
int node_id;
|
||||
const SavedObject* object = internal::FindNodeAtPath(
|
||||
variable_path, bundle_.saved_object_graph(), &node_id);
|
||||
if (object == nullptr) {
|
||||
return errors::NotFound("No saved object found at path ", variable_path);
|
||||
}
|
||||
|
||||
if (object->kind_case() == SavedObject::kVariable) {
|
||||
auto iter = revived_objects_.find(node_id);
|
||||
if (iter == revived_objects_.end()) {
|
||||
return errors::Internal("Variable ", variable_path,
|
||||
" was not properly revived.");
|
||||
}
|
||||
*variable = static_cast<Variable*>(iter->second.get());
|
||||
return Status();
|
||||
}
|
||||
|
||||
*variable = nullptr;
|
||||
return errors::InvalidArgument(
|
||||
variable_path, " is not a path to a Variable (kind=", object->kind_case(),
|
||||
")");
|
||||
}
|
||||
|
||||
TFSavedModelAPI::TFSavedModelAPI(
|
||||
const std::string& directory, SavedModelV2Bundle bundle,
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
|
||||
@ -352,8 +387,8 @@ Status TFSavedModelAPI::Load(
|
||||
// https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454
|
||||
|
||||
RevivedObjectMap revived_objects;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReviveObjects(bundle.meta_graph_def(), context, &revived_objects));
|
||||
TF_RETURN_IF_ERROR(ReviveObjects(bundle.meta_graph_def(), context, directory,
|
||||
&revived_objects));
|
||||
|
||||
// TODO(bmzhao): When we later add support for loading resources, we need to
|
||||
// handle the case where materializing a function's captures requires invoking
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
|
||||
#include "tensorflow/cc/saved_model/bundle_v2.h"
|
||||
@ -68,6 +69,8 @@ class TFSavedModelAPI : public SavedModelAPI {
|
||||
|
||||
~TFSavedModelAPI() override = default;
|
||||
|
||||
Status GetVariable(const std::string& variable_path, Variable** variable);
|
||||
|
||||
private:
|
||||
TFSavedModelAPI(
|
||||
const std::string& directory, SavedModelV2Bundle bundle,
|
||||
|
@ -245,14 +245,17 @@ tf_cc_test(
|
||||
"saved_model_api_test.cc",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/c/experimental/saved_model/internal/testdata:saved_models",
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
],
|
||||
deps = [
|
||||
":saved_model_api_type",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api_test_util",
|
||||
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_api",
|
||||
"//tensorflow/c/experimental/saved_model/public:concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -21,15 +21,20 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using tensorflow::tstring;
|
||||
|
||||
constexpr char kTestData[] = "cc/saved_model/testdata";
|
||||
const char* kServeTag[] = {"serve"};
|
||||
|
||||
@ -137,6 +142,103 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TF_DeleteStatus(status);
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
std::string model_dir = SavedModelPath("AssetModule");
|
||||
|
||||
TF_SavedModel* saved_model =
|
||||
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
|
||||
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TF_ConcreteFunction* read_file_fn =
|
||||
TF_GetSavedModelConcreteFunction(saved_model, "read_file", status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_Op* read_file_op =
|
||||
TF_ConcreteFunctionMakeCallOp(read_file_fn, nullptr, 0, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
|
||||
// inputs + outputs a function has.
|
||||
TFE_TensorHandle* read_file_fn_outputs[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
|
||||
TFE_Execute(read_file_op, &read_file_fn_outputs[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TF_Tensor* result = TFE_TensorHandleResolve(read_file_fn_outputs[0], status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
EXPECT_EQ(TF_NumDims(result), 0);
|
||||
tensorflow::tstring* output_value =
|
||||
static_cast<tensorflow::tstring*>(TF_TensorData(result));
|
||||
EXPECT_EQ(std::string(*output_value), "TEST ASSET FILE CONTENTS\n");
|
||||
|
||||
TF_DeleteTensor(result);
|
||||
TFE_DeleteTensorHandle(read_file_fn_outputs[0]);
|
||||
TFE_DeleteOp(read_file_op);
|
||||
TF_DeleteSavedModel(saved_model);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST_P(CSavedModelAPITest, LoadSavedModelWithUninitializedVariable) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TF_DeleteStatus(status);
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
std::string model_dir = tensorflow::io::JoinPath(
|
||||
tensorflow::testing::TensorFlowSrcRoot(),
|
||||
"c/experimental/saved_model/internal/testdata/UninitializedVariable");
|
||||
|
||||
TF_SavedModel* saved_model =
|
||||
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
tensorflow::TFSavedModelAPI* model_api =
|
||||
tensorflow::down_cast<tensorflow::TFSavedModelAPI*>(
|
||||
tensorflow::unwrap(saved_model));
|
||||
tensorflow::Variable* uninitialized_variable;
|
||||
ASSERT_EQ(tensorflow::Status::OK(),
|
||||
model_api->GetVariable("uninitialized_variable",
|
||||
&uninitialized_variable));
|
||||
ASSERT_EQ(tensorflow::DT_FLOAT, uninitialized_variable->dtype());
|
||||
|
||||
ASSERT_EQ(tensorflow::Status::OK(),
|
||||
model_api->GetVariable("sub_module.uninitialized_variable",
|
||||
&uninitialized_variable));
|
||||
ASSERT_EQ(tensorflow::DT_INT64, uninitialized_variable->dtype());
|
||||
|
||||
TF_DeleteSavedModel(saved_model);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest,
|
||||
::testing::Bool());
|
||||
|
||||
|
36
tensorflow/c/experimental/saved_model/internal/testdata/BUILD
vendored
Normal file
36
tensorflow/c/experimental/saved_model/internal/testdata/BUILD
vendored
Normal file
@ -0,0 +1,36 @@
|
||||
load("//tensorflow:tensorflow.bzl", "py_strict_binary")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# Run this binary manually, with an argument pointing to the testdata/
|
||||
# directory, to generate the test files used by the filegroup rule below.
|
||||
py_strict_binary(
|
||||
name = "gen_saved_models",
|
||||
srcs = ["gen_saved_models.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/module",
|
||||
"//tensorflow/python/saved_model",
|
||||
"//tensorflow/python/saved_model:save_options",
|
||||
],
|
||||
)
|
||||
|
||||
# Files generated by the binary above.
|
||||
filegroup(
|
||||
name = "saved_models",
|
||||
srcs = glob([
|
||||
"UninitializedVariable/**",
|
||||
]),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/internal:__pkg__",
|
||||
],
|
||||
)
|
BIN
tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/saved_model.pb
vendored
Normal file
BIN
tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/saved_model.pb
vendored
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
84
tensorflow/c/experimental/saved_model/internal/testdata/gen_saved_models.py
vendored
Normal file
84
tensorflow/c/experimental/saved_model/internal/testdata/gen_saved_models.py
vendored
Normal file
@ -0,0 +1,84 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# Lint as: python3
|
||||
"""Creates saved models used for testing.
|
||||
|
||||
This executable should be run with an argument pointing to the testdata/ folder
|
||||
in this directory. It will re-generate the saved models that are used for
|
||||
testing.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import google_type_annotations
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python.compat import v2_compat
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.saved_model import saved_model
|
||||
|
||||
|
||||
def _gen_uninitialized_variable(base_dir):
|
||||
"""Generates a saved model with an uninitialized variable."""
|
||||
|
||||
class SubModule(module.Module):
|
||||
"""A module with an UninitializedVariable."""
|
||||
|
||||
def __init__(self):
|
||||
self.uninitialized_variable = resource_variable_ops.UninitializedVariable(
|
||||
name="uninitialized_variable", dtype=dtypes.int64)
|
||||
|
||||
class Module(module.Module):
|
||||
"""A module with an UninitializedVariable."""
|
||||
|
||||
def __init__(self):
|
||||
super(Module, self).__init__()
|
||||
self.sub_module = SubModule()
|
||||
self.initialized_variable = variables.Variable(
|
||||
1.0, name="initialized_variable")
|
||||
# An UninitializedVariable with the same name as the variable in the
|
||||
# SubModule, but with a different type.
|
||||
self.uninitialized_variable = resource_variable_ops.UninitializedVariable(
|
||||
name="uninitialized_variable", dtype=dtypes.float32)
|
||||
|
||||
@def_function.function(
|
||||
input_signature=[tensor_spec.TensorSpec((), dtypes.float32)])
|
||||
def compute(self, value):
|
||||
return self.initialized_variable + value
|
||||
|
||||
to_save = Module()
|
||||
saved_model.save(
|
||||
to_save, export_dir=os.path.join(base_dir, "UninitializedVariable"))
|
||||
|
||||
|
||||
def main(args):
|
||||
if len(args) != 2:
|
||||
raise app.UsageError("Expected one argument (base_dir).")
|
||||
_, base_dir = args
|
||||
_gen_uninitialized_variable(base_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
v2_compat.enable_v2_behavior()
|
||||
app.run(main)
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/executor_cache.h"
|
||||
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
|
||||
#include "tensorflow/stream_executor/timer.h"
|
||||
|
@ -209,9 +209,13 @@ tf_cc_test(
|
||||
py_binary(
|
||||
name = "testdata/generate_saved_models",
|
||||
srcs = ["testdata/generate_saved_models.py"],
|
||||
data = [
|
||||
":saved_model_asset_data",
|
||||
],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
@ -221,6 +225,7 @@ py_binary(
|
||||
"//tensorflow/python/module",
|
||||
"//tensorflow/python/saved_model",
|
||||
"//tensorflow/python/saved_model:save_options",
|
||||
"//tensorflow/python/training/tracking",
|
||||
"@absl_py//absl:app",
|
||||
],
|
||||
)
|
||||
@ -229,6 +234,7 @@ py_binary(
|
||||
filegroup(
|
||||
name = "saved_model_test_files",
|
||||
srcs = glob([
|
||||
"testdata/AssetModule/**",
|
||||
"testdata/half_plus_two_pbtxt/**",
|
||||
"testdata/half_plus_two_main_op/**",
|
||||
"testdata/half_plus_two/**",
|
||||
@ -245,6 +251,13 @@ alias(
|
||||
actual = ":saved_model_test_files",
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "saved_model_asset_data",
|
||||
srcs = [
|
||||
"testdata/test_asset.txt",
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
glob([
|
||||
"testdata/half_plus_two_pbtxt/**",
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||
#include "tensorflow/cc/saved_model/reader.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
@ -73,26 +74,41 @@ uint64 GetLatencyMicroseconds(const uint64 start_microseconds) {
|
||||
// Ensure that constant tensors loaded from the saved model have valid shape.
|
||||
// Also ensure that constant nodes have a value assigned to them.
|
||||
// TODO(b/154763635): this is temporary and will be replaced with a better audit
|
||||
static Status ValidateNode(const NodeDef& node) {
|
||||
const auto node_iterator = node.attr().find("value");
|
||||
if (node_iterator != node.attr().end()) {
|
||||
AttrValue node_value = node_iterator->second;
|
||||
if (node_value.has_tensor()) {
|
||||
const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
|
||||
if (node_shape.num_elements() < 0) {
|
||||
return errors::FailedPrecondition(
|
||||
"Saved model contains node \"", node.name(), "\" (op \"", node.op(),
|
||||
"\") which initializes from a tensor with ",
|
||||
node_shape.num_elements(), " elements");
|
||||
}
|
||||
}
|
||||
} else if (node.op() == "Const") {
|
||||
return errors::FailedPrecondition(
|
||||
"Saved model contains node \"", node.name(),
|
||||
"\" which is a constant tensor but no value has been provided");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static Status ValidateSavedTensors(const GraphDef& graph_def) {
|
||||
for (const auto& node : graph_def.node()) {
|
||||
const auto node_iterator = node.attr().find("value");
|
||||
if (node_iterator != node.attr().end()) {
|
||||
AttrValue node_value = node_iterator->second;
|
||||
if (node_value.has_tensor()) {
|
||||
const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
|
||||
if (node_shape.num_elements() < 0) {
|
||||
return errors::FailedPrecondition(
|
||||
"Saved model contains node \"", node.name(), "\" (op \"",
|
||||
node.op(), "\") which initializes from a tensor with ",
|
||||
node_shape.num_elements(), " elements");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ValidateNode(node));
|
||||
}
|
||||
|
||||
if (graph_def.has_library()) {
|
||||
const FunctionDefLibrary& library = graph_def.library();
|
||||
for (const auto& function : library.function()) {
|
||||
for (const auto& node : function.node_def()) {
|
||||
TF_RETURN_IF_ERROR(ValidateNode(node));
|
||||
}
|
||||
} else if (node.op() == "Const") {
|
||||
return errors::FailedPrecondition(
|
||||
"Saved model contains node \"", node.name(),
|
||||
"\" which is a constant tensor but no value has been provided");
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -45,6 +45,8 @@ constexpr char kTestFuzzGeneratedNegativeShape[] =
|
||||
"cc/saved_model/testdata/fuzz_generated/negative_shape";
|
||||
constexpr char kTestFuzzGeneratedConstWithNoValue[] =
|
||||
"cc/saved_model/testdata/fuzz_generated/const_with_no_value";
|
||||
constexpr char kTestFuzzGeneratedBadNodeAttr[] =
|
||||
"cc/saved_model/testdata/fuzz_generated/bad_node_attr";
|
||||
|
||||
class LoaderTest : public ::testing::Test {
|
||||
protected:
|
||||
@ -328,5 +330,20 @@ TEST_F(LoaderTest, ConstNoValue) {
|
||||
std::string::npos);
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, BadNodeAttr) {
|
||||
SavedModelBundle bundle;
|
||||
RunOptions run_options;
|
||||
SessionOptions session_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestFuzzGeneratedBadNodeAttr);
|
||||
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe}, &bundle);
|
||||
EXPECT_FALSE(st.ok());
|
||||
EXPECT_NE(
|
||||
st.error_message().find("constant tensor but no value has been provided"),
|
||||
std::string::npos);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
1
tensorflow/cc/saved_model/testdata/AssetModule/assets/test_asset.txt
vendored
Normal file
1
tensorflow/cc/saved_model/testdata/AssetModule/assets/test_asset.txt
vendored
Normal file
@ -0,0 +1 @@
|
||||
TEST ASSET FILE CONTENTS
|
BIN
tensorflow/cc/saved_model/testdata/AssetModule/saved_model.pb
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/AssetModule/saved_model.pb
vendored
Normal file
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/AssetModule/variables/variables.data-00000-of-00001
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/AssetModule/variables/variables.data-00000-of-00001
vendored
Normal file
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/AssetModule/variables/variables.index
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/AssetModule/variables/variables.index
vendored
Normal file
Binary file not shown.
0
tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/assets/empty
vendored
Normal file
0
tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/assets/empty
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/saved_model.pb
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/saved_model.pb
vendored
Normal file
Binary file not shown.
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/variables/variables.index
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/fuzz_generated/bad_node_attr/variables/variables.index
vendored
Normal file
Binary file not shown.
@ -29,9 +29,12 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import save_options
|
||||
from tensorflow.python.saved_model import saved_model
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
|
||||
|
||||
class VarsAndArithmeticObjectGraph(module.Module):
|
||||
@ -68,9 +71,21 @@ class CyclicModule(module.Module):
|
||||
self.child = ReferencesParent(self)
|
||||
|
||||
|
||||
class AssetModule(module.Module):
|
||||
|
||||
def __init__(self):
|
||||
self.asset = tracking.Asset(
|
||||
test.test_src_dir_path("cc/saved_model/testdata/test_asset.txt"))
|
||||
|
||||
@def_function.function(input_signature=[])
|
||||
def read_file(self):
|
||||
return io_ops.read_file(self.asset)
|
||||
|
||||
|
||||
MODULE_CTORS = {
|
||||
"VarsAndArithmeticObjectGraph": VarsAndArithmeticObjectGraph,
|
||||
"CyclicModule": CyclicModule,
|
||||
"AssetModule": AssetModule,
|
||||
}
|
||||
|
||||
|
||||
|
1
tensorflow/cc/saved_model/testdata/test_asset.txt
vendored
Normal file
1
tensorflow/cc/saved_model/testdata/test_asset.txt
vendored
Normal file
@ -0,0 +1 @@
|
||||
TEST ASSET FILE CONTENTS
|
@ -1708,40 +1708,6 @@ std::atomic<int64>* GetPointerToFuel(int64 initial_value) {
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap*
|
||||
uncompilable_node_info) {
|
||||
Device* device = flr->device();
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
|
||||
®istration));
|
||||
|
||||
// We can always *compile* resource operations, stateful RNGs and dummy ops,
|
||||
// even if we are sometimes unable to auto-cluster them.
|
||||
RecursiveCompilabilityChecker::OperationFilter op_filter;
|
||||
op_filter.allow_resource_ops_in_called_functions = true;
|
||||
op_filter.allow_stack_ops = true;
|
||||
op_filter.allow_tensor_array_ops = true;
|
||||
op_filter.allow_stateful_rng_ops = true;
|
||||
op_filter.allow_control_trigger = true;
|
||||
op_filter.allow_eliding_assert_and_checknumerics_ops = true;
|
||||
op_filter.allow_ops_producing_or_consuming_variant = true;
|
||||
op_filter.allow_slow_ops = true;
|
||||
op_filter.allow_inaccurate_ops = true;
|
||||
|
||||
RecursiveCompilabilityChecker checker{
|
||||
op_filter, DeviceType{registration->compilation_device_name}};
|
||||
if (!uncompilable_node_info) {
|
||||
// We do not need uncompilable node info. Just return the result.
|
||||
return checker.IsCompilableCall(ndef, flr);
|
||||
}
|
||||
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result =
|
||||
checker.FindUncompilableNodes(ndef, flr);
|
||||
uncompilable_node_info->swap(uncompilable_node_result);
|
||||
return uncompilable_node_info->empty();
|
||||
}
|
||||
|
||||
Status MarkForCompilationPass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
|
||||
@ -1951,6 +1917,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"ParallelDynamicStitch",
|
||||
"ParameterizedTruncatedNormal",
|
||||
"PartitionedCall",
|
||||
"Polygamma",
|
||||
"PopulationCount",
|
||||
"Qr",
|
||||
"QuantizeAndDequantizeV2",
|
||||
@ -2094,6 +2061,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"XlaSpmdShardToFullShape",
|
||||
"XlaSvd",
|
||||
"XlaWhile",
|
||||
"Zeta",
|
||||
"_Arg",
|
||||
"_ArrayToList",
|
||||
"_ListToArray",
|
||||
|
@ -50,14 +50,6 @@ class MarkForCompilationPass : public GraphOptimizationPass {
|
||||
friend class MarkForCompilationPassTestHelper;
|
||||
};
|
||||
|
||||
// Returns true iff 'ndef' is a call to a function that is compilable. A
|
||||
// function is compilable iff every operator in the function body is
|
||||
// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not
|
||||
// null, we will populate 'uncompilable_node_info' with uncompilable node info.
|
||||
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap*
|
||||
uncompilable_node_info = nullptr);
|
||||
|
||||
absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable();
|
||||
|
||||
namespace testing {
|
||||
|
@ -294,4 +294,12 @@ se::Stream* XlaDeviceContext::GetDeviceToDeviceStream() {
|
||||
return device_to_device_stream(stream);
|
||||
}
|
||||
|
||||
Status XlaDeviceContext::ThenExecute(Device* device,
|
||||
stream_executor::Stream* stream,
|
||||
std::function<void()> func) {
|
||||
VLOG(2) << "XlaDeviceContext::ThenExecute";
|
||||
stream->ThenDoHostCallback(std::move(func));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -86,6 +86,9 @@ class XlaDeviceContext : public DeviceContext {
|
||||
// Returns a device-to-device stream, in round-robin fashion.
|
||||
se::Stream* GetDeviceToDeviceStream();
|
||||
|
||||
Status ThenExecute(Device* device, stream_executor::Stream* stream,
|
||||
std::function<void()> func) override;
|
||||
|
||||
private:
|
||||
bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; }
|
||||
|
||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
@ -32,6 +31,44 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Returns true iff 'ndef' is a call to a function that is compilable. A
|
||||
// function is compilable iff every operator in the function body is
|
||||
// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not
|
||||
// null, we will populate 'uncompilable_node_info' with uncompilable node info.
|
||||
static bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap*
|
||||
uncompilable_node_info) {
|
||||
Device* device = flr->device();
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
|
||||
®istration));
|
||||
|
||||
// We can always *compile* resource operations, stateful RNGs and dummy ops,
|
||||
// even if we are sometimes unable to auto-cluster them.
|
||||
RecursiveCompilabilityChecker::OperationFilter op_filter;
|
||||
op_filter.allow_resource_ops_in_called_functions = true;
|
||||
op_filter.allow_stack_ops = true;
|
||||
op_filter.allow_tensor_array_ops = true;
|
||||
op_filter.allow_stateful_rng_ops = true;
|
||||
op_filter.allow_control_trigger = true;
|
||||
op_filter.allow_eliding_assert_and_checknumerics_ops = true;
|
||||
op_filter.allow_ops_producing_or_consuming_variant = true;
|
||||
op_filter.allow_slow_ops = true;
|
||||
op_filter.allow_inaccurate_ops = true;
|
||||
|
||||
RecursiveCompilabilityChecker checker{
|
||||
op_filter, DeviceType{registration->compilation_device_name}};
|
||||
if (!uncompilable_node_info) {
|
||||
// We do not need uncompilable node info. Just return the result.
|
||||
return checker.IsCompilableCall(ndef, flr);
|
||||
}
|
||||
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result =
|
||||
checker.FindUncompilableNodes(ndef, flr);
|
||||
uncompilable_node_info->swap(uncompilable_node_result);
|
||||
return uncompilable_node_info->empty();
|
||||
}
|
||||
|
||||
bool XlaKernelCreator::CanCreateKernel(
|
||||
const FunctionLibraryRuntime& flr,
|
||||
const std::shared_ptr<const NodeProperties>& props) const {
|
||||
|
@ -366,6 +366,20 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
|
||||
}];
|
||||
}
|
||||
|
||||
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
let summary = "Sinh operation";
|
||||
|
||||
let description = [{
|
||||
Returns `Sinh(operand)` element-wise.
|
||||
|
||||
$$
|
||||
\sinh(x) = (e^x - e^-x) / 2 if |x| < 1
|
||||
= e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
|
||||
$$
|
||||
}];
|
||||
}
|
||||
|
||||
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
let summary = "Tan operation";
|
||||
|
@ -49,6 +49,45 @@ def : Pat<(HLOClient_AcosOp $input),
|
||||
),
|
||||
(HLO_ConstantLike<"M_PI"> $input))>;
|
||||
|
||||
// Express `sinh` as
|
||||
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
|
||||
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
|
||||
def : Pat<(HLOClient_SinhOp $input),
|
||||
(HLO_SelectOp
|
||||
(HLO_CompareOp
|
||||
(HLO_AbsOp $input),
|
||||
(HLO_ConstantLike<"1"> $input),
|
||||
HLO_COMPARISON_DIRECTION_LT
|
||||
),
|
||||
(HLO_DivOp
|
||||
(HLO_SubOp
|
||||
(HLO_ExpOp $input),
|
||||
(HLO_ExpOp
|
||||
(HLO_NegOp $input)
|
||||
)
|
||||
),
|
||||
(HLO_ConstantLike<"2"> $input)
|
||||
),
|
||||
(HLO_SubOp
|
||||
(HLO_ExpOp
|
||||
(HLO_AddOp
|
||||
$input,
|
||||
(HLO_LogOp
|
||||
(HLO_ConstantLike<"0.5"> $input)
|
||||
)
|
||||
)
|
||||
),
|
||||
(HLO_ExpOp
|
||||
(HLO_SubOp
|
||||
(HLO_LogOp
|
||||
(HLO_ConstantLike<"0.5"> $input)
|
||||
),
|
||||
$input
|
||||
)
|
||||
)
|
||||
)
|
||||
)>;
|
||||
|
||||
// Express tan in MHLO dialect as
|
||||
// tan(x) = sin(x) / cos(x).
|
||||
def : Pat<(HLOClient_TanOp $input),
|
||||
|
@ -47,7 +47,8 @@ namespace {
|
||||
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
|
||||
|
||||
// TODO(herhut): Generate these out of op definitions.
|
||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) fn(TanOp) sep fn(AcosOp)
|
||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
||||
fn(TanOp) sep fn(AcosOp) sep fn(SinhOp)
|
||||
|
||||
template <typename OpTy>
|
||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||
|
@ -357,11 +357,7 @@ ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
|
||||
while_region.body()));
|
||||
} else if (auto case_op = dyn_cast<CaseOp>(op)) {
|
||||
llvm::SmallVector<FuncOp, 4> functions;
|
||||
functions.reserve(case_op.branches().size());
|
||||
for (auto branch : case_op.branches())
|
||||
functions.emplace_back(SymbolTable::lookupNearestSymbolFrom<FuncOp>(
|
||||
case_op, branch.cast<SymbolRefAttr>()));
|
||||
|
||||
case_op.get_branch_functions(functions);
|
||||
AnalyzeFunctionalCaseOrIfOp(case_op, functions, backtrack_analysis);
|
||||
} else if (auto if_op = dyn_cast<IfOp>(op)) {
|
||||
AnalyzeFunctionalCaseOrIfOp(
|
||||
|
@ -653,9 +653,8 @@ Status MlirFunctionContext::Finalize(OutputList* outputs,
|
||||
}
|
||||
builder_.create<ReturnOp>(func_.getLoc(), ret_operands);
|
||||
|
||||
auto arg_types = llvm::to_vector<8>(body.getArgumentTypes());
|
||||
auto result_types =
|
||||
llvm::to_vector<8>(body.getTerminator()->getOperandTypes());
|
||||
auto arg_types = body.getArgumentTypes();
|
||||
auto result_types = body.getTerminator()->getOperandTypes();
|
||||
func_.setType(FunctionType::get(arg_types, result_types, func_.getContext()));
|
||||
*f = new MlirFunction(std::move(context_), std::move(module_), func_);
|
||||
return Status::OK();
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -110,27 +110,37 @@ def TF_StackResource : TF_ResourceBase<"Stack">;
|
||||
def TF_TensorArrayResource : TF_ResourceBase<"TensorArray">;
|
||||
def TF_SummaryResource : TF_ResourceBase<"Summary">;
|
||||
def TF_LookupTableResource : TF_ResourceBase<"LookupTable">;
|
||||
def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">;
|
||||
def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">;
|
||||
|
||||
def TF_VariableRead : MemRead<TF_VariableResource>;
|
||||
def TF_StackRead : MemRead<TF_StackResource>;
|
||||
def TF_TensorArrayRead : MemRead<TF_TensorArrayResource>;
|
||||
def TF_LookupTableRead : MemRead<TF_LookupTableResource>;
|
||||
def TF_DatasetSeedGeneratorRead : MemRead<TF_DatasetSeedGeneratorResource>;
|
||||
def TF_DatasetMemoryCacheRead : MemRead<TF_DatasetMemoryCacheResource>;
|
||||
|
||||
def TF_VariableWrite : MemWrite<TF_VariableResource>;
|
||||
def TF_StackWrite : MemWrite<TF_StackResource>;
|
||||
def TF_TensorArrayWrite : MemWrite<TF_TensorArrayResource>;
|
||||
def TF_SummaryWrite : MemWrite<TF_SummaryResource>;
|
||||
def TF_LookupTableWrite : MemWrite<TF_LookupTableResource>;
|
||||
def TF_DatasetSeedGeneratorWrite : MemWrite<TF_DatasetSeedGeneratorResource>;
|
||||
def TF_DatasetMemoryCacheWrite : MemWrite<TF_DatasetMemoryCacheResource>;
|
||||
|
||||
def TF_VariableAlloc : MemAlloc<TF_VariableResource>;
|
||||
def TF_StackAlloc : MemAlloc<TF_StackResource>;
|
||||
def TF_TensorArrayAlloc : MemAlloc<TF_TensorArrayResource>;
|
||||
def TF_SummaryAlloc : MemAlloc<TF_SummaryResource>;
|
||||
def TF_LookupTableAlloc : MemAlloc<TF_LookupTableResource>;
|
||||
def TF_DatasetSeedGeneratorAlloc : MemAlloc<TF_DatasetSeedGeneratorResource>;
|
||||
def TF_DatasetMemoryCacheAlloc : MemAlloc<TF_DatasetMemoryCacheResource>;
|
||||
|
||||
def TF_StackFree : MemFree<TF_StackResource>;
|
||||
def TF_TensorArrayFree : MemFree<TF_TensorArrayResource>;
|
||||
def TF_SummaryFree : MemFree<TF_SummaryResource>;
|
||||
def TF_DatasetSeedGeneratorFree : MemFree<TF_DatasetSeedGeneratorResource>;
|
||||
def TF_DatasetMemoryCacheFree : MemFree<TF_DatasetMemoryCacheResource>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow op definitions
|
||||
@ -172,109 +182,19 @@ class TF_TensorFlowType <string name, string description> :
|
||||
"TensorFlow " # description # " type">,
|
||||
BuildableType<"getType<mlir::TF::" # name # "Type>()">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Integer types
|
||||
|
||||
// TODO(mgester) shouldn't this be SignedIntOfWidths?
|
||||
def TF_I32Or64 : SignlessIntOfWidths<[32, 64]>;
|
||||
|
||||
def TF_I32OrI64Tensor : TensorOf<[TF_I32Or64]>;
|
||||
|
||||
def TF_Uint8 : UI<8>;
|
||||
def TF_Uint8Tensor : TensorOf<[TF_Uint8]>;
|
||||
|
||||
def TF_Uint16 : UI<16>;
|
||||
def TF_Uint16Tensor : TensorOf<[TF_Uint16]>;
|
||||
|
||||
def TF_Uint32 : UI<32>;
|
||||
def TF_Uint32Tensor : TensorOf<[TF_Uint32]>;
|
||||
|
||||
def TF_Uint64 : UI<64>;
|
||||
def TF_Uint64Tensor : TensorOf<[TF_Uint64]>;
|
||||
|
||||
// Any unsigned integer type
|
||||
def TF_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>;
|
||||
|
||||
// Any signed integer type
|
||||
// TODO(mgester) shouldn't this be SignedIntOfWidths?
|
||||
def TF_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>;
|
||||
|
||||
// Any integer type
|
||||
def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">;
|
||||
|
||||
// Any integer tensor types
|
||||
def TF_IntTensor : TensorOf<[TF_Int]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Quantized types
|
||||
def TF_Qint8 : TF_TensorFlowType<"Qint8", "qint8">;
|
||||
def TF_Qint16 : TF_TensorFlowType<"Qint16", "qint16">;
|
||||
def TF_Qint32 : TF_TensorFlowType<"Qint32", "qint32">;
|
||||
def TF_Quint8 : TF_TensorFlowType<"Quint8", "quint8">;
|
||||
def TF_Quint16 : TF_TensorFlowType<"Quint16", "quint16">;
|
||||
|
||||
// Any quantized type
|
||||
def TF_Quantized : AnyTypeOf<[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8,
|
||||
TF_Quint16], "quantized">;
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Floating-point types
|
||||
|
||||
def TF_F32Or64 : FloatOfWidths<[32, 64]>;
|
||||
|
||||
def TF_F32OrF64Tensor : TensorOf<[TF_F32Or64]>;
|
||||
|
||||
def TF_Float : AnyTypeOf<[F16, F32, F64, BF16], "floating-point">;
|
||||
|
||||
// Any floating-point tensor types
|
||||
def TF_FpTensor : TensorOf<[TF_Float]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Complex types
|
||||
|
||||
// TODO(suderman): Remove TF_Complex64 and use a standard ops declaration, along
|
||||
// with the associated cleanup.
|
||||
def TF_Complex64 : Complex<F<32>>;
|
||||
def TF_Complex64Tensor : TensorOf<[TF_Complex64]>;
|
||||
|
||||
def TF_Complex128 : Complex<F<64>>;
|
||||
def TF_Complex128Tensor : TensorOf<[TF_Complex128]>;
|
||||
|
||||
def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">;
|
||||
|
||||
def TF_ComplexTensor : TensorOf<[TF_Complex]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// String/variant/resource types
|
||||
|
||||
def TF_Str : TF_TensorFlowType<"String", "string">;
|
||||
def TF_StrTensor : TensorOf<[TF_Str]>;
|
||||
|
||||
def TF_Variant : TF_TensorFlowType<"Variant", "variant">;
|
||||
def TF_VariantTensor : TensorOf<[TF_Variant]>;
|
||||
|
||||
def TF_Resource : TF_TensorFlowType<"Resource", "resource">;
|
||||
def TF_ResourceTensor : TensorOf<[TF_Resource]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Reference types
|
||||
|
||||
// Float reference types
|
||||
def TF_F16Ref : TF_TensorFlowType<"HalfRef", "f16ref">;
|
||||
def TF_F32Ref : TF_TensorFlowType<"FloatRef", "f32ref">;
|
||||
def TF_F64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">;
|
||||
def TF_Float16Ref : TF_TensorFlowType<"HalfRef", "f16ref">;
|
||||
def TF_Float32Ref : TF_TensorFlowType<"FloatRef", "f32ref">;
|
||||
def TF_Float64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">;
|
||||
def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">;
|
||||
|
||||
// Any float reference type
|
||||
def TF_FloatRef : AnyTypeOf<[TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_Bfloat16Ref],
|
||||
"floating-point reference">;
|
||||
|
||||
// Complex reference types
|
||||
def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">;
|
||||
def TF_Complex128Ref : TF_TensorFlowType<"Complex128Ref", "complex128ref">;
|
||||
|
||||
// Any complex reference type
|
||||
def TF_ComplexRef : AnyTypeOf<[TF_Complex64Ref, TF_Complex128Ref], "complex reference">;
|
||||
|
||||
// Integer reference types
|
||||
def TF_Int8Ref : TF_TensorFlowType<"Int8Ref", "i8ref">;
|
||||
def TF_Int16Ref : TF_TensorFlowType<"Int16Ref", "i16ref">;
|
||||
@ -286,17 +206,6 @@ def TF_Uint16Ref : TF_TensorFlowType<"Uint16Ref", "ui16ref">;
|
||||
def TF_Uint32Ref : TF_TensorFlowType<"Uint32Ref", "ui32ref">;
|
||||
def TF_Uint64Ref : TF_TensorFlowType<"Uint64Ref", "ui64ref">;
|
||||
|
||||
// Any signed integer reference type
|
||||
def TF_SIntRef : AnyTypeOf<[TF_Int8Ref, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref],
|
||||
"signed integer reference">;
|
||||
|
||||
// Any unsigned integer reference type
|
||||
def TF_UIntRef : AnyTypeOf<[TF_Uint8Ref, TF_Uint16Ref, TF_Uint32Ref,
|
||||
TF_Uint64Ref], "unsigned integer reference">;
|
||||
|
||||
// Any integer reference type
|
||||
def TF_IntRef : AnyTypeOf<[TF_SIntRef, TF_UIntRef], "integer reference">;
|
||||
|
||||
// Quantized reference types
|
||||
def TF_Qint8Ref : TF_TensorFlowType<"Qint8Ref", "qint8ref">;
|
||||
def TF_Qint16Ref : TF_TensorFlowType<"Qint16Ref", "qint16ref">;
|
||||
@ -304,54 +213,152 @@ def TF_Qint32Ref : TF_TensorFlowType<"Qint32Ref", "qint32ref">;
|
||||
def TF_Quint8Ref : TF_TensorFlowType<"Quint8Ref", "quint8ref">;
|
||||
def TF_Quint16Ref : TF_TensorFlowType<"Quint16Ref", "quint16ref">;
|
||||
|
||||
// Any quantized reference type
|
||||
def TF_QuantizedRef : AnyTypeOf<[TF_Qint8Ref, TF_Qint16Ref, TF_Qint32Ref,
|
||||
TF_Quint8Ref, TF_Quint16Ref], "quantized reference">;
|
||||
|
||||
// Other reference types
|
||||
def TF_BoolRef : TF_TensorFlowType<"BoolRef", "boolref">;
|
||||
def TF_ResourceRef : TF_TensorFlowType<"ResourceRef", "resourceref">;
|
||||
def TF_StrRef : TF_TensorFlowType<"StringRef", "stringref">;
|
||||
def TF_VariantRef : TF_TensorFlowType<"VariantRef", "variantref">;
|
||||
|
||||
// Reference tensor types
|
||||
def TF_FpRefTensor : TensorOf<[TF_FloatRef]>;
|
||||
def TF_I32OrI64RefTensor : TensorOf<[TF_Int32Ref, TF_Int64Ref]>;
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Integer types (including corresponding reference types)
|
||||
|
||||
def TF_Bool : AnyTypeOf<[I<1>, TF_BoolRef], "bool">;
|
||||
|
||||
def TF_Int8 : AnyTypeOf<[I8, TF_Int8Ref], "8-bit integer">;
|
||||
def TF_Int16 : AnyTypeOf<[I16, TF_Int16Ref], "16-bit integer">;
|
||||
def TF_Int32 : AnyTypeOf<[I32, TF_Int32Ref], "32-bit integer">;
|
||||
def TF_Int64 : AnyTypeOf<[I64, TF_Int64Ref], "64-bit integer">;
|
||||
def TF_I32OrI64 : AnyTypeOf<[I32, I64, TF_Int32Ref, TF_Int64Ref],
|
||||
"32/64-bit signed integer">;
|
||||
|
||||
def TF_Uint8 : AnyTypeOf<[UI<8>, TF_Uint8Ref], "8-bit unsigned integer">;
|
||||
def TF_Uint16 : AnyTypeOf<[UI<16>, TF_Uint16Ref], "16-bit unsigned integer">;
|
||||
def TF_Uint32 : AnyTypeOf<[UI<32>, TF_Uint32Ref], "32-bit unsigned integer">;
|
||||
def TF_Uint64 : AnyTypeOf<[UI<64>, TF_Uint64Ref], "64-bit unsigned integer">;
|
||||
|
||||
// Any unsigned integer type
|
||||
def TF_UInt : AnyTypeOf<[TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64],
|
||||
"unsigned integer">;
|
||||
|
||||
// Any signed integer type
|
||||
def TF_SInt : AnyTypeOf<[TF_Int8, TF_Int16, TF_Int32, TF_Int64],
|
||||
"signed integer">;
|
||||
|
||||
// Any integer type
|
||||
def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">;
|
||||
|
||||
// Tensor types
|
||||
def TF_BoolTensor : TensorOf<[TF_Bool]>;
|
||||
|
||||
def TF_IntTensor : TensorOf<[TF_Int]>;
|
||||
def TF_Int8Tensor : TensorOf<[TF_Int8]>;
|
||||
def TF_Int16Tensor : TensorOf<[TF_Int16]>;
|
||||
def TF_Int32Tensor : TensorOf<[TF_Int32]>;
|
||||
def TF_Int64Tensor : TensorOf<[TF_Int64]>;
|
||||
def TF_I32OrI64Tensor : TensorOf<[TF_I32OrI64]>;
|
||||
|
||||
def TF_Uint8Tensor : TensorOf<[TF_Uint8]>;
|
||||
def TF_Uint16Tensor : TensorOf<[TF_Uint16]>;
|
||||
def TF_Uint32Tensor : TensorOf<[TF_Uint32]>;
|
||||
def TF_Uint64Tensor : TensorOf<[TF_Uint64]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Quantized types (including corresponding reference types)
|
||||
|
||||
def TF_Qint8 : AnyTypeOf<
|
||||
[TF_TensorFlowType<"Qint8", "qint8">, TF_Qint8Ref],
|
||||
"8-bit quantized integer">;
|
||||
def TF_Qint16 : AnyTypeOf<
|
||||
[TF_TensorFlowType<"Qint16", "qint16">, TF_Qint16Ref],
|
||||
"16-bit quantized integer">;
|
||||
def TF_Qint32 : AnyTypeOf<
|
||||
[TF_TensorFlowType<"Qint32", "qint32">, TF_Qint32Ref],
|
||||
"32-bit quantized integer">;
|
||||
def TF_Quint8 : AnyTypeOf<
|
||||
[TF_TensorFlowType<"Quint8", "quint8">, TF_Quint8Ref],
|
||||
"8-bit quantized unsigned integer">;
|
||||
def TF_Quint16 : AnyTypeOf<
|
||||
[TF_TensorFlowType<"Quint16", "quint16">, TF_Quint16Ref],
|
||||
"16-bit quantized unsigned integer">;
|
||||
|
||||
// Any quantized type
|
||||
def TF_Quantized : AnyTypeOf<
|
||||
[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, TF_Quint16], "quantized">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Floating-point types (including corresponding reference types)
|
||||
|
||||
def TF_Float16 : AnyTypeOf<[F16, TF_Float16Ref], "16-bit float">;
|
||||
def TF_Float32 : AnyTypeOf<[F32, TF_Float32Ref], "32-bit float">;
|
||||
def TF_Float64 : AnyTypeOf<[F64, TF_Float64Ref], "64-bit float">;
|
||||
def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">;
|
||||
|
||||
def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">;
|
||||
|
||||
def TF_Float : AnyTypeOf<
|
||||
[TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16,
|
||||
TF_Float16Ref, TF_Float32Ref, TF_Float64Ref, TF_Bfloat16Ref],
|
||||
"floating-point">;
|
||||
|
||||
// Tensor types
|
||||
def TF_FloatTensor : TensorOf<[TF_Float]>;
|
||||
def TF_F32OrF64Tensor : TensorOf<[TF_F32OrF64]>;
|
||||
def TF_Float16Tensor : TensorOf<[TF_Float16]>;
|
||||
def TF_Float32Tensor : TensorOf<[TF_Float32]>;
|
||||
def TF_Float64Tensor : TensorOf<[TF_Float64]>;
|
||||
def TF_Bfloat16Tensor : TensorOf<[TF_Bfloat16]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Complex types (including corresponding reference types)
|
||||
|
||||
// TODO(suderman): Remove TF_Complex64 and use a standard ops declaration, along
|
||||
// with the associated cleanup.
|
||||
def TF_Complex64 : AnyTypeOf<[Complex<F<32>>, TF_Complex64Ref],
|
||||
"64-bit complex">;
|
||||
def TF_Complex128 : AnyTypeOf<[Complex<F<64>>, TF_Complex128Ref],
|
||||
"128-bit complex">;
|
||||
def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">;
|
||||
|
||||
// Tensor types
|
||||
def TF_ComplexTensor : TensorOf<[TF_Complex]>;
|
||||
def TF_Complex64Tensor : TensorOf<[TF_Complex64]>;
|
||||
def TF_Complex128Tensor : TensorOf<[TF_Complex128]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// String/variant/resource types (including corresponding reference types)
|
||||
|
||||
def TF_Str : AnyTypeOf<
|
||||
[TF_TensorFlowType<"String", "str">, TF_StrRef], "string">;
|
||||
def TF_StrTensor : TensorOf<[TF_Str]>;
|
||||
|
||||
def TF_Variant : AnyTypeOf<
|
||||
[TF_TensorFlowType<"Variant", "var">, TF_VariantRef], "variant">;
|
||||
def TF_VariantTensor : TensorOf<[TF_Variant]>;
|
||||
|
||||
def TF_Resource : AnyTypeOf<
|
||||
[TF_TensorFlowType<"Resource", "res">, TF_ResourceRef], "resource">;
|
||||
def TF_ResourceTensor : TensorOf<[TF_Resource]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Multi-category type constraints
|
||||
|
||||
def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32Or64]>;
|
||||
|
||||
def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32Or64]>;
|
||||
|
||||
// Any integer or floating-point tensor types
|
||||
def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32OrF64]>;
|
||||
def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32OrI64]>;
|
||||
def TF_IntOrFpTensor : TensorOf<[TF_Int, TF_Float]>;
|
||||
|
||||
def TF_SintOrFpTensor : TensorOf<[TF_SInt, TF_Float]>;
|
||||
|
||||
def TF_FpOrComplexTensor : TensorOf<[TF_Float, TF_Complex]>;
|
||||
|
||||
def TF_Number : AnyTypeOf<[TF_Int, TF_Float, TF_Quantized, TF_Complex],
|
||||
"number">;
|
||||
def TF_NumberRef : AnyTypeOf<[TF_IntRef, TF_FloatRef, TF_QuantizedRef,
|
||||
TF_ComplexRef], "number reference">;
|
||||
|
||||
def TF_Number : AnyTypeOf<
|
||||
[TF_Int, TF_Float, TF_Quantized, TF_Complex], "number">;
|
||||
def TF_NumberTensor : TensorOf<[TF_Number]>;
|
||||
def TF_NumberRefTensor : TensorOf<[TF_NumberRef]>;
|
||||
|
||||
def TF_NumberNotQuantizedOrStr :
|
||||
AnyTypeOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Str]>;
|
||||
def TF_NumberNotQuantizedOrStrRef :
|
||||
AnyTypeOf<[TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref, TF_StrRef]>;
|
||||
def TF_NumberNotQuantizedOrStrTensor : TensorOf<[TF_NumberNotQuantizedOrStr]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tensor and tensor element types
|
||||
|
||||
// Bool type
|
||||
def TF_Bool : I<1>;
|
||||
|
||||
// Any tensor element type allowed in TensorFlow ops
|
||||
// (see https://www.tensorflow.org/api_docs/python/tf/dtypes/DType)
|
||||
def TF_ElementType : Type<Or<[TF_Float.predicate,
|
||||
|
@ -116,6 +116,24 @@ An n-way switch statement, implementing the following:
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
int num_branches() { return branches().size(); }
|
||||
|
||||
// Gets function corresponding branch # `index`.
|
||||
FuncOp branch_function(int index) {
|
||||
auto flat_sym_ref = branches()[index].cast<FlatSymbolRefAttr>();
|
||||
return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, flat_sym_ref);
|
||||
}
|
||||
|
||||
// Gets all branch functions.
|
||||
void get_branch_functions(SmallVectorImpl<FuncOp> &functions) {
|
||||
functions.reserve(num_branches());
|
||||
for (int idx : llvm::seq<int>(0, num_branches()))
|
||||
functions.push_back(branch_function(idx));
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_CaseRegionOp : TF_Op<"CaseRegion",
|
||||
@ -206,12 +224,12 @@ source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs:
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
|
||||
TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
|
||||
I32Tensor:$source_target_pairs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
TensorOf<[TF_Bfloat16, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -231,7 +249,7 @@ element_shape: a shape compatible with that of elements in the list.
|
||||
|
||||
let arguments = (ins
|
||||
TF_I32OrI64Tensor:$element_shape,
|
||||
I32Tensor:$max_num_elements
|
||||
TF_Int32Tensor:$max_num_elements
|
||||
);
|
||||
}
|
||||
|
||||
@ -424,7 +442,7 @@ def TF_ParseExampleOp : TF_Op<"ParseExample",
|
||||
TF_StrTensor:$names,
|
||||
Variadic<TF_StrTensor>:$sparse_keys,
|
||||
Variadic<TF_StrTensor>:$dense_keys,
|
||||
Variadic<TensorOf<[F32, I64, TF_Str]>>:$dense_defaults,
|
||||
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_defaults,
|
||||
|
||||
TF_ShapeAttrArray:$dense_shapes,
|
||||
I32ElementsAttr:$result_segment_sizes,
|
||||
@ -432,10 +450,10 @@ def TF_ParseExampleOp : TF_Op<"ParseExample",
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<I64Tensor>:$sparse_indices, // len(sparse_types)
|
||||
Variadic<TensorOf<[F32, I64, TF_Str]>>:$sparse_values, // len(sparse_types)
|
||||
Variadic<I64Tensor>:$sparse_shapes, // len(sparse_types)
|
||||
Variadic<TensorOf<[F32, I64, TF_Str]>>:$dense_values // len(Tdense)
|
||||
Variadic<TF_Int64Tensor>:$sparse_indices, // len(sparse_types)
|
||||
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$sparse_values, // len(sparse_types)
|
||||
Variadic<TF_Int64Tensor>:$sparse_shapes, // len(sparse_types)
|
||||
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_values // len(Tdense)
|
||||
);
|
||||
|
||||
TF_DerivedOperandSizeAttr Nsparse = TF_DerivedOperandSizeAttr<2>;
|
||||
@ -459,7 +477,7 @@ def TF_ParseExampleV2Op : TF_Op<"ParseExampleV2",
|
||||
TF_StrTensor:$sparse_keys,
|
||||
TF_StrTensor:$dense_keys,
|
||||
TF_StrTensor:$ragged_keys,
|
||||
Variadic<TensorOf<[F32, I64, TF_Str]>>:$dense_defaults,
|
||||
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_defaults,
|
||||
|
||||
Confined<I64Attr, [IntMinValue<0>]>:$num_sparse,
|
||||
TF_ShapeAttrArray:$dense_shapes,
|
||||
@ -467,13 +485,13 @@ def TF_ParseExampleV2Op : TF_Op<"ParseExampleV2",
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<I64Tensor>:$sparse_indices, // len(sparse_types)
|
||||
Variadic<TensorOf<[F32, I64, TF_Str]>>:$sparse_values, // len(sparse_types)
|
||||
Variadic<I64Tensor>:$sparse_shapes, // len(sparse_types)
|
||||
Variadic<TensorOf<[F32, I64, TF_Str]>>:$dense_values, // len(Tdense)
|
||||
Variadic<TensorOf<[F32, I64, TF_Str]>>:$ragged_values, // len(ragged_value_types)
|
||||
Variadic<TF_Int64Tensor>:$sparse_indices, // len(sparse_types)
|
||||
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$sparse_values, // len(sparse_types)
|
||||
Variadic<TF_Int64Tensor>:$sparse_shapes, // len(sparse_types)
|
||||
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$dense_values, // len(Tdense)
|
||||
Variadic<TensorOf<[TF_Float32, TF_Int64, TF_Str]>>:$ragged_values, // len(ragged_value_types)
|
||||
// = len(ragged_split_types)
|
||||
Variadic<TensorOf<[I32, I64]>>:$ragged_row_splits // len(ragged_split_types)
|
||||
Variadic<TensorOf<[TF_Int32, TF_Int64]>>:$ragged_row_splits // len(ragged_split_types)
|
||||
// = len(ragged_value_types)
|
||||
);
|
||||
|
||||
@ -735,7 +753,7 @@ element_dtype: the desired type of elements in the list.
|
||||
|
||||
let arguments = (ins
|
||||
TF_I32OrI64Tensor:$element_shape,
|
||||
I32Tensor:$num_elements
|
||||
TF_Int32Tensor:$num_elements
|
||||
);
|
||||
}
|
||||
|
||||
@ -956,8 +974,8 @@ Creates a dataset that batches `batch_size` elements from `input_dataset`.
|
||||
|
||||
let arguments = (ins
|
||||
TF_VariantTensor:$input_dataset,
|
||||
I64Tensor:$batch_size,
|
||||
I1Tensor:$drop_remainder,
|
||||
TF_Int64Tensor:$batch_size,
|
||||
TF_BoolTensor:$drop_remainder,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$parallel_copy,
|
||||
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||
@ -1006,9 +1024,9 @@ to `batch_size * num_parallel_batches` copies of `f` in parallel.
|
||||
let arguments = (ins
|
||||
TF_VariantTensor:$input_dataset,
|
||||
Variadic<TF_Tensor>:$other_arguments,
|
||||
I64Tensor:$batch_size,
|
||||
I64Tensor:$num_parallel_calls,
|
||||
I1Tensor:$drop_remainder,
|
||||
TF_Int64Tensor:$batch_size,
|
||||
TF_Int64Tensor:$num_parallel_calls,
|
||||
TF_BoolTensor:$drop_remainder,
|
||||
|
||||
SymbolRefAttr:$f,
|
||||
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||
@ -1036,7 +1054,7 @@ def TF_ParallelMapDatasetOp : TF_Op<"ParallelMapDataset", [NoSideEffect]> {
|
||||
let arguments = (ins
|
||||
TF_VariantTensor:$input_dataset,
|
||||
Variadic<TF_Tensor>:$other_arguments,
|
||||
I32Tensor:$num_parallel_calls,
|
||||
TF_Int32Tensor:$num_parallel_calls,
|
||||
|
||||
SymbolRefAttr:$f,
|
||||
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||
@ -1118,11 +1136,11 @@ This function is faster and numerically stabler than `bessel_i0(x)`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_FpTensor:$x
|
||||
TF_FloatTensor:$x
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_FpTensor:$y
|
||||
TF_FloatTensor:$y
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -1139,11 +1157,11 @@ This function is faster and numerically stabler than `bessel_i1(x)`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_FpTensor:$x
|
||||
TF_FloatTensor:$x
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_FpTensor:$y
|
||||
TF_FloatTensor:$y
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -1154,7 +1172,7 @@ def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> {
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_Tensor>:$args,
|
||||
I32Tensor:$device_ordinal,
|
||||
TF_Int32Tensor:$device_ordinal,
|
||||
|
||||
SymbolRefAttr:$f,
|
||||
DefaultValuedAttr<I64Attr, "0">:$autotuner_thresh
|
||||
@ -1264,12 +1282,12 @@ def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastable
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref, TF_Uint32Ref]>:$x,
|
||||
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref, TF_Uint32Ref]>:$y
|
||||
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32]>:$x,
|
||||
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref, TF_Uint32Ref]>:$z
|
||||
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -1289,12 +1307,12 @@ def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$x,
|
||||
TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$y
|
||||
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Complex]>:$x,
|
||||
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Complex]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$z
|
||||
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Complex]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -1310,12 +1328,12 @@ def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, TF
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$x,
|
||||
TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$y
|
||||
TensorOf<[TF_Float, TF_Int16, TF_Int32, TF_Int64, TF_Uint8]>:$x,
|
||||
TensorOf<[TF_Float, TF_Int16, TF_Int32, TF_Int64, TF_Uint8]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$z
|
||||
TensorOf<[TF_Float, TF_Int16, TF_Int32, TF_Int64, TF_Uint8]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -1333,12 +1351,12 @@ If `x` and `y` are reals, this will return the floating-point division.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$x,
|
||||
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$y
|
||||
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8]>:$x,
|
||||
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$z
|
||||
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -1362,12 +1380,12 @@ Both input and output have a range `(-inf, inf)`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_NumberNotQuantizedOrStr, TF_NumberNotQuantizedOrStrRef]>:$x,
|
||||
TensorOf<[TF_NumberNotQuantizedOrStr, TF_NumberNotQuantizedOrStrRef]>:$y
|
||||
TensorOf<[TF_NumberNotQuantizedOrStr]>:$x,
|
||||
TensorOf<[TF_NumberNotQuantizedOrStr]>:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_NumberNotQuantizedOrStr, TF_NumberNotQuantizedOrStrRef]>:$z
|
||||
TensorOf<[TF_NumberNotQuantizedOrStr]>:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
@ -1383,13 +1401,13 @@ The generated values will have mean 0 and standard deviation 1.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_VariableRead, TF_VariableWrite]>:$resource,
|
||||
I64Tensor:$algorithm,
|
||||
Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource,
|
||||
TF_Int64Tensor:$algorithm,
|
||||
TF_I32OrI64Tensor:$shape
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_FpTensor:$output
|
||||
TF_FloatTensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>;
|
||||
@ -1407,12 +1425,12 @@ deviations from the mean are dropped and re-picked.
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource,
|
||||
I64Tensor:$algorithm,
|
||||
TF_Int64Tensor:$algorithm,
|
||||
TF_I32OrI64Tensor:$shape
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_FpTensor:$output
|
||||
TF_FloatTensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>;
|
||||
@ -1429,12 +1447,12 @@ lower bound 0 is included in the range, while the upper bound 1 is excluded.
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource,
|
||||
I64Tensor:$algorithm,
|
||||
TF_Int64Tensor:$algorithm,
|
||||
TF_I32OrI64Tensor:$shape
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_FpTensor:$output
|
||||
TF_FloatTensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>;
|
||||
@ -1450,12 +1468,12 @@ The generated values are uniform integers covering the whole range of `dtype`.
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource,
|
||||
I64Tensor:$algorithm,
|
||||
TF_Int64Tensor:$algorithm,
|
||||
TF_I32OrI64Tensor:$shape
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[I32, I64, TF_Uint32, TF_Uint64]>:$output
|
||||
TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>;
|
||||
@ -1480,14 +1498,14 @@ smaller than the range of the output (either `2^32` or `2^64`).
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_VariableRead,TF_VariableWrite]>:$resource,
|
||||
I64Tensor:$algorithm,
|
||||
TF_Int64Tensor:$algorithm,
|
||||
TF_I32OrI64Tensor:$shape,
|
||||
TensorOf<[I32, I64, TF_Uint32, TF_Uint64]>:$minval,
|
||||
TensorOf<[I32, I64, TF_Uint32, TF_Uint64]>:$maxval
|
||||
TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$minval,
|
||||
TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$maxval
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[I32, I64, TF_Uint32, TF_Uint64]>:$output
|
||||
TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr shape_dtype = TF_DerivedOperandTypeAttr<2>;
|
||||
@ -1560,8 +1578,8 @@ filename_suffix: Every event file's name is suffixed with this suffix.
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
|
||||
TF_StrTensor:$logdir,
|
||||
I32Tensor:$max_queue,
|
||||
I32Tensor:$flush_millis,
|
||||
TF_Int32Tensor:$max_queue,
|
||||
TF_Int32Tensor:$flush_millis,
|
||||
TF_StrTensor:$filename_suffix
|
||||
);
|
||||
|
||||
@ -1647,10 +1665,10 @@ max_outputs: Max number of batch elements to generate audio for.
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
|
||||
I64Tensor:$step,
|
||||
TF_Int64Tensor:$step,
|
||||
TF_StrTensor:$tag,
|
||||
F32Tensor:$tensor,
|
||||
F32Tensor:$sample_rate,
|
||||
TF_Float32Tensor:$tensor,
|
||||
TF_Float32Tensor:$sample_rate,
|
||||
|
||||
Confined<DefaultValuedAttr<I64Attr, "3">, [IntMinValue<1>]>:$max_outputs
|
||||
);
|
||||
@ -1669,7 +1687,7 @@ tensor: A scalar string of the serialized tf.GraphDef proto.
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
|
||||
I64Tensor:$step,
|
||||
TF_Int64Tensor:$step,
|
||||
TF_StrTensor:$tensor
|
||||
);
|
||||
|
||||
@ -1694,7 +1712,7 @@ values: Any shape. Values to use to build the histogram.
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
|
||||
I64Tensor:$step,
|
||||
TF_Int64Tensor:$step,
|
||||
TF_StrTensor:$tag,
|
||||
TF_IntOrFpTensor:$values
|
||||
);
|
||||
@ -1753,9 +1771,9 @@ bad_color: Color to use for pixels with non-finite values.
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
|
||||
I64Tensor:$step,
|
||||
TF_Int64Tensor:$step,
|
||||
TF_StrTensor:$tag,
|
||||
TensorOf<[F16, F32, TF_Uint8]>:$tensor,
|
||||
TensorOf<[TF_Float16, TF_Float32, TF_Uint8]>:$tensor,
|
||||
TF_Uint8Tensor:$bad_color,
|
||||
|
||||
Confined<DefaultValuedAttr<I64Attr, "3">, [IntMinValue<1>]>:$max_images
|
||||
@ -1777,7 +1795,7 @@ tensor: A tensor holding one or more serialized `Summary` protobufs to write.
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
|
||||
I64Tensor:$step,
|
||||
TF_Int64Tensor:$step,
|
||||
TF_StrTensor:$tensor
|
||||
);
|
||||
|
||||
@ -1798,7 +1816,7 @@ value: Value for the summary.
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
|
||||
I64Tensor:$step,
|
||||
TF_Int64Tensor:$step,
|
||||
TF_StrTensor:$tag,
|
||||
TF_IntOrFpTensor:$value
|
||||
);
|
||||
@ -1822,7 +1840,7 @@ summary_metadata: Serialized SummaryMetadata protocol buffer containing
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_SummaryWrite]>:$writer,
|
||||
I64Tensor:$step,
|
||||
TF_Int64Tensor:$step,
|
||||
TF_Tensor:$tensor,
|
||||
TF_StrTensor:$tag,
|
||||
TF_StrTensor:$summary_metadata
|
||||
@ -1833,7 +1851,6 @@ summary_metadata: Serialized SummaryMetadata protocol buffer containing
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>;
|
||||
}
|
||||
|
||||
// TODO(b/168035831): Model dataset read.
|
||||
def TF_InitializeTableFromDatasetOp : TF_Op<"InitializeTableFromDataset", []> {
|
||||
let summary = "";
|
||||
|
||||
@ -1875,4 +1892,22 @@ Where to extract the key and value from a line is specified by `key_index` and
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
// TODO(b/168035831): Model filename read.
|
||||
def TF_CacheDatasetV2Op : TF_Op<"CacheDatasetV2", []> {
|
||||
let summary = "";
|
||||
|
||||
let arguments = (ins
|
||||
TF_VariantTensor:$input_dataset,
|
||||
TF_StrTensor:$filename,
|
||||
Arg<TF_ResourceTensor, "", [TF_DatasetMemoryCacheRead, TF_DatasetMemoryCacheWrite]>:$cache,
|
||||
|
||||
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_VariantTensor:$handle
|
||||
);
|
||||
}
|
||||
|
||||
#endif // TF_OPS
|
||||
|
@ -546,8 +546,7 @@ LogicalResult FoldConstantCaseOp::matchAndRewrite(
|
||||
if (!matchPattern(op.branch_index(), m_Constant(&branch))) return failure();
|
||||
|
||||
int index = *branch.getValues<int>().begin();
|
||||
if (index < 0 || index >= op.branches().size())
|
||||
index = op.branches().size() - 1;
|
||||
if (index < 0 || index >= op.num_branches()) index = op.num_branches() - 1;
|
||||
|
||||
auto func = op.branches()[index].cast<SymbolRefAttr>();
|
||||
auto empty = rewriter.getStringAttr("");
|
||||
|
@ -1032,6 +1032,7 @@ static LogicalResult Verify(SizeOp op) {
|
||||
|
||||
OpFoldResult SizeOp::fold(ArrayRef<Attribute> operands) {
|
||||
ShapedType output_type = getType().cast<ShapedType>();
|
||||
if (!output_type.hasRank()) return {};
|
||||
ShapedType input_type = getOperand().getType().cast<ShapedType>();
|
||||
if (!input_type.hasStaticShape()) return {};
|
||||
int size = input_type.getNumElements();
|
||||
|
@ -43,6 +43,16 @@ struct LookupTable : ::mlir::SideEffects::Resource::Base<LookupTable> {
|
||||
StringRef getName() final { return "LookupTable"; }
|
||||
};
|
||||
|
||||
struct DatasetSeedGenerator
|
||||
: ::mlir::SideEffects::Resource::Base<DatasetSeedGenerator> {
|
||||
StringRef getName() final { return "DatasetSeedGenerator"; }
|
||||
};
|
||||
|
||||
struct DatasetMemoryCache
|
||||
: ::mlir::SideEffects::Resource::Base<DatasetMemoryCache> {
|
||||
StringRef getName() final { return "DatasetMemoryCache"; }
|
||||
};
|
||||
|
||||
} // namespace ResourceEffects
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
@ -74,6 +74,17 @@ func @ignore_embedding_ops() -> () {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @ignore_stack_ops
|
||||
func @ignore_stack_ops(%arg0: tensor<i32>) -> () {
|
||||
"tf_device.cluster"() ( {
|
||||
// CHECK: "tf.StackV2"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
%0 = "tf.StackV2"(%arg0) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
|
||||
tf_device.return
|
||||
}) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @op_string_result
|
||||
func @op_string_result() -> tensor<i32> {
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
|
@ -575,4 +575,15 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
func @pcall_resource_result_func(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>> {
|
||||
return %arg0 : tensor<*x!tf.resource<tensor<f32>>>
|
||||
}
|
||||
|
||||
// Check that the fold for tf.Size does not crash with unranked output type.
|
||||
// CHECK-LABEL: func @unranked_tf_size
|
||||
func @unranked_tf_size() -> tensor<*xi32> {
|
||||
%0 = "tf.Const"() {value = dense<[-1, 26]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%add = "tf.AddV2"(%0, %0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<*xi32>
|
||||
// CHECK: "tf.Size"
|
||||
// CHECK-SAME: (tensor<2xi32>) -> tensor<i32>
|
||||
%size = "tf.Size"(%add) {device = ""} : (tensor<*xi32>) -> tensor<*xi32>
|
||||
return %size : tensor<*xi32>
|
||||
}
|
||||
}
|
||||
|
@ -242,7 +242,7 @@ func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<1000
|
||||
func @testReshape(tensor<*xf32>, tensor<*xf32>) -> (tensor<100x100xf32>) {
|
||||
^bb0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>):
|
||||
%shape1 = constant dense<100.> : tensor<2xf32>
|
||||
// expected-error @+1 {{must be tensor of 32/64-bit signless integer values}}
|
||||
// expected-error @+1 {{must be tensor of 32/64-bit signed integer values}}
|
||||
%r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<*xf32>, tensor<2xf32>) -> (tensor<100x100xf32>)
|
||||
return %r1 : tensor<100x100xf32>
|
||||
}
|
||||
@ -1967,7 +1967,7 @@ func @testValidShape(tensor<1x32x32x16xf32>, tensor<*xf32>) -> (tensor<4xi32>, t
|
||||
// -----
|
||||
|
||||
func @testShapeWrongResultElemType(%arg0: tensor<1x32x32x16xf32>) -> tensor<4xf32> {
|
||||
// expected-error @+1 {{result #0 must be tensor of 32/64-bit signless integer values}}
|
||||
// expected-error @+1 {{result #0 must be tensor of 32/64-bit signed integer values}}
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<1x32x32x16xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -2011,7 +2011,7 @@ func @testValidShapeN(%arg0 : tensor<1x32x32x16xf32>, %arg1 : tensor<*xf32>) ->
|
||||
// -----
|
||||
|
||||
func @testShapeNWrongResultElemType(%arg0: tensor<1x32x32x16xf32>) -> tensor<4xf32> {
|
||||
// expected-error @+1 {{result #1 must be tensor of 32/64-bit signless integer values}}
|
||||
// expected-error @+1 {{result #1 must be tensor of 32/64-bit signed integer values}}
|
||||
%0:2 = "tf.ShapeN"(%arg0, %arg0) : (tensor<1x32x32x16xf32>, tensor<1x32x32x16xf32>) -> (tensor<4xi32>, tensor<4xf32>)
|
||||
return %0#1 : tensor<4xf32>
|
||||
}
|
||||
@ -2072,7 +2072,7 @@ func @testVariableShapeMultipleSubtypes(%arg0: tensor<*x!tf.resource<tensor<1x32
|
||||
// -----
|
||||
|
||||
func @testVariableShapeWrongResultElemType(%arg0: tensor<*x!tf.resource<tensor<1x32x32x16xf32>>>) -> tensor<?xf32> {
|
||||
// expected-error @+1 {{result #0 must be tensor of 32/64-bit signless integer values}}
|
||||
// expected-error @+1 {{result #0 must be tensor of 32/64-bit signed integer values}}
|
||||
%0 = "tf.VariableShape"(%arg0) : (tensor<*x!tf.resource<tensor<1x32x32x16xf32>>>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -2208,7 +2208,7 @@ func @testTranspose(tensor<2x3x4xf32>) -> tensor<3x2x4xf32> {
|
||||
// Test invalid tf.Less
|
||||
func @testLess(tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> {
|
||||
^bb0(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>):
|
||||
// expected-error @+1 {{op result #0 must be tensor of 1-bit signless integer values}}
|
||||
// expected-error @+1 {{op result #0 must be tensor of bool values}}
|
||||
%0 = "tf.Less"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
@ -2225,7 +2225,7 @@ func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<i32>) -> tensor<?xf32>
|
||||
|
||||
// tf.ConcatV2 with wrong 'axis' element type
|
||||
func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<f32>) -> tensor<?xf32> {
|
||||
// expected-error @+1 {{operand #2 must be tensor of 32/64-bit signless integer values}}
|
||||
// expected-error @+1 {{operand #2 must be tensor of 32/64-bit signed integer values}}
|
||||
%0 = "tf.ConcatV2"(%arg, %arg, %axis) : (tensor<8x16xf32>, tensor<8x16xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
@ -2258,7 +2258,7 @@ func @testAll64(%arg0: tensor<2x2xi1>, %arg1: tensor<i64>) -> tensor<i1> {
|
||||
// -----
|
||||
|
||||
func @testAllFloat(%arg0: tensor<2x2xi1>, %arg1: tensor<f32>) -> tensor<i1> {
|
||||
// expected-error @+1 {{'tf.All' op operand #1 must be tensor of 32/64-bit signless integer values}}
|
||||
// expected-error @+1 {{'tf.All' op operand #1 must be tensor of 32/64-bit signed integer values}}
|
||||
%0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<f32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
@ -2266,7 +2266,7 @@ func @testAllFloat(%arg0: tensor<2x2xi1>, %arg1: tensor<f32>) -> tensor<i1> {
|
||||
// -----
|
||||
|
||||
func @testAllI32(%arg0: tensor<2x2xi32>, %arg1: tensor<f32>) -> tensor<i32> {
|
||||
// expected-error @+1 {{'tf.All' op operand #0 must be tensor of 1-bit signless integer values}}
|
||||
// expected-error @+1 {{'tf.All' op operand #0 must be tensor of bool values}}
|
||||
%0 = "tf.All"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi32>, tensor<f32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
@ -2,29 +2,81 @@
|
||||
|
||||
// Tests that the pass can correctly transform a training loop with 2 replicas.
|
||||
|
||||
!tf_res_f32 = type tensor<*x!tf.resource<tensor<f32>>>
|
||||
!tf_res_md_f32 = type tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> // Multi-dim f32
|
||||
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
|
||||
// CHECK-LABEL: func @main
|
||||
func @main(%arg0: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg2: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"}) {
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
// CHECK-SAME: %[[ARG1:.*]]: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||
// CHECK-SAME: %[[ARG2:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
||||
// CHECK-SAME: %[[ARG3:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"})
|
||||
func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"},
|
||||
%arg1: !tf_res_f32 {tf.device = "/device:TPU:1"},
|
||||
%arg2: !tf_res_md_f32 {tf.device = "/device:TPU:0"},
|
||||
%arg3: !tf_res_md_f32 {tf.device = "/device:TPU:1"}) {
|
||||
|
||||
%0 = "tf.Const"() {value = dense<100> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[STATE0:.*]] = "tf.VarHandleOp"()
|
||||
// CHECK-SAME: device = "/device:TPU:0"
|
||||
// CHECK: %[[STATE1:.*]] = "tf.VarHandleOp"()
|
||||
// CHECK-SAME: device = "/device:TPU:1"
|
||||
// CHECK: %[[WHILE:.*]]:7 = "tf.While"(
|
||||
// CHECK-SAME: %[[STATE0]], %[[STATE1]])
|
||||
%1:5 = "tf.While"(%0, %arg0, %arg1, %arg2, %arg3)
|
||||
{T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE",
|
||||
"tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE",
|
||||
"tfdtype$DT_RESOURCE"], body = @while_body_7560,
|
||||
cond = @while_cond_7550, device = "", is_stateless = false}
|
||||
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
||||
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
||||
// CHECK: %[[WHILE:.*]] = "tf.WhileRegion"(
|
||||
%1 = "tf.WhileRegion"(%0) ( {
|
||||
// Condition region
|
||||
// CHECK: ^bb
|
||||
// CHECK: "tf.Yield"
|
||||
^bb0(%carg0: tensor<i32>):
|
||||
%c0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
%c1 = "tf.GreaterEqual"(%carg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"tf.Yield"(%c1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
// Body region
|
||||
// CHECK: ^bb0
|
||||
^bb0(%barg0: tensor<i32>):
|
||||
%b0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
%b1 = "tf.AddV2"(%barg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"()
|
||||
%compile:2 = "tf_device.launch"() ( {
|
||||
%b2:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64,
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
tf_device.return %b2#0, %b2#1 : tensor<!tf.string>, tensor<2x!tf.string>
|
||||
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/device:CPU:0"} : () -> ()
|
||||
// CHECK: tf_device.replicate
|
||||
// CHECK-SAME: [%[[ARG0]], %[[ARG1]]] as %[[R0:.*]]: tensor<*x!tf.resource<tensor<f32>>>,
|
||||
// CHECK-SAME: [%[[ARG2]], %[[ARG3]]] as %[[R1:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
// CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[R_STATE:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>>
|
||||
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
|
||||
%rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
|
||||
[%arg2, %arg3] as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
||||
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
|
||||
// CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]])
|
||||
%id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]])
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
|
||||
// CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1)
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
|
||||
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
|
||||
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_device.return %ret : tensor<i32>
|
||||
}
|
||||
// CHECK: "tf.Yield"
|
||||
"tf.Yield"(%b1) : (tensor<i32>) -> ()
|
||||
}) {device = "", is_stateless = false} : (tensor<i32>) -> (tensor<i32>)
|
||||
// CHECK: %[[DEFAULT:.*]] = "tf.Const"()
|
||||
// CHECK: tf_device.replicate
|
||||
// CHECK-SAME: as %[[V0:.*]]: tensor<*x!tf.resource<tensor<f32>>>,
|
||||
@ -37,165 +89,72 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @while_body_7560
|
||||
func @while_body_7560(%arg0: tensor<i32>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"})
|
||||
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) {
|
||||
// CHECK-SAME: (%[[ITER:.*]]: tensor<i32>,
|
||||
// CHECK-SAME: %[[BODY_ARG1:.*]]: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
// CHECK-SAME: %[[BODY_ARG2:.*]]: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||
// CHECK-SAME: %[[BODY_ARG3:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
||||
// CHECK-SAME: %[[BODY_ARG4:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"},
|
||||
// CHECK-SAME: %[[STATE_ARG0:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>> {tf.device = "/device:TPU:0"},
|
||||
// CHECK-SAME: %[[STATE_ARG1:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>> {tf.device = "/device:TPU:1"})
|
||||
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"()
|
||||
%compile:2 = "tf_device.launch"() ( {
|
||||
%2:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64,
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<2x!tf.string>
|
||||
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/device:CPU:0"} : () -> ()
|
||||
// CHECK: tf_device.replicate
|
||||
// CHECK-SAME: [%[[BODY_ARG1]], %[[BODY_ARG2]]] as %[[R0:.*]]: tensor<*x!tf.resource<tensor<f32>>>,
|
||||
// CHECK-SAME: [%[[BODY_ARG3]], %[[BODY_ARG4]]] as %[[R1:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
// CHECK-SAME: [%[[STATE_ARG0]], %[[STATE_ARG1]]] as %[[R_STATE:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>>
|
||||
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
|
||||
%rep:2 = tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
|
||||
[%arg3, %arg4] as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
||||
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
|
||||
// CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]])
|
||||
%id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]])
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
|
||||
// CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1)
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
|
||||
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
|
||||
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_device.return %ret : tensor<i32>
|
||||
}
|
||||
return %1, %arg1, %arg2, %arg3, %arg4 : tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
|
||||
}
|
||||
// CHECK-LABEL: func @while_cond_7550
|
||||
func @while_cond_7550(%arg0: tensor<i32>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"})
|
||||
-> tensor<i1> {
|
||||
%0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.GreaterEqual"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass does not format variables with other uses.
|
||||
|
||||
!tf_res_f32 = type tensor<*x!tf.resource<tensor<f32>>>
|
||||
!tf_res_md_f32 = type tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> // Multi-dim f32
|
||||
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-NOT: TPUReshardVariables
|
||||
func @main(%arg0: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg2: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg4: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg5: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"}) {
|
||||
func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"},
|
||||
%arg1: !tf_res_f32 {tf.device = "/device:TPU:1"},
|
||||
%arg2: !tf_res_md_f32 {tf.device = "/device:TPU:0"},
|
||||
%arg3: !tf_res_md_f32 {tf.device = "/device:TPU:1"},
|
||||
%arg4: !tf_res_f32 {tf.device = "/device:TPU:1"}) {
|
||||
|
||||
%0 = "tf.Const"() {value = dense<100> : tensor<i32>} : () -> tensor<i32>
|
||||
%1:7 = "tf.While"(%0, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
|
||||
{body = @while_body_7560,
|
||||
cond = @while_cond_7550, device = "", is_stateless = false}
|
||||
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
|
||||
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
|
||||
%1 = "tf.WhileRegion"(%0) ( {
|
||||
// Condition region
|
||||
^bb0(%carg0: tensor<i32>):
|
||||
%c0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
%c1 = "tf.GreaterEqual"(%carg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"tf._UnknownOp1_"(%arg1) : (!tf_res_f32) -> ()
|
||||
"tf.Yield"(%c1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
// Body region
|
||||
^bb0(%barg0: tensor<i32>):
|
||||
%b0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
%b1 = "tf.AddV2"(%barg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%compile:2 = "tf_device.launch"() ( {
|
||||
%b2:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64,
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
tf_device.return %b2#0, %b2#1 : tensor<!tf.string>, tensor<2x!tf.string>
|
||||
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/device:CPU:0"} : () -> ()
|
||||
%id0 = "tf.Identity"(%arg3) : (!tf_res_md_f32) -> !tf_res_md_f32
|
||||
"tf._Unknown_"(%id0) : (!tf_res_md_f32) -> ()
|
||||
%newvar = "tf._SomeOp"() : () -> !tf_res_f32
|
||||
%rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: !tf_res_f32,
|
||||
[%arg2, %arg3] as %arg31: !tf_res_md_f32,
|
||||
[%newvar, %arg4] as %arg32 : !tf_res_f32)
|
||||
{_mirrored_variable_indices = [0, 1, 2], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
|
||||
// %arg30 is used in the cond function, %arg31 has other uses (%id0), and
|
||||
// %arg32 is not a pass-through.
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUExecuteAndUpdateVariables"(%arg30, %arg31, %arg32, %compile#1)
|
||||
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
|
||||
: (!tf_res_f32, !tf_res_md_f32, !tf_res_f32, tensor<2x!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
|
||||
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_device.return %ret : tensor<i32>
|
||||
}
|
||||
"tf.Yield"(%b1) : (tensor<i32>) -> ()
|
||||
}) {device = "", is_stateless = false} : (tensor<i32>) -> (tensor<i32>)
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @while_body_7560
|
||||
// CHECK-NOT: TPUReshardVariables
|
||||
func @while_body_7560(%arg0: tensor<i32>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg5: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg6: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"})
|
||||
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>) {
|
||||
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%compile:2 = "tf_device.launch"() ( {
|
||||
%2:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64,
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<2x!tf.string>
|
||||
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/device:CPU:0"} : () -> ()
|
||||
%id0 = "tf.Identity"(%arg3) : (tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) -> tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
|
||||
"tf._Unknown_"(%id0) : (tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) -> ()
|
||||
%newvar = "tf._SomeOp"() : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
|
||||
[%arg3, %arg4] as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
[%newvar, %arg6] as %arg32: tensor<*x!tf.resource<tensor<f32>>>)
|
||||
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
|
||||
// %arg30 is used in the cond function, %arg31 has other uses (%id0), and
|
||||
// %arg32 is not a pass-through.
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUExecuteAndUpdateVariables"(%arg30, %arg31, %arg32, %compile#1)
|
||||
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<f32>>>, tensor<2x!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return %1, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6 : tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>
|
||||
}
|
||||
// CHECK-LABEL: func @while_cond_7550
|
||||
func @while_cond_7550(%arg0: tensor<i32>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg5: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg6: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"})
|
||||
-> tensor<i1> {
|
||||
%0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.GreaterEqual"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"tf._UnknownOp1_"(%arg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> ()
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -203,81 +162,62 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
// Tests that the pass does not format variables when model parallelism is
|
||||
// present.
|
||||
|
||||
!tf_res_f32 = type tensor<*x!tf.resource<tensor<f32>>>
|
||||
!tf_res_md_f32 = type tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> // Multi-dim f32
|
||||
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-NOT: TPUReshardVariables
|
||||
func @main(%arg0: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg2: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"}) {
|
||||
func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"},
|
||||
%arg1: !tf_res_f32 {tf.device = "/device:TPU:1"},
|
||||
%arg2: !tf_res_md_f32 {tf.device = "/device:TPU:0"},
|
||||
%arg3: !tf_res_md_f32 {tf.device = "/device:TPU:1"}) {
|
||||
|
||||
%0 = "tf.Const"() {value = dense<100> : tensor<i32>} : () -> tensor<i32>
|
||||
%1:5 = "tf.While"(%0, %arg0, %arg1, %arg2, %arg3)
|
||||
{T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE",
|
||||
"tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE",
|
||||
"tfdtype$DT_RESOURCE"], body = @while_body_7560,
|
||||
cond = @while_cond_7550, device = "", is_stateless = false}
|
||||
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
||||
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @while_body_7560
|
||||
// CHECK-NOT: TPUReshardVariables
|
||||
func @while_body_7560(%arg0: tensor<i32>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"})
|
||||
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) {
|
||||
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%compile:2 = "tf_device.launch"() ( {
|
||||
%2:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64,
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<2x!tf.string>
|
||||
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/device:CPU:0"} : () -> ()
|
||||
%rep:2 = tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
|
||||
[%arg3, %arg4] as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
||||
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
|
||||
%id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
"tf_device.parallel_execute"() ({
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
|
||||
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
|
||||
tf_device.return
|
||||
%1 = "tf.WhileRegion"(%0) ( {
|
||||
// Condition region
|
||||
^bb0(%carg0: tensor<i32>):
|
||||
%c0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
%c1 = "tf.GreaterEqual"(%carg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"tf.Yield"(%c1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
tf_device.return
|
||||
}) {} : () -> ()
|
||||
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_device.return %ret : tensor<i32>
|
||||
}
|
||||
return %1, %arg1, %arg2, %arg3, %arg4 : tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
|
||||
}
|
||||
// CHECK-LABEL: func @while_cond_7550
|
||||
func @while_cond_7550(%arg0: tensor<i32>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"})
|
||||
-> tensor<i1> {
|
||||
%0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.GreaterEqual"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
// Body region
|
||||
^bb0(%barg0: tensor<i32>):
|
||||
%b0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
%b1 = "tf.AddV2"(%barg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%compile:2 = "tf_device.launch"() ( {
|
||||
%b2:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64,
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
tf_device.return %b2#0, %b2#1 : tensor<!tf.string>, tensor<2x!tf.string>
|
||||
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/device:CPU:0"} : () -> ()
|
||||
%rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
|
||||
[%arg2, %arg3] as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
||||
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
|
||||
%id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
"tf_device.parallel_execute"() ({
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
|
||||
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
|
||||
tf_device.return
|
||||
}, {
|
||||
tf_device.return
|
||||
}) {} : () -> ()
|
||||
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_device.return %ret : tensor<i32>
|
||||
}
|
||||
"tf.Yield"(%b1) : (tensor<i32>) -> ()
|
||||
}) {device = "", is_stateless = false} : (tensor<i32>) -> (tensor<i32>)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@ -285,34 +225,83 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
|
||||
// Tests that the pass can correctly transform a training loop with a packed
|
||||
// variable.
|
||||
!tf_res_f32 = type tensor<*x!tf.resource<tensor<f32>>>
|
||||
!tf_res_md_f32 = type tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> // Multi-dim f32
|
||||
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
|
||||
// CHECK-LABEL: func @main
|
||||
func @main(%arg0: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg2: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:COMPOSITE:0"}) {
|
||||
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
// CHECK-SAME: %[[ARG1:.*]]: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||
// CHECK-SAME: %[[ARG2:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:COMPOSITE:0"})
|
||||
func @main(%arg0: !tf_res_f32 {tf.device = "/device:TPU:0"},
|
||||
%arg1: !tf_res_f32 {tf.device = "/device:TPU:1"},
|
||||
%arg2: !tf_res_md_f32 {tf.device = "/device:COMPOSITE:0"}) {
|
||||
%0 = "tf.Const"() {value = dense<100> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[STATE0:.*]] = "tf.VarHandleOp"()
|
||||
// CHECK-SAME: device = "/device:TPU:0"
|
||||
// CHECK: %[[STATE1:.*]] = "tf.VarHandleOp"()
|
||||
// CHECK-SAME: device = "/device:TPU:1"
|
||||
// CHECK: %[[WHILE:.*]]:6 = "tf.While"(
|
||||
// CHECK-SAME: %[[STATE0]], %[[STATE1]])
|
||||
%1:4 = "tf.While"(%0, %arg0, %arg1, %arg2)
|
||||
{T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE",
|
||||
"tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE"],
|
||||
body = @while_body_7560,
|
||||
cond = @while_cond_7550, device = "", is_stateless = false}
|
||||
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
||||
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
||||
// CHECK: %[[WHILE:.*]] = "tf.WhileRegion"(
|
||||
%1 = "tf.WhileRegion"(%0) ( {
|
||||
// Condition region
|
||||
// CHECK: ^bb
|
||||
// CHECK: "tf.Yield"
|
||||
^bb0(%carg0: tensor<i32>):
|
||||
%c0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
%c1 = "tf.GreaterEqual"(%carg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"tf.Yield"(%c1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
// Body region
|
||||
// CHECK: ^bb0
|
||||
^bb0(%barg0: tensor<i32>):
|
||||
%b0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
%b1 = "tf.AddV2"(%barg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"()
|
||||
%compile:2 = "tf_device.launch"() ( {
|
||||
%b2:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64,
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
tf_device.return %b2#0, %b2#1 : tensor<!tf.string>, tensor<2x!tf.string>
|
||||
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/device:CPU:0"} : () -> ()
|
||||
// CHECK: tf_device.replicate
|
||||
// CHECK-SAME: [%[[ARG0]], %[[ARG1]]] as %[[R0:.*]]: tensor<*x!tf.resource<tensor<f32>>>,
|
||||
// CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[R_STATE:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>>
|
||||
// CHECK-SAME: %[[ARG2]] as %[[R1:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
|
||||
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
|
||||
%rep:2 = tf_device.replicate([%arg0, %arg1] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
|
||||
%arg2 as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
||||
{_mirrored_variable_indices = [0, 1], _packed_input_indices = [1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
|
||||
// CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]])
|
||||
%id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]])
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
|
||||
// CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1)
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
|
||||
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
|
||||
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_device.return %ret : tensor<i32>
|
||||
}
|
||||
// CHECK: "tf.Yield"
|
||||
"tf.Yield"(%b1) : (tensor<i32>) -> ()
|
||||
}) {device = "", is_stateless = false} : (tensor<i32>) -> (tensor<i32>)
|
||||
// CHECK: %[[DEFAULT:.*]] = "tf.Const"()
|
||||
// CHECK: tf_device.replicate
|
||||
// CHECK-SAME: as %[[V0:.*]]: tensor<*x!tf.resource<tensor<f32>>>,
|
||||
// CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[STATE:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>>,
|
||||
// CHECK-SAME: as %[[V1:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
|
||||
// CHECK-SAME: [%[[ARG0]], %[[ARG1]]] as %[[V0:.*]]: tensor<*x!tf.resource<tensor<f32>>>,
|
||||
// CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[STATE:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>>
|
||||
// CHECK-SAME: %[[ARG2]] as %[[V1:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
|
||||
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUReshardVariables"(%[[V0]], %[[V1]], %[[DEFAULT]], %[[STATE]])
|
||||
@ -320,70 +309,4 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @while_body_7560
|
||||
func @while_body_7560(%arg0: tensor<i32>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:COMPOSITE:0"})
|
||||
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) {
|
||||
// CHECK-SAME: (%[[ITER:.*]]: tensor<i32>,
|
||||
// CHECK-SAME: %[[BODY_ARG1:.*]]: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
// CHECK-SAME: %[[BODY_ARG2:.*]]: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||
// CHECK-SAME: %[[BODY_ARG3:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:COMPOSITE:0"},
|
||||
// CHECK-SAME: %[[STATE_ARG0:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>> {tf.device = "/device:TPU:0"},
|
||||
// CHECK-SAME: %[[STATE_ARG1:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>> {tf.device = "/device:TPU:1"})
|
||||
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"()
|
||||
%compile:2 = "tf_device.launch"() ( {
|
||||
%2:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64,
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<2x!tf.string>
|
||||
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/device:CPU:0"} : () -> ()
|
||||
// CHECK: tf_device.replicate
|
||||
// CHECK-SAME: [%[[BODY_ARG1]], %[[BODY_ARG2]]] as %[[R0:.*]]: tensor<*x!tf.resource<tensor<f32>>>,
|
||||
// CHECK-SAME: [%[[STATE_ARG0]], %[[STATE_ARG1]]] as %[[R_STATE:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>>,
|
||||
// CHECK-SAME: %[[BODY_ARG3]] as %[[R1:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
|
||||
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
|
||||
%rep:2 = tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
|
||||
%arg3 as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
||||
{_mirrored_variable_indices = [0, 1], _packed_input_indices = [1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
|
||||
// CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]])
|
||||
%id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]])
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
|
||||
// CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1)
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
|
||||
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
|
||||
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_device.return %ret : tensor<i32>
|
||||
}
|
||||
return %1, %arg1, %arg2, %arg3 : tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
|
||||
}
|
||||
// CHECK-LABEL: func @while_cond_7550
|
||||
func @while_cond_7550(%arg0: tensor<i32>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:COMPOSITE:0"})
|
||||
-> tensor<i1> {
|
||||
%0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.GreaterEqual"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
}
|
||||
|
@ -120,8 +120,8 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
|
||||
pm.addNestedPass<FuncOp>(CreateTPUParallelExecuteSinkResourceWritePass());
|
||||
pm.addNestedPass<FuncOp>(CreateTPUMergeVariablesWithExecutePass());
|
||||
pm.addNestedPass<FuncOp>(CreateTPUColocateCompositeResourceOps());
|
||||
pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
|
||||
pm.addPass(CreateTPUVariableReformattingPass());
|
||||
pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
|
||||
}
|
||||
|
||||
void CreateTPUBridgePipelineV1(OpPassManager &pm) {
|
||||
|
@ -217,7 +217,7 @@ def : Pat<(TF_RsqrtGradOp $lhs, $rhs),
|
||||
|
||||
// TODO(hinsu): Support complex input types.
|
||||
def LowerTanhGradOp :
|
||||
Pat<(TF_TanhGradOp TF_FpTensor:$y, TF_FpTensor:$dy),
|
||||
Pat<(TF_TanhGradOp TF_FloatTensor:$y, TF_FloatTensor:$dy),
|
||||
(TF_MulOp $dy,
|
||||
(TF_SubOp (TF_ConstOp (GetScalarOfType<1> $y)),
|
||||
(TF_SquareOp $y)))>;
|
||||
|
@ -77,6 +77,47 @@ void AddRewrittenEmbeddingOps(MLIRContext* context,
|
||||
TF::SendTPUEmbeddingGradientsOp::getOperationName(), context));
|
||||
}
|
||||
|
||||
// Stack, TensorList and TensorArray ops are rewritten during the second phase
|
||||
// of the bridge (compilation of TPUCompile op). They would not match any
|
||||
// legalization/canonicalization pattern and have to be manually added to the
|
||||
// list of supported ops.
|
||||
void AddRewrittenCompositeOps(MLIRContext* context,
|
||||
llvm::DenseSet<OperationName>* supported_ops) {
|
||||
#define GET_OPERATION_NAME(op) OperationName(op::getOperationName(), context)
|
||||
llvm::SmallDenseSet<OperationName, 32> allowlist_ops = {
|
||||
// Stack ops.
|
||||
GET_OPERATION_NAME(TF::StackV2Op),
|
||||
GET_OPERATION_NAME(TF::StackPushV2Op),
|
||||
GET_OPERATION_NAME(TF::StackPopV2Op),
|
||||
// Tensor Array ops.
|
||||
GET_OPERATION_NAME(TF::TensorArrayV3Op),
|
||||
GET_OPERATION_NAME(TF::TensorArrayReadV3Op),
|
||||
GET_OPERATION_NAME(TF::TensorArrayWriteV3Op),
|
||||
GET_OPERATION_NAME(TF::TensorArrayConcatV3Op),
|
||||
GET_OPERATION_NAME(TF::TensorArraySplitV3Op),
|
||||
GET_OPERATION_NAME(TF::TensorArraySizeV3Op),
|
||||
GET_OPERATION_NAME(TF::TensorArrayGradV3Op),
|
||||
GET_OPERATION_NAME(TF::TensorArrayGatherV3Op),
|
||||
GET_OPERATION_NAME(TF::TensorArrayScatterV3Op),
|
||||
GET_OPERATION_NAME(TF::TensorListFromTensorOp),
|
||||
// Tensor List Ops.
|
||||
GET_OPERATION_NAME(TF::EmptyTensorListOp),
|
||||
GET_OPERATION_NAME(TF::TensorListReserveOp),
|
||||
GET_OPERATION_NAME(TF::TensorListFromTensorOp),
|
||||
GET_OPERATION_NAME(TF::TensorListPushBackOp),
|
||||
GET_OPERATION_NAME(TF::TensorListPopBackOp),
|
||||
GET_OPERATION_NAME(TF::TensorListGetItemOp),
|
||||
GET_OPERATION_NAME(TF::TensorListSetItemOp),
|
||||
GET_OPERATION_NAME(TF::TensorListLengthOp),
|
||||
GET_OPERATION_NAME(TF::TensorListElementShapeOp),
|
||||
GET_OPERATION_NAME(TF::TensorListGatherOp),
|
||||
GET_OPERATION_NAME(TF::TensorListScatterIntoExistingListOp),
|
||||
};
|
||||
#undef GET_OPERATION_NAME
|
||||
|
||||
supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());
|
||||
}
|
||||
|
||||
bool HasStringOperand(Operation& op) {
|
||||
for (auto operand : op.getOperands()) {
|
||||
if (getElementTypeOrSelf(operand).isa<TF::StringType>()) return true;
|
||||
@ -201,6 +242,7 @@ void MarkOpsForOutsideCompilation::runOnOperation() {
|
||||
}
|
||||
AddSupportedControlFlowOps(module.getContext(), &supported_ops);
|
||||
AddRewrittenEmbeddingOps(module.getContext(), &supported_ops);
|
||||
AddRewrittenCompositeOps(module.getContext(), &supported_ops);
|
||||
|
||||
auto result = module.walk([&](tf_device::ClusterOp cluster) {
|
||||
// Only if `allow_soft_placement` attribute is true should we mark ops
|
||||
|
@ -180,8 +180,7 @@ tf_executor::IslandOp CreateInputBarrierIsland(
|
||||
|
||||
// Create YieldOp for the new input sink island.
|
||||
builder->setInsertionPointToEnd(&input_sink_island.GetBody());
|
||||
builder->create<tf_executor::YieldOp>(island_op.getLoc(),
|
||||
llvm::to_vector<8>(island_inputs));
|
||||
builder->create<tf_executor::YieldOp>(island_op.getLoc(), island_inputs);
|
||||
return input_sink_island;
|
||||
}
|
||||
|
||||
|
@ -501,9 +501,8 @@ void RegionResourceHoister::ReplaceOpWithNewOp() {
|
||||
OpBuilder builder(op_);
|
||||
// Clone ths old operation but with new result types.
|
||||
Operation* new_op = Operation::create(
|
||||
op_->getLoc(), op_->getName(), new_result_types,
|
||||
llvm::to_vector<4>(op_->getOperands()), op_->getAttrs(),
|
||||
llvm::to_vector<4>(op_->getSuccessors()), op_->getNumRegions());
|
||||
op_->getLoc(), op_->getName(), new_result_types, op_->getOperands(),
|
||||
op_->getAttrs(), op_->getSuccessors(), op_->getNumRegions());
|
||||
builder.insert(new_op);
|
||||
|
||||
// Move regions to the new op.
|
||||
@ -1224,14 +1223,11 @@ LogicalResult HoistForControlFlow(
|
||||
return failure();
|
||||
} else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {
|
||||
SmallVector<FuncOp, 4> branch_functions;
|
||||
branch_functions.reserve(case_op.branches().size());
|
||||
for (const Attribute& branch : case_op.branches()) {
|
||||
FuncOp func =
|
||||
module.lookupSymbol<FuncOp>(branch.cast<FlatSymbolRefAttr>());
|
||||
case_op.get_branch_functions(branch_functions);
|
||||
for (FuncOp func : branch_functions) {
|
||||
// Recursively handle the nested control flow.
|
||||
HoistForControlFlow(&func.front(), module,
|
||||
lifted_partitioned_call_callees);
|
||||
branch_functions.push_back(func);
|
||||
}
|
||||
if (failed(HandleCaseOrIfOp(case_op, branch_functions))) return failure();
|
||||
} else if (auto call_op = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
|
||||
|
@ -86,9 +86,8 @@ void EliminateUnusedResults(
|
||||
// Rebuild the new operation with lesser number of results.
|
||||
OpBuilder builder(op);
|
||||
Operation *new_op = Operation::create(
|
||||
op->getLoc(), op->getName(), new_result_types,
|
||||
llvm::to_vector<4>(op->getOperands()), op->getAttrs(),
|
||||
llvm::to_vector<4>(op->getSuccessors()), op->getNumRegions());
|
||||
op->getLoc(), op->getName(), new_result_types, op->getOperands(),
|
||||
op->getAttrs(), op->getSuccessors(), op->getNumRegions());
|
||||
builder.insert(new_op);
|
||||
|
||||
// Move region bodies to the new operation.
|
||||
@ -415,11 +414,7 @@ LogicalResult CleanupAndCanonicalize(Operation *parent_op) {
|
||||
op, {if_op.then_function(), if_op.else_function()}, if_op.input());
|
||||
} else if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
|
||||
SmallVector<FuncOp, 4> branches;
|
||||
for (Attribute branch : case_op.branches()) {
|
||||
auto sym = branch.cast<FlatSymbolRefAttr>();
|
||||
branches.push_back(
|
||||
SymbolTable::lookupNearestSymbolFrom<FuncOp>(op, sym));
|
||||
}
|
||||
case_op.get_branch_functions(branches);
|
||||
result = CanonicalizeFunctionalIfCase(case_op, branches, case_op.input());
|
||||
} else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
|
||||
if (while_op.cond_function().walk(check_while_cond).wasInterrupted())
|
||||
|
@ -927,10 +927,7 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
|
||||
{if_op.then_function(), if_op.else_function()}, max_iteration);
|
||||
} else if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
|
||||
SmallVector<FuncOp, 4> branches;
|
||||
for (Attribute branch : case_op.branches()) {
|
||||
auto sym = branch.cast<FlatSymbolRefAttr>();
|
||||
branches.push_back(SymbolTable::lookupNearestSymbolFrom<FuncOp>(op, sym));
|
||||
}
|
||||
case_op.get_branch_functions(branches);
|
||||
return PropagateShapeToFunctions(module,
|
||||
drop_begin(case_op.getOperandTypes(), 1),
|
||||
branches, max_iteration);
|
||||
|
@ -139,8 +139,7 @@ void ModifyFunctionSignature(
|
||||
handle_new_size_vars(func.getArguments().drop_front(original_arg_count));
|
||||
}
|
||||
func.setType(FunctionType::get(
|
||||
new_input_types,
|
||||
llvm::to_vector<8>(func.front().getTerminator()->getOperandTypes()),
|
||||
new_input_types, func.front().getTerminator()->getOperandTypes(),
|
||||
func.getContext()));
|
||||
}
|
||||
|
||||
|
@ -708,10 +708,7 @@ LogicalResult DecomposeTensorListOpsInternal(
|
||||
}
|
||||
} else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {
|
||||
SmallVector<FuncOp, 2> branches;
|
||||
for (auto branch_symbol : case_op.branches()) {
|
||||
branches.push_back(module.lookupSymbol<FuncOp>(
|
||||
branch_symbol.cast<FlatSymbolRefAttr>()));
|
||||
}
|
||||
case_op.get_branch_functions(branches);
|
||||
if (failed(HandleCaseOrIfOp(case_op, branches, module, buffer_to_size,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
|
@ -115,9 +115,8 @@ struct TPUSpaceToDepthPass
|
||||
|
||||
// Updates func argument type to have the updated input shape.
|
||||
void UpdateFuncType(FuncOp func) {
|
||||
auto arg_types = llvm::to_vector<8>(func.front().getArgumentTypes());
|
||||
auto result_types =
|
||||
llvm::to_vector<4>(func.front().getTerminator()->getOperandTypes());
|
||||
auto arg_types = func.front().getArgumentTypes();
|
||||
auto result_types = func.front().getTerminator()->getOperandTypes();
|
||||
func.setType(FunctionType::get(arg_types, result_types, func.getContext()));
|
||||
}
|
||||
|
||||
|
@ -138,14 +138,17 @@ Value SkipIdentity(Value v, bool allow_other_use,
|
||||
|
||||
// Finds the formattable arguments of `execute` and annotates the metadata of
|
||||
// `compile` to record these arguments. In addition, it returns a mapping from
|
||||
// the formattable arguments of `execute` to the corresponding arguments of
|
||||
// `while_op` (which should be passed through to `execute` via `replicate`). The
|
||||
// the formattable arguments of `execute` to the corresponding operand of
|
||||
// `replicate`. The
|
||||
// entries in the mapping are sorted in the order of operands of `execute`.
|
||||
llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4>
|
||||
AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
|
||||
TF::WhileOp while_op, tf_device::ReplicateOp replicate,
|
||||
TF::WhileRegionOp while_op, tf_device::ReplicateOp replicate,
|
||||
TF::TPUExecuteAndUpdateVariablesOp execute,
|
||||
tf_device::LaunchOp compile_launch, FuncOp body, FuncOp cond) {
|
||||
tf_device::LaunchOp compile_launch) {
|
||||
Region& body = while_op.body();
|
||||
Region& cond = while_op.cond();
|
||||
|
||||
llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4> mapping;
|
||||
auto mirrored_variable_indices_attr =
|
||||
replicate.getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
|
||||
@ -204,39 +207,43 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
|
||||
// We have found a mirrored variable which is an input to the replicated
|
||||
// `execute`. Now find if this mirrored variable is a pass-through of while
|
||||
// arguments.
|
||||
llvm::SmallVector<Value, 4> while_args;
|
||||
llvm::SmallVector<Value, 4> replicate_args;
|
||||
for (int64_t i = 0; i < num_inputs; ++i) {
|
||||
llvm::SmallPtrSet<Operation*, 4> skipped_identities;
|
||||
|
||||
auto replicate_operand = SkipIdentity(
|
||||
replicate.GetReplicaOperandForBlockArgument(block_arg, i),
|
||||
/*allow_other_use=*/false, &skipped_identities);
|
||||
auto block_arg = replicate_operand.dyn_cast<BlockArgument>();
|
||||
// To qualify for a valid pass-through mirrored variable, it must satisfy
|
||||
// 1) it is the body's argument;
|
||||
// 2) it has no other uses than `replicate`, the skipped identitiy ops,
|
||||
// or the return;
|
||||
// 3) the corresponding argument in the cond function has no uses.
|
||||
if (!block_arg || block_arg.getOwner() != &body.front() ||
|
||||
llvm::any_of(replicate_operand.getUsers(),
|
||||
[&](Operation* user) {
|
||||
return user != body.front().getTerminator() &&
|
||||
skipped_identities.count(user) == 0 &&
|
||||
user != replicate;
|
||||
}) ||
|
||||
!cond.getArgument(block_arg.getArgNumber()).use_empty()) {
|
||||
while_args.clear();
|
||||
// For region based control flow, the resource operand for the replicate
|
||||
// should be a region capture. If this has any use other than the
|
||||
// replicate op (within the body of the while) or the skipped identities,
|
||||
// then do not apply the transformation to this variable.
|
||||
bool is_region_capture =
|
||||
replicate_operand.getParentRegion()->isProperAncestor(&body);
|
||||
bool has_other_use_in_body =
|
||||
llvm::any_of(replicate_operand.getUsers(), [&](Operation* user) {
|
||||
// Ignore uses that are not in the while body or condition.
|
||||
if (!body.isAncestor(user->getParentRegion()) &&
|
||||
!cond.isAncestor(user->getParentRegion()))
|
||||
return false;
|
||||
// Within the body or cond, only uses in replicate and the skipped
|
||||
// identities is allowed.
|
||||
return user != replicate && skipped_identities.count(user) == 0;
|
||||
});
|
||||
|
||||
if (!is_region_capture || has_other_use_in_body) {
|
||||
replicate_args.clear();
|
||||
break;
|
||||
}
|
||||
while_args.push_back(while_op.getOperand(block_arg.getArgNumber()));
|
||||
replicate_args.push_back(replicate_operand);
|
||||
}
|
||||
if (while_args.empty()) continue;
|
||||
if (replicate_args.empty()) continue;
|
||||
// Now set the enable_xla_sharding field in the metadata to inform the
|
||||
// compile op.
|
||||
auto metadata_arg = metadata.mutable_args(it->second);
|
||||
metadata_arg->set_enable_xla_sharding(
|
||||
::tensorflow::tpu::TPUCompileMetadataProto_Arg::ALLOWED);
|
||||
mapping.emplace_back(it->second, std::move(while_args));
|
||||
mapping.emplace_back(it->second, std::move(replicate_args));
|
||||
}
|
||||
// Sort the mapping according to execute operand order.
|
||||
llvm::sort(mapping, llvm::less_first());
|
||||
@ -261,7 +268,8 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
|
||||
|
||||
// Adds a new replicated input to the replicate op.
|
||||
tf_device::ReplicateOp AddInputsToReplicateOp(
|
||||
tf_device::ReplicateOp replicate, ArrayRef<Value> new_inputs,
|
||||
tf_device::ReplicateOp replicate,
|
||||
MutableArrayRef<TF::VarHandleOp> new_inputs,
|
||||
const llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>&
|
||||
devices) {
|
||||
int64_t num_replicas = replicate.n();
|
||||
@ -293,7 +301,11 @@ tf_device::ReplicateOp AddInputsToReplicateOp(
|
||||
new_packed_inputs.emplace_back(
|
||||
replicate.GetReplicaOperandForBlockArgument(arg, /*replica=*/0));
|
||||
}
|
||||
new_replicated_inputs.emplace_back(new_inputs, new_inputs.front().getType());
|
||||
SmallVector<Value, 4> new_input_values;
|
||||
new_input_values.reserve(new_inputs.size());
|
||||
for (auto var : new_inputs) new_input_values.push_back(var.resource());
|
||||
new_replicated_inputs.emplace_back(new_input_values,
|
||||
new_input_values.front().getType());
|
||||
OpBuilder builder(replicate);
|
||||
auto new_replicate = builder.create<tf_device::ReplicateOp>(
|
||||
replicate.getLoc(), num_replicas, devices, new_replicated_inputs,
|
||||
@ -319,58 +331,6 @@ tf_device::ReplicateOp AddInputsToReplicateOp(
|
||||
return new_replicate;
|
||||
}
|
||||
|
||||
// Adds the per-device state variables to the while-loop's inputs/outputs.
|
||||
TF::WhileOp AddStateVarsToWhileOp(TF::WhileOp while_op, FuncOp body,
|
||||
FuncOp cond,
|
||||
ArrayRef<TF::VarHandleOp> state_vars) {
|
||||
auto body_return = llvm::cast<ReturnOp>(body.front().back());
|
||||
auto new_body_return_vals = llvm::to_vector<4>(body_return.getOperands());
|
||||
auto new_while_operands = llvm::to_vector<4>(while_op.getOperands());
|
||||
auto append_types = [&](ArrayRef<Type> types) {
|
||||
auto new_types = llvm::to_vector<4>(types);
|
||||
for (auto state_var : state_vars) {
|
||||
new_types.push_back(state_var.resource().getType());
|
||||
}
|
||||
return new_types;
|
||||
};
|
||||
for (auto state_var : state_vars) {
|
||||
body.front().addArgument(state_var.resource().getType());
|
||||
cond.front().addArgument(state_var.resource().getType());
|
||||
auto inner_arg = body.getArgument(body.front().getNumArguments() - 1);
|
||||
new_body_return_vals.push_back(inner_arg);
|
||||
new_while_operands.push_back(state_var.resource());
|
||||
}
|
||||
OpBuilder builder = OpBuilder::atBlockEnd(&body.front());
|
||||
// Update return values.
|
||||
builder.create<ReturnOp>(body_return.getLoc(), new_body_return_vals);
|
||||
body_return.erase();
|
||||
|
||||
body.setType(FunctionType::get(append_types(body.getType().getInputs()),
|
||||
append_types(body.getType().getResults()),
|
||||
body.getContext()));
|
||||
cond.setType(FunctionType::get(append_types(cond.getType().getInputs()),
|
||||
cond.getType().getResults(),
|
||||
cond.getContext()));
|
||||
for (int64_t i = 0, end = state_vars.size(); i < end; ++i) {
|
||||
int64_t arg_index = body.getNumArguments() - state_vars.size() + i;
|
||||
TF::VarHandleOp state_var = state_vars[i];
|
||||
auto device_attr = state_var.getAttr(kDeviceAttr);
|
||||
if (device_attr) {
|
||||
body.setArgAttr(arg_index, kFuncDeviceAttr, device_attr);
|
||||
cond.setArgAttr(arg_index, kFuncDeviceAttr, device_attr);
|
||||
}
|
||||
}
|
||||
builder.setInsertionPoint(while_op);
|
||||
auto new_while_op = builder.create<TF::WhileOp>(
|
||||
while_op.getLoc(),
|
||||
append_types(llvm::to_vector<4>(while_op.getResultTypes())),
|
||||
new_while_operands, while_op.getAttrs());
|
||||
while_op.replaceAllUsesWith(
|
||||
new_while_op.getResults().take_front(while_op.getNumResults()));
|
||||
while_op.erase();
|
||||
return new_while_op;
|
||||
}
|
||||
|
||||
// Creates the per-device variables that represent the formatting state of each
|
||||
// device.
|
||||
llvm::SmallVector<TF::VarHandleOp, 4> CreateStateVars(
|
||||
@ -421,8 +381,8 @@ void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op,
|
||||
}
|
||||
|
||||
// Performs the transformation for a replicate op inside a while loop.
|
||||
void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
|
||||
MLIRContext* context) {
|
||||
void HandleReplicateOp(TF::WhileRegionOp while_op,
|
||||
tf_device::ReplicateOp replicate) {
|
||||
int64_t num_replicas = replicate.n();
|
||||
if (num_replicas == 1) return;
|
||||
tf_device::LaunchOp execute_launch;
|
||||
@ -452,13 +412,10 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
|
||||
!llvm::isa<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front()))
|
||||
return;
|
||||
|
||||
FuncOp body = while_op.body_function();
|
||||
FuncOp cond = while_op.cond_function();
|
||||
|
||||
// Analyze the formattable inputs.
|
||||
auto execute_arg_to_outer_args =
|
||||
AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
|
||||
while_op, replicate, execute, compile_launch, body, cond);
|
||||
while_op, replicate, execute, compile_launch);
|
||||
if (execute_arg_to_outer_args.empty()) return;
|
||||
|
||||
// Extract the replicated devices.
|
||||
@ -489,16 +446,7 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
|
||||
RankedTensorType::get({2}, TF::StringType::get(builder.getContext()));
|
||||
auto state_vars =
|
||||
CreateStateVars(devices, while_op.getLoc(), key_type, &builder);
|
||||
while_op = AddStateVarsToWhileOp(while_op, body, cond, state_vars);
|
||||
// Add the new while loop inputs to the replicate op inside the body.
|
||||
int64_t new_while_operand_count = while_op.getNumOperands();
|
||||
llvm::SmallVector<Value, 4> inner_state_vars;
|
||||
for (int64_t i = new_while_operand_count - num_replicas;
|
||||
i < new_while_operand_count; ++i) {
|
||||
inner_state_vars.push_back(body.front().getArgument(i));
|
||||
}
|
||||
|
||||
replicate = AddInputsToReplicateOp(replicate, inner_state_vars, devices);
|
||||
replicate = AddInputsToReplicateOp(replicate, state_vars, devices);
|
||||
// Build the reformat according to the compilation. Build it inside
|
||||
// `replicate`.
|
||||
llvm::SmallVector<Value, 8> reformat_operands;
|
||||
@ -576,10 +524,9 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
|
||||
|
||||
void TPUVariableRuntimeReformattingPass::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
module.walk([&](TF::WhileOp while_op) {
|
||||
auto body = llvm::cast<FuncOp>(module.lookupSymbol(while_op.body()));
|
||||
module.walk([&](TF::WhileRegionOp while_op) {
|
||||
tf_device::ReplicateOp replicate;
|
||||
body.walk([&](tf_device::ReplicateOp replicate_op) {
|
||||
while_op.body().walk([&](tf_device::ReplicateOp replicate_op) {
|
||||
if (replicate == nullptr) {
|
||||
replicate = replicate_op;
|
||||
return WalkResult::advance();
|
||||
@ -592,7 +539,7 @@ void TPUVariableRuntimeReformattingPass::runOnOperation() {
|
||||
// `tf_device.parallel_execute` op in the `tf_device.replicate` is present.
|
||||
if (replicate &&
|
||||
replicate.GetBody().getOps<tf_device::ParallelExecuteOp>().empty())
|
||||
HandleReplicateOp(while_op, replicate, &getContext());
|
||||
HandleReplicateOp(while_op, replicate);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -137,7 +137,7 @@ void LowerCase(TF::CaseOp op, ModuleOp module) {
|
||||
auto tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
|
||||
|
||||
// Create replica of input tuple for each branch
|
||||
SmallVector<Value, 4> n_tuple_inputs(op.branches().size(), tuple_input);
|
||||
SmallVector<Value, 4> n_tuple_inputs(op.num_branches(), tuple_input);
|
||||
|
||||
// Create the new case op with tuple inputs.
|
||||
auto case_op =
|
||||
@ -145,9 +145,8 @@ void LowerCase(TF::CaseOp op, ModuleOp module) {
|
||||
n_tuple_inputs, op.branches().size());
|
||||
|
||||
// Import the regions for all branches.
|
||||
for (unsigned i = 0; i < op.branches().size(); ++i) {
|
||||
mlir::FuncOp branch_func = module.lookupSymbol<mlir::FuncOp>(
|
||||
op.branches()[i].cast<SymbolRefAttr>());
|
||||
for (unsigned i = 0; i < op.num_branches(); ++i) {
|
||||
mlir::FuncOp branch_func = op.branch_function(i);
|
||||
ImportXlaRegion(branch_func, &case_op.branches()[i], loc,
|
||||
/*tuple_return=*/false);
|
||||
}
|
||||
|
@ -586,6 +586,7 @@ foreach Mapping = [
|
||||
[TF_RealOp, HLO_RealOp],
|
||||
[TF_RsqrtOp, HLO_RsqrtOp],
|
||||
[TF_SigmoidOp, HLO_LogisticOp],
|
||||
[TF_SinhOp, HLOClient_SinhOp],
|
||||
[TF_SinOp, HLO_SinOp],
|
||||
[TF_SqrtOp, HLO_SqrtOp],
|
||||
[TF_TanhOp, HLO_TanhOp],
|
||||
|
@ -436,8 +436,8 @@ Status LhloDialectEmitter::Initialize() {
|
||||
}
|
||||
}
|
||||
|
||||
FunctionType function_type = builder_.getFunctionType(
|
||||
llvm::to_vector<8>(block->getArgumentTypes()), {});
|
||||
FunctionType function_type =
|
||||
builder_.getFunctionType(block->getArgumentTypes(), {});
|
||||
func_op.setType(function_type);
|
||||
func_op.setAllArgAttrs(args_attrs);
|
||||
|
||||
|
@ -31,6 +31,7 @@ import six
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import gen_random_ops
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
@ -54,6 +55,16 @@ def _igammac(a, x):
|
||||
return math_ops.igammac(a, x)
|
||||
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def _polygamma(n, x):
|
||||
return math_ops.polygamma(n, x)
|
||||
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def _zeta(a, q):
|
||||
return math_ops.zeta(a, q)
|
||||
|
||||
|
||||
# This is df/da / df/dx, where f = igamma.
|
||||
def implicit_reparameterization_grad(a, x):
|
||||
log_prob = math_ops.xlogy(a - 1., x) - math_ops.lgamma(a) - x
|
||||
@ -136,6 +147,208 @@ class Log1pTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
self._test_range(0., 3., dtype, rtol, atol, is_negative=False)
|
||||
|
||||
|
||||
class ZetaTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if flags.FLAGS.vary_seed:
|
||||
entropy = os.urandom(64)
|
||||
if six.PY2:
|
||||
answer = int(entropy.encode('hex'), 16)
|
||||
else:
|
||||
answer = int.from_bytes(entropy, 'big')
|
||||
np.random.seed(answer % (2**32 - 1))
|
||||
super(ZetaTest, self).setUp()
|
||||
|
||||
def adjust_tolerance_for_tpu(self, dtype, rtol, atol):
|
||||
if self.device not in ['TPU']:
|
||||
return rtol, atol
|
||||
|
||||
if dtype == np.float32:
|
||||
return 2e-2, 1e-7
|
||||
return 2e-4, 1e-20
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
|
||||
def testBadValues(self):
|
||||
q = np.random.uniform(low=0.3, high=20., size=[10])
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
y = _zeta(np.float64(1.), q)
|
||||
actual = sess.run(y)
|
||||
# When x == 1, this is the Harmonic series.
|
||||
self.assertTrue(np.all(np.isinf(actual)))
|
||||
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
y = _zeta(np.float64(0.1), q)
|
||||
actual = sess.run(y)
|
||||
# When x < 1, this is undefined.
|
||||
self.assertTrue(np.all(np.isnan(actual)))
|
||||
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
y = _zeta([1., 1.1], [-1.1, -1.])
|
||||
actual = sess.run(y)
|
||||
|
||||
# When q is negative, zeta is not defined
|
||||
# if q is an integer or x is not an integer.
|
||||
self.assertTrue(np.all(np.isinf(actual)))
|
||||
|
||||
@parameterized.parameters((np.float32, 1e-2, 1e-11),
|
||||
(np.float64, 1e-4, 1e-30))
|
||||
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
|
||||
def testLargeXSmallQ(self, dtype, rtol, atol):
|
||||
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
|
||||
if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64:
|
||||
# TODO(b/165739664): Figure out why on TPU F64 Zeta sometimes returns
|
||||
# infs.
|
||||
self.skipTest(
|
||||
'Skipping test because some F64 operations are numerically '
|
||||
'unstable on TPU.')
|
||||
|
||||
x = np.random.uniform(low=100., high=200., size=[NUM_SAMPLES]).astype(dtype)
|
||||
q = np.random.uniform(low=0.3, high=1., size=[NUM_SAMPLES]).astype(dtype)
|
||||
|
||||
expected_values = sps.zeta(x, q)
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
y = _zeta(x, q)
|
||||
actual = sess.run(y)
|
||||
|
||||
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
|
||||
|
||||
@parameterized.parameters((np.float32, 1e-2, 1e-11),
|
||||
(np.float64, 1e-4, 1e-30))
|
||||
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
|
||||
def testSmallValues(self, dtype, rtol, atol):
|
||||
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
|
||||
# Test values near zero.
|
||||
x = np.random.uniform(low=1.1, high=10., size=[NUM_SAMPLES]).astype(dtype)
|
||||
q = np.random.uniform(
|
||||
low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype)
|
||||
|
||||
expected_values = sps.zeta(x, q)
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
actual = sess.run(_zeta(x, q))
|
||||
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
|
||||
|
||||
@parameterized.parameters((np.float32, 1e-2, 1e-11),
|
||||
(np.float64, 1e-4, 1e-30))
|
||||
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
|
||||
def testMediumValues(self, dtype, rtol, atol):
|
||||
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
|
||||
x = np.random.uniform(low=1.1, high=100., size=[NUM_SAMPLES]).astype(dtype)
|
||||
q = np.random.uniform(low=1., high=1e1, size=[NUM_SAMPLES]).astype(dtype)
|
||||
|
||||
expected_values = sps.zeta(x, q)
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
actual = sess.run(_zeta(x, q))
|
||||
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
|
||||
|
||||
@parameterized.parameters((np.float32, 2e-2, 1e-5), (np.float64, 1e-4, 1e-30))
|
||||
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
|
||||
def testLargeValues(self, dtype, rtol, atol):
|
||||
x = np.random.uniform(
|
||||
low=100., high=int(1e3), size=[NUM_SAMPLES]).astype(dtype)
|
||||
q = np.random.uniform(
|
||||
low=1., high=int(1e1), size=[NUM_SAMPLES]).astype(dtype)
|
||||
|
||||
expected_values = sps.zeta(x, q)
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
actual = sess.run(_zeta(x, q))
|
||||
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
class PolygammaTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if flags.FLAGS.vary_seed:
|
||||
entropy = os.urandom(64)
|
||||
if six.PY2:
|
||||
answer = int(entropy.encode('hex'), 16)
|
||||
else:
|
||||
answer = int.from_bytes(entropy, 'big')
|
||||
np.random.seed(answer % (2**32 - 1))
|
||||
super(PolygammaTest, self).setUp()
|
||||
|
||||
def adjust_tolerance_for_tpu(self, dtype, rtol, atol):
|
||||
if self.device not in ['TPU']:
|
||||
return rtol, atol
|
||||
|
||||
if dtype == np.float32:
|
||||
return 2e-2, 1e-7
|
||||
return 2e-4, 1e-20
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
|
||||
def testBadValues(self):
|
||||
x = np.random.uniform(low=0.3, high=20., size=[10])
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
y = _polygamma(np.float64(-1.), x)
|
||||
actual = sess.run(y)
|
||||
# Not defined for negative numbers.
|
||||
self.assertTrue(np.all(np.isnan(actual)))
|
||||
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
y = _polygamma(np.float64(0.1), x)
|
||||
actual = sess.run(y)
|
||||
# Not defined for non-integers.
|
||||
self.assertTrue(np.all(np.isnan(actual)))
|
||||
|
||||
@parameterized.parameters((np.float32, 1e-2, 1e-11),
|
||||
(np.float64, 1e-4, 1e-30))
|
||||
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
|
||||
def testRecoverDigamma(self, dtype, rtol, atol):
|
||||
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
|
||||
if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64:
|
||||
self.skipTest(
|
||||
'Skipping test because some F64 operations are '
|
||||
'numerically unstable on TPU.'
|
||||
)
|
||||
|
||||
x = np.random.uniform(low=0.1, high=50., size=[NUM_SAMPLES]).astype(dtype)
|
||||
expected_values = sps.digamma(x)
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
y = _polygamma(dtype(0.), x)
|
||||
actual = sess.run(y)
|
||||
|
||||
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
|
||||
|
||||
@parameterized.parameters((np.float32, 1e-2, 1e-11),
|
||||
(np.float64, 1e-4, 1e-30))
|
||||
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
|
||||
def testSmallN(self, dtype, rtol, atol):
|
||||
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
|
||||
# Test values near zero.
|
||||
n = np.random.randint(low=1, high=5, size=[NUM_SAMPLES]).astype(dtype)
|
||||
x = np.random.uniform(
|
||||
low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype)
|
||||
|
||||
expected_values = sps.polygamma(n, x)
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
actual = sess.run(_polygamma(n, x))
|
||||
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
|
||||
|
||||
@parameterized.parameters((np.float32, 1e-2, 1e-11),
|
||||
(np.float64, 1e-4, 1e-30))
|
||||
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
|
||||
def testMediumLargeN(self, dtype, rtol, atol):
|
||||
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
|
||||
n = np.random.randint(low=5, high=10, size=[NUM_SAMPLES]).astype(dtype)
|
||||
x = np.random.uniform(low=1., high=1e1, size=[NUM_SAMPLES]).astype(dtype)
|
||||
|
||||
expected_values = sps.polygamma(n, x)
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
actual = sess.run(_polygamma(n, x))
|
||||
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
class IgammaTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -290,6 +290,21 @@ xla::XlaOp IgammacImpl(xla::XlaOp x, xla::XlaOp y,
|
||||
|
||||
XLA_MAKE_BINARY(Igammac, IgammacImpl(lhs, rhs, broadcast_helper));
|
||||
|
||||
xla::XlaOp PolygammaImpl(xla::XlaOp n, xla::XlaOp x,
|
||||
const BCast& broadcast_helper) {
|
||||
std::tie(n, x) = XlaBinaryOp::Broadcast(n, x, broadcast_helper);
|
||||
return xla::Polygamma(n, x);
|
||||
}
|
||||
|
||||
XLA_MAKE_BINARY(Polygamma, PolygammaImpl(lhs, rhs, broadcast_helper));
|
||||
|
||||
xla::XlaOp ZetaImpl(xla::XlaOp x, xla::XlaOp q, const BCast& broadcast_helper) {
|
||||
std::tie(x, q) = XlaBinaryOp::Broadcast(x, q, broadcast_helper);
|
||||
return xla::Zeta(x, q);
|
||||
}
|
||||
|
||||
XLA_MAKE_BINARY(Zeta, ZetaImpl(lhs, rhs, broadcast_helper));
|
||||
|
||||
#undef XLA_MAKE_BINARY
|
||||
|
||||
class ApproximateEqualOp : public XlaOpKernel {
|
||||
|
@ -206,6 +206,8 @@ igamma = _broadcasting_binary_op(math_ops.igamma)
|
||||
igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a)
|
||||
random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad)
|
||||
igammac = _broadcasting_binary_op(math_ops.igammac)
|
||||
polygamma = _broadcasting_binary_op(math_ops.polygamma)
|
||||
zeta = _broadcasting_binary_op(math_ops.zeta)
|
||||
|
||||
|
||||
def _binary_op(fn):
|
||||
|
@ -1832,4 +1832,139 @@ XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x) {
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp Polygamma(XlaOp n, XlaOp x) {
|
||||
auto& builder = *x.builder();
|
||||
auto doit = [](XlaOp n, XlaOp x, PrimitiveType type) -> XlaOp {
|
||||
XlaOp n_plus_one = n + ScalarLike(n, 1.);
|
||||
XlaOp sign =
|
||||
(ScalarLike(n, 2.) * Rem(n, ScalarLike(n, 2.)) - ScalarLike(n, 1.));
|
||||
|
||||
const double nan = std::numeric_limits<double>::quiet_NaN();
|
||||
|
||||
XlaOp output = Select(Eq(n, ScalarLike(n, 0.)), Digamma(x),
|
||||
sign * Exp(Lgamma(n_plus_one)) * Zeta(n_plus_one, x));
|
||||
// Check that n is a natural number.
|
||||
output = Select(Or(Ne(n, Floor(n)), Lt(n, ScalarLike(n, 0.))),
|
||||
ScalarLike(n, nan), output);
|
||||
return output;
|
||||
};
|
||||
return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(auto n_shape, builder.GetShape(n));
|
||||
TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x));
|
||||
if (n_shape != x_shape) {
|
||||
return InvalidArgument(
|
||||
"Arguments to Polygamma must have equal shapes and types; "
|
||||
"got %s and %s",
|
||||
n_shape.ToString(), x_shape.ToString());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Zeta", x));
|
||||
bool needs_upcast =
|
||||
n_shape.element_type() == F16 || x_shape.element_type() == BF16;
|
||||
|
||||
if (needs_upcast) {
|
||||
n = ConvertElementType(n, F32);
|
||||
x = ConvertElementType(x, F32);
|
||||
}
|
||||
XlaOp result = doit(n, x, n_shape.element_type());
|
||||
if (needs_upcast) {
|
||||
result = ConvertElementType(result, n_shape.element_type());
|
||||
}
|
||||
return result;
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp Zeta(XlaOp x, XlaOp q) {
|
||||
auto& builder = *x.builder();
|
||||
auto doit = [&builder](XlaOp x, XlaOp q, PrimitiveType type) -> XlaOp {
|
||||
// (2k) ! / B_{2k}, where B_{2k} are the Bernoulli numbers.
|
||||
// These are ordered in reverse.
|
||||
static const std::array<double, 12> kZetaCoeffs{
|
||||
-7.1661652561756670113e18,
|
||||
1.8152105401943546773e17,
|
||||
-4.5979787224074726105e15,
|
||||
1.1646782814350067249e14,
|
||||
-2.950130727918164224e12,
|
||||
7.47242496e10,
|
||||
-1.8924375803183791606e9,
|
||||
47900160.0,
|
||||
-1209600.0,
|
||||
30240.0,
|
||||
-720.0,
|
||||
12.0,
|
||||
};
|
||||
|
||||
// For speed we'll always use 9 iterations for the initial series estimate,
|
||||
// and a 12 term expansion for the Euler-Maclaurin formula.
|
||||
|
||||
XlaOp a = q;
|
||||
XlaOp neg_power = ScalarLike(a, 0.);
|
||||
XlaOp initial_sum = Pow(q, Neg(x));
|
||||
for (int i = 0; i < 9; ++i) {
|
||||
a = a + ScalarLike(a, 1.);
|
||||
neg_power = Pow(a, Neg(x));
|
||||
initial_sum = initial_sum + neg_power;
|
||||
}
|
||||
a = a + ScalarLike(a, 1.);
|
||||
neg_power = Pow(a, Neg(x));
|
||||
XlaOp s = initial_sum + neg_power * a / (x - ScalarLike(a, 1.));
|
||||
XlaOp a_inverse_square = Reciprocal(Square(a));
|
||||
XlaOp horner_sum = ScalarLike(a, 0.);
|
||||
XlaOp factor = ScalarLike(a, 1.);
|
||||
// Use Horner's rule for this.
|
||||
// Note this differs from Cephes which does a 'naive' polynomial evaluation.
|
||||
// Using Horner's rule allows to avoid some NaN's and Infs from happening,
|
||||
// resulting in more numerically stable code.
|
||||
for (int i = 0; i < 11; ++i) {
|
||||
factor =
|
||||
(x - ScalarLike(x, 22 - 2 * i)) * (x - ScalarLike(x, 21 - 2 * i));
|
||||
horner_sum = factor * a_inverse_square *
|
||||
(horner_sum + ScalarLike(a, 1. / kZetaCoeffs[i]));
|
||||
}
|
||||
s = s + neg_power *
|
||||
(ScalarLike(neg_power, 0.5) +
|
||||
x / a * (ScalarLike(a, 1. / kZetaCoeffs[11]) + horner_sum));
|
||||
|
||||
const double nan = std::numeric_limits<double>::quiet_NaN();
|
||||
const double inf = std::numeric_limits<double>::infinity();
|
||||
// Use the initial zeta sum without the correction term coming
|
||||
// from Euler-Maclaurin if it is accurate enough.
|
||||
XlaOp output =
|
||||
Select(Lt(Abs(neg_power), Abs(initial_sum) * Epsilon(&builder, type)),
|
||||
initial_sum, s);
|
||||
// This is the harmonic series.
|
||||
output = Select(Eq(x, ScalarLike(x, 1.)), ScalarLike(x, inf), output);
|
||||
// Function is not defined for x < 1.
|
||||
output = Select(Lt(x, ScalarLike(x, 1.)), ScalarLike(x, nan), output);
|
||||
// If q <= 0, then when q is an integer or x is not an integer, this is
|
||||
// NaN.
|
||||
XlaOp domain_error = And(Le(q, ScalarLike(x, 0.)), Ne(x, Floor(x)));
|
||||
XlaOp negative_integer_q = And(Le(q, ScalarLike(x, 0.)), Eq(q, Floor(q)));
|
||||
output = Select(negative_integer_q, ScalarLike(x, inf), output);
|
||||
output = Select(domain_error, ScalarLike(x, nan), output);
|
||||
return output;
|
||||
};
|
||||
return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x));
|
||||
TF_ASSIGN_OR_RETURN(auto q_shape, builder.GetShape(q));
|
||||
if (x_shape != q_shape) {
|
||||
return InvalidArgument(
|
||||
"Arguments to Zeta must have equal shapes and types; got %s and %s",
|
||||
x_shape.ToString(), q_shape.ToString());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Zeta", x));
|
||||
bool needs_upcast =
|
||||
x_shape.element_type() == F16 || x_shape.element_type() == BF16;
|
||||
|
||||
if (needs_upcast) {
|
||||
x = ConvertElementType(x, F32);
|
||||
q = ConvertElementType(q, F32);
|
||||
}
|
||||
XlaOp result = doit(x, q, x_shape.element_type());
|
||||
if (needs_upcast) {
|
||||
result = ConvertElementType(result, x_shape.element_type());
|
||||
}
|
||||
return result;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -72,6 +72,12 @@ XlaOp RandomGammaGrad(XlaOp a, XlaOp x);
|
||||
// Computes an approximation of the complementary incomplete gamma function.
|
||||
XlaOp Igammac(XlaOp a, XlaOp x);
|
||||
|
||||
// Computes the Polygamma of two arguments.
|
||||
XlaOp Polygamma(XlaOp n, XlaOp x);
|
||||
|
||||
// Computes the Riemann zeta function of two arguments.
|
||||
XlaOp Zeta(XlaOp x, XlaOp q);
|
||||
|
||||
// Rounds the given number to even when the number is equidistant between two
|
||||
// integers.
|
||||
XlaOp RoundToEven(XlaOp x);
|
||||
|
@ -239,7 +239,7 @@ struct CacheEntry {
|
||||
// a signature and if the object has been insterted already, other threads
|
||||
// will wait for the notification.
|
||||
absl::Notification compilation_complete;
|
||||
absl::optional<std::exception> compilation_error = absl::nullopt;
|
||||
absl::optional<Status> compilation_error = absl::nullopt;
|
||||
};
|
||||
|
||||
// A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the
|
||||
@ -314,7 +314,7 @@ class CompiledFunction {
|
||||
// absl::optional<absl::Notification> is not supported
|
||||
bool first_compilation_started_ = false;
|
||||
absl::Notification first_compilation_complete_;
|
||||
absl::optional<std::exception> first_compilation_error_ = absl::nullopt;
|
||||
absl::optional<Status> first_compilation_error_ = absl::nullopt;
|
||||
};
|
||||
|
||||
CompiledFunction::CompiledFunction(py::function fun,
|
||||
@ -646,7 +646,8 @@ CacheEntry& CompiledFunction::GetCacheEntry(
|
||||
py::gil_scoped_release gil_release;
|
||||
found_iterator->second->compilation_complete.WaitForNotification();
|
||||
if (found_iterator->second->compilation_error) {
|
||||
throw found_iterator->second->compilation_error.value();
|
||||
throw std::invalid_argument(
|
||||
found_iterator->second->compilation_error.value().error_message());
|
||||
}
|
||||
}
|
||||
return *(found_iterator->second);
|
||||
@ -671,8 +672,8 @@ CacheEntry& CompiledFunction::SetAndReturnCacheEntry(
|
||||
} else {
|
||||
try {
|
||||
executable_and_pytree = cache_miss_fun_(*args, **kwargs);
|
||||
} catch (const std::exception& e) {
|
||||
cache_entry.compilation_error = e;
|
||||
} catch (const py::error_already_set& e) {
|
||||
cache_entry.compilation_error = InvalidArgument("%s", e.what());
|
||||
cache_entry.compilation_complete.Notify();
|
||||
throw;
|
||||
}
|
||||
@ -736,16 +737,17 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
|
||||
if (!first_compilation_complete_.HasBeenNotified()) {
|
||||
py::gil_scoped_release gil_release;
|
||||
first_compilation_complete_.WaitForNotification();
|
||||
if (first_compilation_error_) {
|
||||
throw first_compilation_error_.value();
|
||||
}
|
||||
}
|
||||
if (first_compilation_error_) {
|
||||
throw std::invalid_argument(
|
||||
first_compilation_error_.value().error_message());
|
||||
}
|
||||
} else {
|
||||
first_compilation_started_ = true;
|
||||
try {
|
||||
cache_miss_result = cache_miss_fun_(*args, **kwargs);
|
||||
} catch (const std::exception& e) {
|
||||
first_compilation_error_ = e;
|
||||
} catch (const py::error_already_set& e) {
|
||||
first_compilation_error_ = InvalidArgument("%s", e.what());
|
||||
first_compilation_complete_.Notify();
|
||||
throw;
|
||||
}
|
||||
@ -754,9 +756,14 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
|
||||
|
||||
pyclient_ = executable->client();
|
||||
default_device_ = executable->LocalDevices()[0].contents;
|
||||
if (!default_device_) {
|
||||
throw std::invalid_argument(
|
||||
"executable->LocalDevices()[0] should not be null!");
|
||||
}
|
||||
first_compilation_complete_.Notify();
|
||||
}
|
||||
}
|
||||
CHECK(default_device_);
|
||||
|
||||
// The C++ jit do not support Tracers arguments yet. The Python-based jit
|
||||
// function will be called if any of the dynamic arguments is unsupported.
|
||||
|
@ -291,6 +291,7 @@ void BuildOpsSubmodule(py::module* m) {
|
||||
ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x"));
|
||||
ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"),
|
||||
py::arg("b"), py::arg("x"));
|
||||
ops.def("Zeta", &Zeta, py::arg("x"), py::arg("q"));
|
||||
|
||||
#define BINARY_OP(op) \
|
||||
ops.def( \
|
||||
|
@ -56,17 +56,21 @@ StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
|
||||
XlaOp a, PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
const int n_dims = a_shape.rank();
|
||||
const int ndims = a_shape.rank();
|
||||
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||
std::vector<int64> error_dims(a_shape.dimensions().begin(),
|
||||
a_shape.dimensions().end());
|
||||
error_dims.back() = error_dims.at(ndims - 2) = 1;
|
||||
|
||||
auto major_dims = AsInt64Slice(a_shape.dimensions())
|
||||
.subspan(
|
||||
/*pos=*/0,
|
||||
/*len=*/n_dims - 2);
|
||||
/*len=*/ndims - 2);
|
||||
|
||||
auto matrix_dims = AsInt64Slice(a_shape.dimensions())
|
||||
.subspan(
|
||||
/*pos=*/0,
|
||||
/*len=*/n_dims);
|
||||
/*len=*/ndims);
|
||||
|
||||
XlaOp l = ZerosLike(a);
|
||||
|
||||
@ -79,9 +83,9 @@ StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
|
||||
auto body_l = loop_vars[1];
|
||||
auto seen_error = loop_vars[2];
|
||||
auto iota_row =
|
||||
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 1);
|
||||
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 1);
|
||||
auto iota_col =
|
||||
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 2);
|
||||
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 2);
|
||||
|
||||
auto mask_pred = Ge(iota_col, iota_row);
|
||||
mask_pred = And(mask_pred, Eq(iota_row, i));
|
||||
@ -91,25 +95,32 @@ StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
|
||||
// L * L.T, This matrix has of a lot of multiplying with zero
|
||||
// (namely, L[:, j:] = 0) and redundant computation, but it is faster
|
||||
// than slice.
|
||||
auto l_square = BatchDot(body_l, false, body_l, true, precision);
|
||||
auto l_square =
|
||||
BatchDot(body_l, false, MaybeConjugate(body_l, true), true, precision);
|
||||
|
||||
// A - L*L.T
|
||||
l_square = body_a - l_square;
|
||||
auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1});
|
||||
l_ii = Sqrt(l_ii);
|
||||
if (ShapeUtil::ElementIsComplex(a_shape)) {
|
||||
auto sqrt = Sqrt(Real(l_ii));
|
||||
l_ii = Complex(sqrt, ZerosLike(sqrt));
|
||||
seen_error = Or(seen_error, IsNan(sqrt));
|
||||
} else {
|
||||
l_ii = Sqrt(l_ii);
|
||||
seen_error = Or(seen_error, IsNan(l_ii));
|
||||
}
|
||||
// L = (A - L*L.T) / l_ii * mask + L
|
||||
body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l;
|
||||
|
||||
seen_error =
|
||||
Or(seen_error, Any(Or(Le(l_ii, ZerosLike(l_ii)), IsNan(l_ii))));
|
||||
|
||||
return std::vector<XlaOp>{body_a, body_l, seen_error};
|
||||
};
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto cholesky_while,
|
||||
ForEachIndex(n, S32, body_fn, {a, l, ConstantR0<bool>(builder, false)},
|
||||
"unblocked", builder));
|
||||
ForEachIndex(
|
||||
n, S32, body_fn,
|
||||
{a, l, Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims))},
|
||||
"unblocked", builder));
|
||||
|
||||
return std::make_pair(cholesky_while[1], cholesky_while[2]);
|
||||
}
|
||||
@ -133,23 +144,23 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
|
||||
ShapeUtil::HumanString(a_shape));
|
||||
}
|
||||
|
||||
if (primitive_util::IsComplexType(a_shape.element_type())) {
|
||||
return Unimplemented(
|
||||
"Complex types are not implemented in Cholesky; got shape %s",
|
||||
ShapeUtil::HumanString(a_shape));
|
||||
}
|
||||
|
||||
if (block_size < 1) {
|
||||
return InvalidArgument(
|
||||
"block_size argument to Cholesky must be >= 1; got %d", block_size);
|
||||
}
|
||||
|
||||
std::vector<int64> error_dims(a_shape.dimensions().begin(),
|
||||
a_shape.dimensions().end());
|
||||
error_dims.back() = error_dims.at(ndims - 2) = 1;
|
||||
std::vector<int64> error_dim_indices(ndims);
|
||||
absl::c_iota(error_dim_indices, 0);
|
||||
|
||||
// Blocked left-looking Cholesky factorization.
|
||||
// Algorithm 1 from
|
||||
// Haidar, Azzam, et al. "High-performance Cholesky factorization for
|
||||
// GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
|
||||
XlaOp l = ZerosLike(a);
|
||||
XlaOp seen_error = ConstantR0<bool>(builder, false);
|
||||
XlaOp seen_error = Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims));
|
||||
for (int64 i = 0; i < n; i += block_size) {
|
||||
int64 k = std::min(block_size, n - i);
|
||||
auto panel = SliceInMinorDims(a, {i, i}, {n, i + k});
|
||||
@ -159,7 +170,8 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
|
||||
// a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i]))
|
||||
auto lhs = SliceInMinorDims(l, {i, 0}, {n, i});
|
||||
auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i});
|
||||
auto delta = BatchDot(lhs, false, rhs, true, precision);
|
||||
auto delta =
|
||||
BatchDot(lhs, false, MaybeConjugate(rhs, true), true, precision);
|
||||
panel = panel - delta;
|
||||
}
|
||||
|
||||
@ -170,8 +182,14 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
|
||||
// other elements.
|
||||
XlaOp factorized_error;
|
||||
if (k == 1) {
|
||||
factorized = Sqrt(x);
|
||||
factorized_error = Any(IsNan(factorized));
|
||||
if (ShapeUtil::ElementIsComplex(a_shape)) {
|
||||
auto sqrt = Sqrt(Real(x));
|
||||
factorized = Complex(sqrt, ZerosLike(sqrt));
|
||||
factorized_error = IsNan(sqrt);
|
||||
} else {
|
||||
factorized = Sqrt(x);
|
||||
factorized_error = IsNan(factorized);
|
||||
}
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(auto tile_output, CholeskyUnblocked(x, precision));
|
||||
std::tie(factorized, factorized_error) = tile_output;
|
||||
@ -187,12 +205,13 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
|
||||
/*left_side=*/false,
|
||||
/*lower=*/true,
|
||||
/*unit_diagonal=*/false,
|
||||
/*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
|
||||
/*transpose_a=*/TriangularSolveOptions::ADJOINT);
|
||||
l = UpdateSliceInMinorDims(l, update, {i + k, i});
|
||||
}
|
||||
}
|
||||
return Select(seen_error,
|
||||
FullLike(l, std::numeric_limits<float>::quiet_NaN()), l);
|
||||
return Select(
|
||||
BroadcastInDim(seen_error, a_shape.dimensions(), error_dim_indices),
|
||||
FullLike(l, std::numeric_limits<float>::quiet_NaN()), l);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -113,28 +113,47 @@ int64 CountNonLeafOps(const OpCollection& ops) {
|
||||
// of reuses This is used as a placeholder only, assuming all
|
||||
// instructions can be fused to enable data reuses
|
||||
int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
|
||||
// Reuses in some way work like forces that pull instructions
|
||||
// towards each other. We use a number 0-10 to classify how strong the force
|
||||
// is between a pair of operations. Given a group of instructions that can be
|
||||
// moved together, if the forces inside a conditional are stronger, the group
|
||||
// will be moved incide or remain inside the conditional; otherwise, it will
|
||||
// be moved outside to or remain outside of the conditional.
|
||||
VLOG(2) << "ConditionalCodeMotion: Add reuses carried by instr: "
|
||||
<< op->ToString() << "=>" << user->ToString() << "\n";
|
||||
switch (user->opcode()) {
|
||||
case HloOpcode::kGetTupleElement:
|
||||
case HloOpcode::kTuple:
|
||||
return 0;
|
||||
case HloOpcode::kConvert:
|
||||
// Because convert is treated not moveable when following Dot or
|
||||
// convolution, here if op is dot or convolution, they must be separated
|
||||
// by a conditional boundary. Here we do not try to pull convert inside
|
||||
// conditionals to be together with the dot or convolution.
|
||||
switch (op->opcode()) {
|
||||
case HloOpcode::kConvolution:
|
||||
case HloOpcode::kDot:
|
||||
return 0;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
switch (op->opcode()) {
|
||||
// These instructions are lightweight and easy to fuse.
|
||||
// These instructions do not carry weight of reuse themselves.
|
||||
case HloOpcode::kParameter:
|
||||
case HloOpcode::kConstant:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
return 0;
|
||||
case HloOpcode::kConditional:
|
||||
return 10;
|
||||
default:
|
||||
// Assume fusion will not happen anyway if user count > 1)
|
||||
if (CountNonLeafOps(op->users()) > 1) {
|
||||
return 0;
|
||||
}
|
||||
return 10;
|
||||
default: {
|
||||
// Assume the reuse decreases with increasing user count.
|
||||
int count1 = CountNonLeafOps(op->users());
|
||||
int count2 = CountNonLeafOps(user->operands());
|
||||
return 10 / count1 / count2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -192,17 +211,35 @@ Status CopyInOrOutOfConditional(
|
||||
absl::InlinedVector<HloInstruction*, 4> new_operands;
|
||||
for (int i = 0; i < op->operands().size(); ++i) {
|
||||
auto op_i = op->operands()[i];
|
||||
VLOG(2) << "Looking for operand:" << op_i->ToString() << "\n";
|
||||
VLOG(2) << "Looking for " << op_i->ToString() << "\n";
|
||||
if (ContainsKey(hoisted_instructions, op_i)) {
|
||||
auto new_op_i =
|
||||
FindOrDie(hoisted_instructions, op_i).operands()[dest_index];
|
||||
VLOG(2) << "new operand:" << new_op_i->ToString() << "\n";
|
||||
VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
|
||||
new_operands.push_back(new_op_i);
|
||||
} else {
|
||||
CHECK(op_i->opcode() == HloOpcode::kConstant);
|
||||
auto new_op_i = parent->AddInstruction(op_i->Clone());
|
||||
VLOG(2) << "new operand:" << new_op_i->ToString() << "\n";
|
||||
new_operands.push_back(new_op_i);
|
||||
switch (op_i->opcode()) {
|
||||
case HloOpcode::kConstant: {
|
||||
auto new_op_i = parent->AddInstruction(op_i->Clone());
|
||||
VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
|
||||
new_operands.push_back(new_op_i);
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kGetTupleElement: {
|
||||
auto gte = Cast<HloGetTupleElementInstruction>(op_i);
|
||||
int64 index = gte->tuple_index();
|
||||
HloInstruction* root = parent->root_instruction();
|
||||
CHECK(root->opcode() == HloOpcode::kTuple &&
|
||||
index < root->operand_count());
|
||||
auto new_op_i = root->mutable_operand(index);
|
||||
VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
|
||||
new_operands.push_back(new_op_i);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "Unexpected out-of-boundary instruction:"
|
||||
<< op_i->ToString() << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
HloInstruction* new_instruction = parent->AddInstruction(
|
||||
@ -492,6 +529,7 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
|
||||
int64 index = tuple_opd->tuple_index();
|
||||
CHECK(old_root->operands().size() > index);
|
||||
HloInstruction* old_opd = old_root->operands()[index];
|
||||
VLOG(2) << "old opd = " << old_opd << "\n";
|
||||
CHECK(ContainsKey(hoisted_instructions, old_opd));
|
||||
HloInstruction* new_opd = hoisted_instructions[old_opd].operands()[0];
|
||||
CHECK(old_opd != nullptr);
|
||||
@ -535,6 +573,7 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
|
||||
HloInstruction* new_root =
|
||||
conditional->branch_computation(0)->root_instruction();
|
||||
*conditional->mutable_shape() = new_root->shape();
|
||||
|
||||
//
|
||||
VLOG(1) << "done moving instructions out of branches\n"
|
||||
<< conditional_parent->ToString(HloPrintOptions::Fingerprint())
|
||||
@ -558,16 +597,26 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionIn(
|
||||
absl::flat_hash_map<HloInstruction*, Boundary> hoisted_instructions;
|
||||
int64 to_move_in_size = to_move_in.size();
|
||||
int64 branch_count = conditional->branch_count();
|
||||
HloGetTupleElementInstruction* tuple_use =
|
||||
DynCast<HloGetTupleElementInstruction>(to_move_in[0].operands()[0]);
|
||||
// If use_index is -1, the old conditional root entry used by to_move_in
|
||||
// instructions still need to be included as an entry of the modified
|
||||
// conditional root, and the new result of the to_move_in instructions
|
||||
// need to be added as an extra entry of the modified root; otherwise, the
|
||||
// old root entry will be replaced with the new result in the modified root.
|
||||
// The entry replacement should be allowed only if tuple_use has <=1 users.
|
||||
int64 use_index = (tuple_use != nullptr && tuple_use->user_count() == 1)
|
||||
? tuple_use->tuple_index()
|
||||
: -1;
|
||||
VLOG(2) << "Tuple use index = " << use_index << "\n";
|
||||
// Number of old conditional entries still to be used outside.
|
||||
// If conditional shape is not tuple, will create a tuple and use subscript
|
||||
// 0 to save the old operand being used.
|
||||
int64 op_index = conditional->shape().IsTuple()
|
||||
? conditional->shape().tuple_shapes_size() - 1
|
||||
: 0;
|
||||
HloGetTupleElementInstruction* tuple_use =
|
||||
dynamic_cast<HloGetTupleElementInstruction*>(to_move_in[0].operands()[0]);
|
||||
int64 use_index = (tuple_use != nullptr) ? tuple_use->tuple_index() : -1;
|
||||
VLOG(2) << "Tuple use index = " << use_index << "\n";
|
||||
int64 op_index =
|
||||
conditional->shape().IsTuple()
|
||||
? ((use_index >= 0) ? conditional->shape().tuple_shapes_size() - 1
|
||||
: conditional->shape().tuple_shapes_size())
|
||||
: 0;
|
||||
// Use to map the tuple_use instruction to its operand;
|
||||
Boundary b_opd_use(Boundary::Position::kInsideBranch);
|
||||
Boundary b_old_root(Boundary::Position::kInsideBranch);
|
||||
@ -628,26 +677,29 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionIn(
|
||||
hoisted_instructions[conditional] = b_old_root;
|
||||
int64 cp_start = 0;
|
||||
if (use_index >= 0) {
|
||||
VLOG(2) << "Mapping GTE: " << tuple_use->ToString() << "\n";
|
||||
hoisted_instructions[tuple_use] = b_opd_use;
|
||||
cp_start = 1;
|
||||
}
|
||||
for (int64 i = cp_start; i < to_move_in_size; i++) {
|
||||
Boundary b_to_move = to_move_in[i];
|
||||
cp_start = (tuple_use != nullptr) ? 1 : 0;
|
||||
for (int64 to_move_index = cp_start; to_move_index < to_move_in_size;
|
||||
to_move_index++) {
|
||||
Boundary b_to_move = to_move_in[to_move_index];
|
||||
HloInstruction* op = b_to_move.operands()[0];
|
||||
CHECK(op != nullptr);
|
||||
bool to_be_used_outside = true;
|
||||
VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
|
||||
if (i < to_move_in_size - 1 && op->user_count() == 1 &&
|
||||
op->users()[0] == to_move_in[i + 1].operands()[0]) {
|
||||
if (to_move_index < to_move_in_size - 1 && op->user_count() == 1 &&
|
||||
op->users()[0] == to_move_in[to_move_index + 1].operands()[0]) {
|
||||
to_be_used_outside = false;
|
||||
VLOG(2) << "Instruction is not to be used outside the branch\n";
|
||||
}
|
||||
Boundary b(Boundary::Position::kInsideBranch);
|
||||
for (int i = 0; i < branch_count; i++) {
|
||||
auto computation = conditional->branch_computation(i);
|
||||
VLOG(2) << "Copying to branch: " << i << "\n";
|
||||
TF_RETURN_IF_ERROR(CopyInOrOutOfConditional(b_to_move, i, computation,
|
||||
hoisted_instructions));
|
||||
VLOG(2) << "After Copying to branch: " << computation->ToString() << "\n";
|
||||
VLOG(2) << "Done:" << computation->ToString() << "\n";
|
||||
if (to_be_used_outside) {
|
||||
auto new_op = hoisted_instructions[op].operands()[i];
|
||||
auto new_root = computation->root_instruction();
|
||||
@ -681,18 +733,19 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionIn(
|
||||
// Remove hoisted instructions from the branches.
|
||||
for (int64 i = to_move_in_size - 1; i >= 0; i--) {
|
||||
Boundary boundary_to_move_in = to_move_in[i];
|
||||
VLOG(2) << "Removing boundary:" << boundary_to_move_in.ToString() << "\n";
|
||||
HloInstruction* op = boundary_to_move_in.operands()[0];
|
||||
for (auto user : op->users()) {
|
||||
VLOG(2) << "Has User: " << user->ToString() << "\n";
|
||||
if (op->user_count() == 0) {
|
||||
VLOG(2) << "Removing boundary:" << boundary_to_move_in.ToString() << "\n";
|
||||
TF_RETURN_IF_ERROR(conditional->parent()->RemoveInstruction(op));
|
||||
VLOG(2) << "Done removing boundary.\n";
|
||||
}
|
||||
TF_RETURN_IF_ERROR(conditional->parent()->RemoveInstruction(op));
|
||||
}
|
||||
|
||||
// Reset shapes of user gtes to the new shape.
|
||||
if (use_index != -1) {
|
||||
for (auto* user : conditional->users()) {
|
||||
if (user->opcode() == HloOpcode::kGetTupleElement) {
|
||||
VLOG(2) << "Resetting shape of user: " << user->ToString() << "\n";
|
||||
*user->mutable_shape() =
|
||||
conditional->shape().tuple_shapes(user->tuple_index());
|
||||
}
|
||||
@ -712,14 +765,21 @@ class GroupConnectedBoundaries {
|
||||
HloComputation* conditional_parent_;
|
||||
bool is_layout_sensitive_;
|
||||
// Instructions that have been visited but are not going to be moved.
|
||||
absl::flat_hash_set<HloInstruction*> visited_;
|
||||
absl::flat_hash_map<HloInstruction*, int>& visited_count_;
|
||||
|
||||
public:
|
||||
explicit GroupConnectedBoundaries(HloInstruction* conditional,
|
||||
bool is_layout_sensitive)
|
||||
explicit GroupConnectedBoundaries(
|
||||
HloInstruction* conditional, bool is_layout_sensitive,
|
||||
absl::flat_hash_map<HloInstruction*, int>& visited_count)
|
||||
: conditional_(conditional),
|
||||
conditional_parent_(conditional->parent()),
|
||||
is_layout_sensitive_(is_layout_sensitive) {}
|
||||
is_layout_sensitive_(is_layout_sensitive),
|
||||
visited_count_(visited_count) {}
|
||||
void clear_recently_visited() {
|
||||
for (const auto& boundary : new_boundaries_) {
|
||||
visited_count_.erase(boundary.operands()[0]);
|
||||
}
|
||||
}
|
||||
// Returns true if `instruction` is worth hoisting.
|
||||
bool WorthHoisting(HloInstruction* instruction) {
|
||||
// This is needed for the "moving-in" transformation, to prevent the root
|
||||
@ -736,19 +796,26 @@ class GroupConnectedBoundaries {
|
||||
// ops such as Dot or Convolutional, it is better to keep convert
|
||||
// within conditional so that convert can be fused with Dot or
|
||||
// Convolutional.
|
||||
//
|
||||
// TODO(b/154283721): figure out the scenario when convert can be
|
||||
// fused with AllReduce out of conditional.
|
||||
switch (instruction->operand(0)->opcode()) {
|
||||
case HloOpcode::kAllReduce:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
return true;
|
||||
default:
|
||||
VLOG(2) << "Instruction is convert and its operand is not know to "
|
||||
VLOG(2) << "Instruction is convert and its operand is not known to "
|
||||
"be worth hoisting\n";
|
||||
return false;
|
||||
}
|
||||
case HloOpcode::kGetTupleElement:
|
||||
switch (instruction->operand(0)->opcode()) {
|
||||
// do not move GTE if its operand is a parameter
|
||||
case HloOpcode::kParameter:
|
||||
return false;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
case HloOpcode::kAllReduce:
|
||||
case HloOpcode::kReduce:
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kPower:
|
||||
case HloOpcode::kCopy:
|
||||
@ -758,8 +825,10 @@ class GroupConnectedBoundaries {
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kTuple:
|
||||
case HloOpcode::kSqrt:
|
||||
case HloOpcode::kRsqrt:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kMaximum:
|
||||
return true;
|
||||
default:
|
||||
VLOG(2) << "Instruction is not known to be worth hoisting\n";
|
||||
@ -772,14 +841,20 @@ class GroupConnectedBoundaries {
|
||||
// The operand must be an instruction that is not going to be moved (if
|
||||
// user is inside the conditional); otherwise it must be the conditional
|
||||
// itself and its user must be outside of the conditional.
|
||||
if (!ContainsKey(visited_, op) && op != conditional_) {
|
||||
if (!ContainsKey(visited_count_, op) && op != conditional_) {
|
||||
continue;
|
||||
}
|
||||
// Only consider single-user cases as reuseable.
|
||||
if (user->opcode() == HloOpcode::kGetTupleElement &&
|
||||
user->user_count() == 1) {
|
||||
if (auto tuple_gte = DynCast<HloGetTupleElementInstruction>(user)) {
|
||||
if (op->opcode() == HloOpcode::kConditional) {
|
||||
auto tuple = op->branch_computation(0)->root_instruction();
|
||||
if (tuple->opcode() == HloOpcode::kTuple) {
|
||||
auto index = tuple_gte->tuple_index();
|
||||
CHECK(index < tuple->operand_count());
|
||||
op = tuple->mutable_operand(index);
|
||||
}
|
||||
}
|
||||
reuses += ReusesCarriedBy(op, user->users()[0]);
|
||||
} else if (op->user_count() == 1) {
|
||||
} else {
|
||||
reuses += ReusesCarriedBy(op, user);
|
||||
}
|
||||
}
|
||||
@ -797,6 +872,7 @@ class GroupConnectedBoundaries {
|
||||
// some aspects of the overall algorithm need to be redesigned to
|
||||
// accommandate the change.
|
||||
if (all_users.size() > 1) {
|
||||
VLOG(2) << "Having multiple users from: " << user->ToString() << "\n";
|
||||
return 0;
|
||||
}
|
||||
if (!all_users.empty()) {
|
||||
@ -818,7 +894,7 @@ class GroupConnectedBoundaries {
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (ContainsKey(visited_, op)) {
|
||||
} else if (ContainsKey(visited_count_, op)) {
|
||||
reuses += ReusesCarriedBy(user, op);
|
||||
}
|
||||
VLOG(2) << "reuses after instruction " << user->ToString() << ":"
|
||||
@ -866,6 +942,50 @@ class GroupConnectedBoundaries {
|
||||
}
|
||||
return b2;
|
||||
}
|
||||
|
||||
// Checking whether it is safe to move a boundary when visited through a
|
||||
// dependent already considered for moving.
|
||||
bool IsSafeToMoveBoundary(const Boundary& next_boundary) {
|
||||
int64 next_boundary_count =
|
||||
(next_boundary.IsInsideBranch())
|
||||
? next_boundary.operands()[0]->user_count()
|
||||
: CountNonLeafOps(next_boundary.operands()[0]->operands());
|
||||
if (next_boundary_count <= 1) {
|
||||
// If boundary has only a single or no dependent, safe to move.
|
||||
return true;
|
||||
} else {
|
||||
if (!ContainsKey(visited_count_, next_boundary.operands()[0])) {
|
||||
VLOG(2) << "Skip next boundary " << next_boundary.ToString() << "\n"
|
||||
<< " because it has multiple dependents: "
|
||||
<< next_boundary_count << "\n";
|
||||
visited_count_[next_boundary.operands()[0]] = 1;
|
||||
new_boundaries_.push_back(next_boundary);
|
||||
} else {
|
||||
auto pos = std::find(new_boundaries_.begin(), new_boundaries_.end(),
|
||||
next_boundary);
|
||||
if (pos != new_boundaries_.end() ||
|
||||
next_boundary.operands().size() == 1) {
|
||||
int count = ++visited_count_[next_boundary.operands()[0]];
|
||||
if (count == next_boundary_count) {
|
||||
VLOG(2) << "Recovering next boundary " << next_boundary.ToString()
|
||||
<< "\n"
|
||||
<< " because all of its dependents have been visited: "
|
||||
<< next_boundary_count << "\n";
|
||||
visited_count_.erase(next_boundary.operands()[0]);
|
||||
if (pos != new_boundaries_.end()) {
|
||||
new_boundaries_.erase(pos);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
VLOG(2) << "Skip incompatible multi-dependent boundary: "
|
||||
<< next_boundary.ToString() << ":" << next_boundary_count
|
||||
<< "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// This function is reused both for moving the boundary outside or into a
|
||||
// conditional. As the result, the readability is somewhat compromised.
|
||||
// It might be nice to refactor this function to factor the outside-inside
|
||||
@ -879,7 +999,7 @@ class GroupConnectedBoundaries {
|
||||
VLOG(2) << "visiting boundary " << b.ToString() << "\n";
|
||||
if ((b.IsOutsideBranch() || InstructionWithinBranchIdentical(
|
||||
b.operands(), is_layout_sensitive_)) &&
|
||||
WorthHoisting(b.operands()[0])) {
|
||||
IsSafeToMoveBoundary(b) && WorthHoisting(b.operands()[0])) {
|
||||
connected_boundaries_.push_back(b);
|
||||
VLOG(2) << "boundary can be moved\n";
|
||||
int64 operand_count = (b.IsInsideBranch())
|
||||
@ -887,26 +1007,12 @@ class GroupConnectedBoundaries {
|
||||
: b.operands()[0]->users().size();
|
||||
for (int i = 0; i < operand_count; i++) {
|
||||
Boundary next_boundary = GetNextBoundary(b, i);
|
||||
int64 next_boundary_count =
|
||||
(next_boundary.IsInsideBranch())
|
||||
? next_boundary.operands()[0]->user_count()
|
||||
: CountNonLeafOps(next_boundary.operands()[0]->operands());
|
||||
// only consider adding an exclusive producor into the same group.
|
||||
if (next_boundary_count == 1) {
|
||||
VLOG(2) << "Add operand " << i << " to visit later\n";
|
||||
visitor.AddToWorkList(next_boundary);
|
||||
} else {
|
||||
VLOG(2) << "Next boundary " << i
|
||||
<< " has multiple uses: " << next_boundary_count << "\n";
|
||||
if (!ContainsKey(visited_, next_boundary.operands()[0])) {
|
||||
visited_.insert(next_boundary.operands()[0]);
|
||||
new_boundaries_.push_back(next_boundary);
|
||||
}
|
||||
}
|
||||
VLOG(2) << "Add operand/user " << i << " to visit later\n";
|
||||
visitor.AddToWorkList(next_boundary);
|
||||
}
|
||||
} else {
|
||||
VLOG(2) << "boundary cannot be moved\n";
|
||||
visited_.insert(b.operands()[0]);
|
||||
visited_count_[b.operands()[0]] = 1;
|
||||
new_boundaries_.push_back(b);
|
||||
}
|
||||
}
|
||||
@ -948,8 +1054,10 @@ class GroupConnectedBoundaries {
|
||||
|
||||
ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
|
||||
HloInstruction* conditional, const Boundary& cur_boundary,
|
||||
std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries) {
|
||||
GroupConnectedBoundaries connect(conditional, is_layout_sensitive_);
|
||||
std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries,
|
||||
absl::flat_hash_map<HloInstruction*, int>& visited_count) {
|
||||
GroupConnectedBoundaries connect(conditional, is_layout_sensitive_,
|
||||
visited_count);
|
||||
auto move_in_or_out =
|
||||
connect.BoundariesToMoveInOrOut(conditional, cur_boundary);
|
||||
if (!move_in_or_out.empty()) {
|
||||
@ -964,16 +1072,21 @@ ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
|
||||
// at the first entry of the sequence is sufficient to know which
|
||||
// direction the move is intended.
|
||||
to_move = move_in_or_out;
|
||||
return to_move[0].IsInsideBranch() ? Decision::kMoveOutOfBranch
|
||||
: Decision::kMoveIntoBranch;
|
||||
return Decision(to_move[0].IsInsideBranch()
|
||||
? Decision::Direction::kMoveOutOfBranch
|
||||
: Decision::Direction::kMoveIntoBranch,
|
||||
benefit);
|
||||
} else {
|
||||
connect.clear_recently_visited();
|
||||
}
|
||||
} else {
|
||||
connect.AddNewBoundaries(new_boundaries);
|
||||
}
|
||||
return ConditionalCodeMotion::Decision::kNoChange;
|
||||
return Decision(Decision::Direction::kNoChange, 0);
|
||||
}
|
||||
|
||||
StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
||||
VLOG(2) << "Begin a new pass of conditional code motion optimization.\n";
|
||||
bool changed = false;
|
||||
bool cleanup_changed = false;
|
||||
{
|
||||
@ -1018,6 +1131,8 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
||||
}
|
||||
}
|
||||
|
||||
// Use to collect mappings between cloned instructions.
|
||||
HloCloneContext clone_context(module);
|
||||
for (HloInstruction* conditional : conditional_ops) {
|
||||
int branch_count = conditional->branch_count();
|
||||
// check for shared conditional computations
|
||||
@ -1031,7 +1146,13 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
||||
}
|
||||
|
||||
// Boundaries to move out or to move into the branches.
|
||||
std::vector<Boundary> to_move_out, to_move_in, new_boundaries;
|
||||
std::vector<std::vector<Boundary> > to_move_out, to_move_in;
|
||||
std::vector<std::vector<Boundary> > new_boundaries_for_moveout;
|
||||
std::vector<std::vector<Boundary> > new_boundaries_for_movein;
|
||||
// Number of times each instruction has been visited for moving.
|
||||
absl::flat_hash_map<HloInstruction*, int> visited_count;
|
||||
int benefit_move_out = 0, benefit_move_in = 0;
|
||||
Decision::Direction final_d = Decision::Direction::kNoChange;
|
||||
// The conditional is moved into a worklist as the seed (starting point).
|
||||
// The conditional will be expanded into multiple seeds (starting points),
|
||||
// its roots and its users, when it is visited by GroupConnectedBoundaries.
|
||||
@ -1039,76 +1160,114 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
||||
// so that the other seeding boundaries can be visited in turn.
|
||||
BoundaryVisitor visitor(conditional);
|
||||
VLOG(2) << "Analyzing conditional:" << conditional->ToString() << "\n";
|
||||
ConditionalCodeMotion::Decision d = Decision::kNoChange;
|
||||
// The following loop breaks out as soon as a decision to modify the
|
||||
// conditional is reached --- irrespective of whether visitor is empty.
|
||||
while (d == Decision::kNoChange && visitor.HasNextBoundary()) {
|
||||
// Try visit all the boundaries, collect the analysis results, and save
|
||||
// all the benefitical non-conflicting decisions. If two decisions conflict
|
||||
// with each other, save the more benefitical one.
|
||||
while (visitor.HasNextBoundary()) {
|
||||
std::vector<Boundary> to_move, next_boundary;
|
||||
Boundary boundary = visitor.PopNextBoundary();
|
||||
VLOG(2) << "Analyzing boundary:" << boundary.ToString() << "\n";
|
||||
d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary);
|
||||
if (d != Decision::kNoChange && conditional_is_shared) {
|
||||
for (int i = 0; i < branch_count; ++i) {
|
||||
HloComputation* branch_i = conditional->branch_computation(i);
|
||||
if (conditional_computations[branch_i] > 0) {
|
||||
// Cloning is absolutely needed if the computation is shared by
|
||||
// different branches, but the cloning can be potentially avoided
|
||||
// if the sharing is only among branches of the same conditional.
|
||||
// If cloning these branches causes a problem due to space issues,
|
||||
// a fix can pass a vector of unique branches to the actual
|
||||
// transformations, as an alternative representation of the
|
||||
// conditional branches to be modified. Right now we assume the
|
||||
// overhead of cloning is minimal since later stages of the compiler
|
||||
// inline all the computations anyway.
|
||||
HloComputation* clone_i =
|
||||
conditional->parent()->parent()->AddEmbeddedComputation(
|
||||
branch_i->Clone());
|
||||
conditional->set_branch_computation(i, clone_i);
|
||||
conditional_computations[branch_i]--;
|
||||
auto d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary,
|
||||
visited_count);
|
||||
switch (d.GetDirection()) {
|
||||
case Decision::Direction::kMoveOutOfBranch:
|
||||
VLOG(2) << "Local Decision is move out of branch\n";
|
||||
to_move_out.push_back(to_move);
|
||||
new_boundaries_for_moveout.push_back(next_boundary);
|
||||
benefit_move_out += d.GetBenefit();
|
||||
if (benefit_move_out >= benefit_move_in) {
|
||||
final_d = Decision::Direction::kMoveOutOfBranch;
|
||||
VLOG(2) << "Current Decision is move out of branch\n";
|
||||
} else {
|
||||
VLOG(2) << "Current Decision remains move into branch\n";
|
||||
}
|
||||
}
|
||||
to_move.clear();
|
||||
next_boundary.clear();
|
||||
VLOG(2) << "Cloned branches as needed: " << conditional->ToString()
|
||||
<< "\n";
|
||||
// Need to reanalyze the cloned code to generate correct result.
|
||||
d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary);
|
||||
}
|
||||
switch (d) {
|
||||
case Decision::kMoveOutOfBranch:
|
||||
VLOG(2) << "Decision is move out of branch\n";
|
||||
to_move_out.insert(to_move_out.end(), to_move.begin(), to_move.end());
|
||||
new_boundaries.insert(new_boundaries.end(), next_boundary.begin(),
|
||||
next_boundary.end());
|
||||
break;
|
||||
case Decision::kMoveIntoBranch:
|
||||
case Decision::Direction::kMoveIntoBranch:
|
||||
VLOG(2) << "Decision is move into branch\n";
|
||||
to_move_in.insert(to_move_in.end(), to_move.begin(), to_move.end());
|
||||
new_boundaries.insert(new_boundaries.end(), next_boundary.begin(),
|
||||
next_boundary.end());
|
||||
to_move_in.push_back(to_move);
|
||||
new_boundaries_for_movein.push_back(next_boundary);
|
||||
benefit_move_in += d.GetBenefit();
|
||||
if (benefit_move_out >= benefit_move_in) {
|
||||
VLOG(2) << "Current Decision remains move out of branch\n";
|
||||
} else {
|
||||
final_d = Decision::Direction::kMoveIntoBranch;
|
||||
VLOG(2) << "Current Decision is move into branch\n";
|
||||
}
|
||||
break;
|
||||
case Decision::kNoChange:
|
||||
case Decision::Direction::kNoChange:
|
||||
VLOG(2) << "Decision is no change\n";
|
||||
for (const Boundary& b : next_boundary) {
|
||||
visitor.AddToWorkList(b);
|
||||
VLOG(2) << "Adding new boundary to worklist:" << b.ToString()
|
||||
<< "\n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
// If modification is to be made, need to clone the shared branches.
|
||||
if (final_d != Decision::Direction::kNoChange && conditional_is_shared) {
|
||||
for (int i = 0; i < branch_count; ++i) {
|
||||
HloComputation* branch_i = conditional->branch_computation(i);
|
||||
if (conditional_computations[branch_i] > 0) {
|
||||
// Cloning is absolutely needed if the computation is shared by
|
||||
// different branches, but the cloning can be potentially avoided
|
||||
// if the sharing is only among branches of the same conditional.
|
||||
// If cloning these branches causes a problem due to space issues,
|
||||
// a fix can pass a vector of unique branches to the actual
|
||||
// transformations, as an alternative representation of the
|
||||
// conditional branches to be modified. Right now we assume the
|
||||
// overhead of cloning is minimal since later stages of the compiler
|
||||
// inline all the computations anyway.
|
||||
HloComputation* clone_i =
|
||||
conditional->parent()->parent()->AddEmbeddedComputation(
|
||||
branch_i->Clone("clone", &clone_context));
|
||||
conditional->set_branch_computation(i, clone_i);
|
||||
conditional_computations[branch_i]--;
|
||||
// Need to translate the analysis result to generate correct result.
|
||||
auto update_boundary = [&](Boundary& boundary) {
|
||||
auto cloned_instr =
|
||||
clone_context.FindInstruction(boundary.operands()[i]);
|
||||
CHECK(cloned_instr != nullptr);
|
||||
VLOG(2) << "boundary before cloning:" << boundary.operands()[i]
|
||||
<< "\n";
|
||||
boundary.mutable_operands()[i] = cloned_instr;
|
||||
VLOG(2) << "boundary after cloning:" << boundary.operands()[i]
|
||||
<< "\n";
|
||||
};
|
||||
// Only boundaries to move out need to be updated.
|
||||
if (final_d == Decision::Direction::kMoveOutOfBranch) {
|
||||
for (int i = 0; i < to_move_out.size(); ++i) {
|
||||
std::vector<Boundary>& m = to_move_out[i];
|
||||
std::for_each(m.begin(), m.end(), update_boundary);
|
||||
}
|
||||
for (int i = 0; i < new_boundaries_for_moveout.size(); ++i) {
|
||||
std::vector<Boundary>& m = new_boundaries_for_moveout[i];
|
||||
std::for_each(m.begin(), m.end(), update_boundary);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
VLOG(2) << "Cloned branches as needed: " << conditional->ToString()
|
||||
<< "\n";
|
||||
}
|
||||
// At most one of to_move_out or to_move_in can be non-empty, since there is
|
||||
// only one optimization decision.
|
||||
if (!to_move_out.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool result,
|
||||
MoveInstructionOut(conditional, to_move_out, new_boundaries));
|
||||
VLOG(2) << "moving out result:" << result << "\n";
|
||||
changed |= result;
|
||||
} else if (!to_move_in.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool result,
|
||||
MoveInstructionIn(conditional, to_move_in, new_boundaries));
|
||||
VLOG(2) << "moving in result:" << result << "\n";
|
||||
changed |= result;
|
||||
if (final_d == Decision::Direction::kMoveOutOfBranch) {
|
||||
CHECK(to_move_out.size() == new_boundaries_for_moveout.size());
|
||||
for (int i = 0; i < to_move_out.size(); ++i) {
|
||||
TF_ASSIGN_OR_RETURN(bool result,
|
||||
MoveInstructionOut(conditional, to_move_out[i],
|
||||
new_boundaries_for_moveout[i]));
|
||||
changed |= result;
|
||||
}
|
||||
} else if (final_d == Decision::Direction::kMoveIntoBranch) {
|
||||
CHECK(to_move_in.size() == new_boundaries_for_movein.size());
|
||||
for (int i = 0; i < to_move_in.size(); ++i) {
|
||||
TF_ASSIGN_OR_RETURN(bool result,
|
||||
MoveInstructionIn(conditional, to_move_in[i],
|
||||
new_boundaries_for_movein[i]));
|
||||
changed |= result;
|
||||
}
|
||||
} else if (pursue_full_conditional_code_motion_ && !conditional_is_shared) {
|
||||
// Invoke special handling for convert rematerialization/hoisting
|
||||
// We need to make sure no sharing is present in the branches because no
|
||||
@ -1118,11 +1277,13 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
||||
bool convert_result,
|
||||
ConvertSpecialMove(conditional, is_layout_sensitive_));
|
||||
changed |= convert_result;
|
||||
VLOG(2) << "Done special moving of convert\n";
|
||||
}
|
||||
}
|
||||
if (changed) {
|
||||
HloPassPipeline subpipeline(
|
||||
"after_conditional_code_motion_after_convert_hoisting");
|
||||
VLOG(2) << "starting after motion passes: DCE\n";
|
||||
subpipeline.AddPass<HloDCE>();
|
||||
subpipeline.AddPass<TupleSimplifier>();
|
||||
subpipeline.AddPass<HloDCE>();
|
||||
|
@ -52,6 +52,9 @@ class Boundary {
|
||||
}
|
||||
return res;
|
||||
}
|
||||
bool operator==(const Boundary& that) {
|
||||
return ContainersEqual(operands_, that.operands_);
|
||||
}
|
||||
|
||||
private:
|
||||
// Boundary instructions in the conditional branches, one from each branch
|
||||
@ -78,13 +81,30 @@ class ConditionalCodeMotion : public HloModulePass {
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
// Optimization decision for each boundary of the conditional instruction.
|
||||
enum class Decision { kMoveOutOfBranch, kMoveIntoBranch, kNoChange };
|
||||
class Decision {
|
||||
public:
|
||||
enum class Direction : uint8 {
|
||||
kMoveOutOfBranch,
|
||||
kMoveIntoBranch,
|
||||
kNoChange
|
||||
};
|
||||
|
||||
public:
|
||||
Decision(Direction direction, int benefit)
|
||||
: direction_(direction), benefit_(benefit) {}
|
||||
Direction GetDirection() const { return direction_; }
|
||||
int GetBenefit() const { return benefit_; }
|
||||
|
||||
private:
|
||||
Direction direction_;
|
||||
int benefit_;
|
||||
};
|
||||
// If the optimization decision is NO_CHANGE, new_boundary is set to nullptr;
|
||||
// otherwise, it is set to the new boundary after proposed optimization.
|
||||
virtual Decision ConsiderCodeMotion(HloInstruction* conditional,
|
||||
const Boundary& cur_boundary,
|
||||
std::vector<Boundary>& to_move,
|
||||
std::vector<Boundary>& new_boundaries);
|
||||
virtual Decision ConsiderCodeMotion(
|
||||
HloInstruction* conditional, const Boundary& cur_boundary,
|
||||
std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries,
|
||||
absl::flat_hash_map<HloInstruction*, int>& visited_count);
|
||||
|
||||
private:
|
||||
const bool is_layout_sensitive_;
|
||||
|
@ -234,17 +234,16 @@ ENTRY main {
|
||||
const HloInstruction* conditional =
|
||||
FindInstruction(module.get(), "conditional");
|
||||
const HloComputation* on_true = conditional->branch_computation(0);
|
||||
ASSERT_EQ(on_true->instruction_count(), 2);
|
||||
ASSERT_EQ(on_true->instruction_count(), 1);
|
||||
const HloComputation* on_false = conditional->branch_computation(1);
|
||||
ASSERT_EQ(on_false->instruction_count(), 2);
|
||||
ASSERT_EQ(on_false->instruction_count(), 1);
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(
|
||||
root,
|
||||
AllOf(op::Tuple(op::Add(op::Convert(op::Reshape(op::GetTupleElement(
|
||||
op::GetTupleElement(op::Conditional())))),
|
||||
op::Convert(op::Reshape(op::GetTupleElement(
|
||||
op::GetTupleElement(op::Conditional()))))))));
|
||||
AllOf(op::Tuple(op::Add(
|
||||
op::Convert(op::Reshape(op::GetTupleElement(op::Conditional()))),
|
||||
op::Convert(op::Reshape(op::GetTupleElement(op::Conditional())))))));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, UserShareOperandCannotBeMoved) {
|
||||
@ -335,7 +334,7 @@ on_false {
|
||||
get-tuple-element.2 = f32[] get-tuple-element(arg_tuple.2), index=0
|
||||
constant.3 = f32[] constant(1)
|
||||
constant.4 = f32[] constant(2)
|
||||
add.4 = f32[] add(get-tuple-element.2, constant.3)
|
||||
add.4 = f32[] add(constant.4, constant.3)
|
||||
add.5 = f32[] add(get-tuple-element.2, constant.4)
|
||||
add.6 = f32[] add(add.4, add.5)
|
||||
ROOT tuple.4 = (f32[]) tuple(add.6)
|
||||
@ -360,7 +359,7 @@ ENTRY main {
|
||||
const HloComputation* on_true = conditional->branch_computation(0);
|
||||
ASSERT_EQ(on_true->instruction_count(), 1);
|
||||
const HloComputation* on_false = conditional->branch_computation(1);
|
||||
ASSERT_EQ(on_false->instruction_count(), 1);
|
||||
ASSERT_EQ(on_false->instruction_count(), 3);
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(
|
||||
@ -543,6 +542,7 @@ ENTRY main {
|
||||
pred.1 = pred[] parameter(0)
|
||||
arg_tuple.3 = (bf16[2,54,168,128], bf16[2,52,168,128]) parameter(1)
|
||||
arg_tuple.4 = (bf16[2,86,104,128], bf16[2,84,104,128]) parameter(2)
|
||||
arg_tuple.5 = f32[3,3,128,128] parameter(3)
|
||||
conditional = (f32[3,3,128,128])
|
||||
conditional(pred.1, arg_tuple.3, arg_tuple.4), true_computation=on_true,
|
||||
false_computation=on_false
|
||||
@ -557,6 +557,7 @@ ENTRY main {
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
const HloInstruction* conditional =
|
||||
FindInstruction(module.get(), "conditional");
|
||||
CHECK(conditional != nullptr);
|
||||
const HloComputation* on_true = conditional->branch_computation(0);
|
||||
ASSERT_EQ(on_true->instruction_count(), 5);
|
||||
const HloComputation* on_false = conditional->branch_computation(1);
|
||||
@ -619,7 +620,7 @@ ENTRY main {
|
||||
EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional())));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, NoMoveInWithMultipleGTE) {
|
||||
TEST_F(ConditionalCodeMotionTest, MoveInWithMultipleGTE) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule RemoveIdenticalInstruction
|
||||
@ -647,19 +648,19 @@ ENTRY main {
|
||||
false_computation=on_false
|
||||
get-first-index = f32[10] get-tuple-element(conditional), index=0
|
||||
get-first-index.2 = f32[10] get-tuple-element(conditional), index=0
|
||||
pow.1 = f32[10] power(get-first-index, get-first-index)
|
||||
pow.1 = f32[10] power(get-first-index, get-first-index.2)
|
||||
ROOT tuple.3 = (f32[10], f32[10]) tuple(pow.1, get-first-index.2)
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root,
|
||||
op::Tuple(op::Power(), op::GetTupleElement(op::Conditional())));
|
||||
EXPECT_THAT(root, op::Tuple(op::GetTupleElement(op::Conditional()),
|
||||
op::GetTupleElement(op::Conditional())));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, MovePowInWithSharedBranch) {
|
||||
TEST_F(ConditionalCodeMotionTest, MoveOutWithSharedBranch) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule RemoveIdenticalInstruction
|
||||
@ -688,12 +689,16 @@ ENTRY main {
|
||||
const HloInstruction* conditional =
|
||||
FindInstruction(module.get(), "conditional");
|
||||
const HloComputation* on_true = conditional->branch_computation(0);
|
||||
ASSERT_EQ(on_true->instruction_count(), 5);
|
||||
ASSERT_EQ(on_true->instruction_count(), 1);
|
||||
const HloComputation* on_false = conditional->branch_computation(1);
|
||||
ASSERT_EQ(on_false->instruction_count(), 5);
|
||||
ASSERT_EQ(on_false->instruction_count(), 1);
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional())));
|
||||
EXPECT_THAT(
|
||||
root, AllOf(op::Power(op::Add(op::GetTupleElement(op::Conditional()),
|
||||
op::GetTupleElement(op::Conditional())),
|
||||
op::Add(op::GetTupleElement(op::Conditional()),
|
||||
op::GetTupleElement(op::Conditional())))));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, MovePowInWithNonTupleRoot) {
|
||||
@ -959,6 +964,104 @@ ENTRY main {
|
||||
op::AllReduce(op::GetTupleElement(op::Conditional())))))));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, DoNotMoveWithExtraOperand) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule RemoveIdenticalInstruction
|
||||
|
||||
branch {
|
||||
arg.1 = f32[10] parameter(0)
|
||||
ROOT add.1 = f32[10] add(arg.1, arg.1)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
pred.1 = pred[] parameter(0)
|
||||
tuple.1 = f32[10] parameter(1)
|
||||
tuple.2 = f32[10] parameter(2)
|
||||
conditional = f32[10]
|
||||
conditional(pred.1, tuple.1, tuple.2), true_computation=branch,
|
||||
false_computation=branch
|
||||
ROOT pow.1 = f32[10] power(conditional, tuple.2)
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, MultipleIndependentMoveIns) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule FromNMT
|
||||
|
||||
%add.31755 (x.139: f32[], y.139: bf16[]) -> bf16[] {
|
||||
%x.139 = bf16[]{:T(512)} parameter(0)
|
||||
%y.139 = bf16[]{:T(512)} parameter(1)
|
||||
ROOT %add.44073 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.139, bf16[]{:T(512)} %y.139)
|
||||
}
|
||||
|
||||
%nmt.1 {
|
||||
%wide_param.3 = (bf16[1024,4096]{1,0}, bf16[18,64,1024]{2,1,0}, s32[]) parameter(0)
|
||||
%get-tuple-element.16525 = bf16[1024,4096]{1,0} get-tuple-element((bf16[1024,4096]{1,0}, bf16[18,64,1024]{2,1,0}, s32[]) %wide_param.3), index=0
|
||||
%get-tuple-element.16527 = bf16[18,64,1024]{2,1,0} get-tuple-element((bf16[1024,4096]{1,0}, bf16[18,64,1024]{2,1,0}, s32[]) %wide_param.3), index=1
|
||||
%get-tuple-element.16588 = s32[] get-tuple-element((bf16[1024,4096]{1,0}, bf16[18,64,1024]{2,1,0}, s32[]) %wide_param.3), index=2
|
||||
%add.3764 = s32[] add(s32[] %get-tuple-element.16588, s32[] %get-tuple-element.16588), metadata={op_type="Sub" op_name="sub"}
|
||||
%reshape.9821 = s32[1]{0} reshape(s32[] %add.3764)
|
||||
%reshape.9822 = s32[] reshape(s32[1]{0} %reshape.9821)
|
||||
%constant.13127 = s32[] constant(0)
|
||||
%dynamic-slice.1245 = bf16[1,64,1024]{2,1,0} dynamic-slice(bf16[18,64,1024]{2,1,0} %get-tuple-element.16527, s32[] %reshape.9822, s32[] %constant.13127, s32[] %constant.13127), dynamic_slice_sizes={1,64,1024}
|
||||
%reshape.9825 = bf16[64,1024]{1,0} reshape(bf16[1,64,1024]{2,1,0} %dynamic-slice.1245), metadata={op_type="GatherV2" op_name="GatherV2"}
|
||||
%logistic.814 = bf16[64,1024]{1,0} logistic(bf16[64,1024]{1,0} %reshape.9825), metadata={op_type="Sigmoid" op_name="Sigmoid"}
|
||||
%multiply.4890 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %reshape.9825, bf16[64,1024]{1,0} %logistic.814), metadata={op_type="Mul" op_name="mul"}
|
||||
%tanh.573 = bf16[64,1024]{1,0} tanh(bf16[64,1024]{1,0} %reshape.9825), metadata={op_type="Tanh" op_name="Tanh"}
|
||||
%multiply.4891 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %logistic.814, bf16[64,1024]{1,0} %tanh.573), metadata={op_type="Mul" op_name="mul_1"}
|
||||
%add.3766 = bf16[64,1024]{1,0} add(bf16[64,1024]{1,0} %multiply.4890, bf16[64,1024]{1,0} %multiply.4891), metadata={op_type="AddV2" op_name="add_1"}
|
||||
%multiply.4894 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %add.3766, bf16[64,1024]{1,0} %logistic.814), metadata={op_type="Mul" op_name="gradients_1/mul_grad/Mul"}
|
||||
%constant.10568 = bf16[] constant(1), metadata={op_type="TanhGrad" op_name="gradients/Tanh_1_grad/TanhGrad"}
|
||||
%broadcast.7198 = bf16[64,1024]{1,0} broadcast(bf16[] %constant.10568), dimensions={}, metadata={op_type="TanhGrad" op_name="gradients/Tanh_1_grad/TanhGrad"}
|
||||
%multiply.4896 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %tanh.573, bf16[64,1024]{1,0} %tanh.573), metadata={op_type="TanhGrad" op_name="gradients/Tanh_1_grad/TanhGrad"}
|
||||
%constant.10571 = bf16[] constant(1), metadata={op_type="SigmoidGrad" op_name="gradients/Sigmoid_grad/SigmoidGrad"}
|
||||
%broadcast.7201 = bf16[64,1024]{1,0} broadcast(bf16[] %constant.10571), dimensions={}, metadata={op_type="SigmoidGrad" op_name="gradients/Sigmoid_grad/SigmoidGrad"}
|
||||
%subtract.1702 = bf16[64,1024]{1,0} subtract(bf16[64,1024]{1,0} %broadcast.7201, bf16[64,1024]{1,0} %logistic.814), metadata={op_type="SigmoidGrad" op_name="gradients/Sigmoid_grad/SigmoidGrad"}
|
||||
%multiply.4907 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %tanh.573, bf16[64,1024]{1,0} %add.3766), metadata={op_type="Mul" op_name="gradients/mul_2_grad/Mul_1"}
|
||||
%multiply.4908 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %multiply.4907, bf16[64,1024]{1,0} %logistic.814), metadata={op_type="SigmoidGrad" op_name="gradients/Sigmoid_2_grad/SigmoidGrad"}
|
||||
%dot.781 = bf16[64,4096]{1,0} dot(bf16[64,1024]{1,0} %multiply.4908, bf16[1024,4096]{1,0} %get-tuple-element.16525), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="MatMul" op_name="MatMul"}
|
||||
ROOT %tuple.3200 = (bf16[64,1024]{1,0}, bf16[64,4096]{1,0}, s32[]) tuple(bf16[64,1024]{1,0} %multiply.4894, bf16[64,4096]{1,0} %dot.781, s32[] %reshape.9822)
|
||||
}
|
||||
ENTRY main {
|
||||
pred.1 = pred[] parameter(0)
|
||||
arg_tuple.3 = (bf16[1024,4096]{1,0}, bf16[18,64,1024]{2,1,0}, s32[]) parameter(1)
|
||||
arg_tuple.4 = (bf16[1024,4096]{1,0}, bf16[18,64,1024]{2,1,0}, s32[]) parameter(2)
|
||||
%arg.2 = s32[] parameter(3)
|
||||
%conditional.3 = (bf16[64,1024]{1,0}, bf16[64,4096]{1,0}, s32[]) conditional(pred.1, arg_tuple.3, arg_tuple.4), true_computation=nmt.1, false_computation=nmt.1
|
||||
%get-tuple-element.15889 = bf16[64,1024]{1,0} get-tuple-element((bf16[64,1024]{1,0}, bf16[64,4096]{1,0}, s32[]) %conditional.3), index=0, metadata={op_type="Case" op_name="switch_case/indexed_case"}
|
||||
%multiply.4596 = bf16[64,1024]{1,0} multiply(bf16[64,1024]{1,0} %get-tuple-element.15889, bf16[64,1024]{1,0} %get-tuple-element.15889), metadata={op_type="L2Loss" op_name="global_norm/L2Loss"}
|
||||
%constant.10279 = bf16[] constant(0), metadata={op_type="L2Loss" op_name="global_norm/L2Loss"}
|
||||
%reduce.844 = bf16[] reduce(bf16[64,1024]{1,0} %multiply.4596, bf16[] %constant.10279), dimensions={0,1}, to_apply=%add.31755, metadata={op_type="L2Loss" op_name="global_norm/L2Loss"}
|
||||
%get-tuple-element.15890 = bf16[64,4096]{1,0} get-tuple-element((bf16[64,1024]{1,0}, bf16[64,4096]{1,0}, s32[]) %conditional.3), index=1, metadata={op_type="Case" op_name="switch_case/indexed_case"}
|
||||
%multiply.4597 = bf16[64,4096]{1,0} multiply(bf16[64,4096]{1,0} %get-tuple-element.15890, bf16[64,4096]{1,0} %get-tuple-element.15890), metadata={op_type="L2Loss" op_name="global_norm/L2Loss"}
|
||||
%constant.10280 = bf16[] constant(0), metadata={op_type="L2Loss" op_name="global_norm/L2Loss"}
|
||||
%reduce.845 = bf16[] reduce(bf16[64,4096]{1,0} %multiply.4597, bf16[] %constant.10280), dimensions={0,1}, to_apply=%add.31755, metadata={op_type="L2Loss" op_name="global_norm/L2Loss"}
|
||||
%multiply.4667 = bf16[] multiply(bf16[] %reduce.845, bf16[]{:T(128)} %reduce.844), metadata={op_type="L2Loss" op_name="global_norm/L2Loss"}
|
||||
ROOT %tuple.3200 = (bf16[], s32[]) tuple(%multiply.4667, s32[] %arg.2)
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
const HloInstruction* conditional =
|
||||
FindInstruction(module.get(), "conditional.3");
|
||||
CHECK(conditional != nullptr);
|
||||
const HloComputation* on_true = conditional->branch_computation(0);
|
||||
ASSERT_EQ(on_true->instruction_count(), 27);
|
||||
const HloComputation* on_false = conditional->branch_computation(1);
|
||||
ASSERT_EQ(on_false->instruction_count(), 27);
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, AllOf(op::Tuple(op::GetTupleElement(op::Conditional()),
|
||||
op::Parameter())));
|
||||
}
|
||||
|
||||
} // namespace conditional_opt
|
||||
|
||||
} // namespace xla
|
||||
|
@ -237,8 +237,7 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
return IsMatrixMultiplication(dot)
|
||||
? candidate_operands
|
||||
: TransposeFolding::OperandIndices{};
|
||||
},
|
||||
TransposeFolding::NeverFoldTranspose);
|
||||
});
|
||||
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
|
||||
pipeline.AddPass<HloDCE>();
|
||||
|
||||
|
@ -84,11 +84,10 @@ Status RunGpuConvForward(GpuConvParams params,
|
||||
"StreamExecutor doesn't support scaled convolution: %lf.",
|
||||
params.conv_result_scale);
|
||||
}
|
||||
stream->ThenConvolveWithAlgorithm(
|
||||
return stream->ConvolveWithAlgorithm(
|
||||
params.input_descriptor, input_buf, params.filter_descriptor, filter_buf,
|
||||
params.conv_desc, params.output_descriptor, &output_buf,
|
||||
scratch_allocator, algorithm, options.profile_result);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename ElementType, typename BiasType, typename OutputType>
|
||||
@ -123,15 +122,13 @@ Status RunGpuConvForwardActivation(GpuConvParams params,
|
||||
side_input = output_buf;
|
||||
}
|
||||
|
||||
stream->ThenFusedConvolveWithAlgorithm(
|
||||
return stream->FusedConvolveWithAlgorithm(
|
||||
params.input_descriptor, input_buf, params.conv_result_scale,
|
||||
params.filter_descriptor, filter_buf, params.conv_desc, side_input,
|
||||
params.fusion->side_input_scale, bias_desc,
|
||||
DeviceMemory<BiasType>(params.fusion->bias_buf), params.fusion->mode,
|
||||
params.output_descriptor, &output_buf, scratch_allocator, algorithm,
|
||||
options.profile_result);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// StreamExecutor supports various data types via overloading, and the support
|
||||
@ -162,7 +159,7 @@ Status RunGpuConvInternalImpl(GpuConvParams params,
|
||||
"StreamExecutor doesn't support scaled convolution: %lf.",
|
||||
params.conv_result_scale);
|
||||
}
|
||||
stream->ThenConvolveBackwardDataWithAlgorithm(
|
||||
return stream->ConvolveBackwardDataWithAlgorithm(
|
||||
params.filter_descriptor, filter_buf, params.output_descriptor,
|
||||
output_buf, params.conv_desc, params.input_descriptor, &input_buf,
|
||||
scratch_allocator, algorithm, options.profile_result);
|
||||
@ -173,7 +170,7 @@ Status RunGpuConvInternalImpl(GpuConvParams params,
|
||||
"StreamExecutor doesn't support scaled convolution: %lf.",
|
||||
params.conv_result_scale);
|
||||
}
|
||||
stream->ThenConvolveBackwardFilterWithAlgorithm(
|
||||
return stream->ConvolveBackwardFilterWithAlgorithm(
|
||||
params.input_descriptor, input_buf, params.output_descriptor,
|
||||
output_buf, params.conv_desc, params.filter_descriptor, &filter_buf,
|
||||
scratch_allocator, algorithm, options.profile_result);
|
||||
|
@ -3341,6 +3341,7 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
|
||||
last_use_instruction, parameter_time, last_use_time,
|
||||
absl::StrCat(indent_string, " ")));
|
||||
} else {
|
||||
last_use_time = std::min(last_use_time, end_time);
|
||||
TF_RETURN_IF_ERROR(add_allocation_and_verify(
|
||||
parameter_time, last_use_time, chunk, value));
|
||||
}
|
||||
@ -3359,12 +3360,13 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
|
||||
TF_RETURN_IF_ERROR(split_conditional_buffer(
|
||||
last_use_instruction, time_bound.start, time_bound.end, " "));
|
||||
} else if (!value->uses().empty()) {
|
||||
last_use_time = std::min(last_use_time, time_bound.end);
|
||||
VLOG(3) << " buffer: " << buffer.ToString()
|
||||
<< " value: " << value->ToShortString() << ": ("
|
||||
<< time_bound.start << ", " << time_bound.end
|
||||
<< time_bound.start << ", " << last_use_time
|
||||
<< ") off: " << chunk.offset << ", size: " << chunk.size;
|
||||
TF_RETURN_IF_ERROR(add_allocation_and_verify(
|
||||
time_bound.start, time_bound.end, chunk, value));
|
||||
time_bound.start, last_use_time, chunk, value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -40,8 +40,12 @@ namespace {
|
||||
// Partition convolution with batch group count.
|
||||
StatusOr<HloInstruction*> PartitionConvolutionWithBatchGroupCount(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding, const Window& conv_window,
|
||||
HloInstruction* original_hlo, int64 num_partitions, SpmdBuilder* b) {
|
||||
const HloSharding& output_sharding,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*,
|
||||
const Window& conv_window)>& create_sharded_conv,
|
||||
const Window& conv_window, HloInstruction* original_hlo,
|
||||
int64 num_partitions, SpmdBuilder* b) {
|
||||
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
|
||||
if (original_hlo->batch_group_count() == 1 ||
|
||||
original_hlo->batch_group_count() < num_partitions) {
|
||||
@ -115,21 +119,10 @@ StatusOr<HloInstruction*> PartitionConvolutionWithBatchGroupCount(
|
||||
auto aligned_output_sharding = hlo_sharding_util::TransposeSharding(
|
||||
lhs.sharding(), lhs_to_output_indices);
|
||||
|
||||
// Get LHS and RHS sharded shape.
|
||||
auto lhs_shard_shape = MakePartitionedShape(lhs.base_shape(), lhs.sharding());
|
||||
auto rhs_shard_shape = MakePartitionedShape(rhs.base_shape(), rhs.sharding());
|
||||
const int64 batch_group_count =
|
||||
CeilOfRatio(original_hlo->batch_group_count(), num_partitions);
|
||||
// Create partitioned convolution.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape sharded_conv_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
lhs_shard_shape, rhs_shard_shape, original_hlo->feature_group_count(),
|
||||
batch_group_count, conv_window, dnums));
|
||||
auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve(
|
||||
sharded_conv_shape, lhs.hlo(), rhs.hlo(),
|
||||
original_hlo->feature_group_count(), batch_group_count, conv_window,
|
||||
dnums, original_hlo->precision_config()));
|
||||
auto sharded_conv,
|
||||
create_sharded_conv(lhs.hlo(), rhs.hlo(), b, conv_window));
|
||||
sharded_conv->set_sharding(aligned_output_sharding);
|
||||
return PartitionedHlo(sharded_conv, output_base_shape, lhs.state())
|
||||
.Reshard(output_sharding)
|
||||
@ -139,8 +132,12 @@ StatusOr<HloInstruction*> PartitionConvolutionWithBatchGroupCount(
|
||||
// Partition convolution with feature group count.
|
||||
StatusOr<HloInstruction*> PartitionConvolutionWithFeatureGroupCount(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding, const Window& conv_window,
|
||||
HloInstruction* original_hlo, int64 num_partitions, SpmdBuilder* b) {
|
||||
const HloSharding& output_sharding,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*,
|
||||
const Window& conv_window)>& create_sharded_conv,
|
||||
const Window& conv_window, HloInstruction* original_hlo,
|
||||
int64 num_partitions, SpmdBuilder* b) {
|
||||
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
|
||||
if (original_hlo->feature_group_count() == 1 ||
|
||||
original_hlo->feature_group_count() < num_partitions) {
|
||||
@ -215,20 +212,9 @@ StatusOr<HloInstruction*> PartitionConvolutionWithFeatureGroupCount(
|
||||
auto aligned_output_sharding = hlo_sharding_util::TransposeSharding(
|
||||
lhs.sharding(), lhs_to_output_indices);
|
||||
|
||||
auto lhs_shard_shape = MakePartitionedShape(lhs.base_shape(), lhs.sharding());
|
||||
auto rhs_shard_shape = MakePartitionedShape(rhs.base_shape(), rhs.sharding());
|
||||
int64 feature_group_count =
|
||||
CeilOfRatio(original_hlo->feature_group_count(), num_partitions);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape sharded_conv_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
lhs_shard_shape, rhs_shard_shape, feature_group_count,
|
||||
original_hlo->batch_group_count(), conv_window, dnums));
|
||||
auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve(
|
||||
sharded_conv_shape, lhs.hlo(), rhs.hlo(), feature_group_count,
|
||||
original_hlo->batch_group_count(), conv_window, dnums,
|
||||
original_hlo->precision_config()));
|
||||
auto sharded_conv,
|
||||
create_sharded_conv(lhs.hlo(), rhs.hlo(), b, conv_window));
|
||||
sharded_conv->set_sharding(aligned_output_sharding);
|
||||
return PartitionedHlo(sharded_conv, output_base_shape, lhs.state())
|
||||
.Reshard(output_sharding)
|
||||
@ -240,9 +226,12 @@ StatusOr<HloInstruction*> PartitionConvolutionWithFeatureGroupCount(
|
||||
StatusOr<HloInstruction*>
|
||||
PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding, const Window& conv_window,
|
||||
HloInstruction* original_hlo, HloInstruction* partition_id,
|
||||
HloModule* module, SpmdBuilder* b) {
|
||||
const HloSharding& output_sharding,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*,
|
||||
const Window& conv_window)>& create_sharded_conv,
|
||||
const Window& conv_window, HloInstruction* original_hlo,
|
||||
HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
|
||||
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
|
||||
TF_RET_CHECK(!lhs.sharding().IsTileMaximal() &&
|
||||
!rhs.sharding().IsTileMaximal());
|
||||
@ -491,10 +480,9 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
|
||||
rhs_with_halo = *concat;
|
||||
}
|
||||
|
||||
auto conv = b->AddInstruction(HloInstruction::CreateConvolve(
|
||||
output_base_shape, conv_lhs, rhs_with_halo,
|
||||
original_hlo->feature_group_count(), original_hlo->batch_group_count(),
|
||||
new_window, dnums, original_hlo->precision_config()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto conv, create_sharded_conv(conv_lhs, rhs_with_halo, b, new_window));
|
||||
|
||||
auto ar = collective_ops_creator.create_cross_partition_all_reduce(
|
||||
b, conv, MakeBinaryAdd(original_hlo->shape().element_type(), module), {},
|
||||
(*lhs.state().next_channel_id)++);
|
||||
@ -509,9 +497,12 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
|
||||
StatusOr<HloInstruction*>
|
||||
PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding, const Window& conv_window,
|
||||
HloInstruction* original_hlo, HloInstruction* partition_id,
|
||||
HloModule* module, SpmdBuilder* b) {
|
||||
const HloSharding& output_sharding,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*,
|
||||
const Window& conv_window)>& create_sharded_conv,
|
||||
const Window& conv_window, HloInstruction* original_hlo,
|
||||
HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
|
||||
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
|
||||
TF_RET_CHECK(!lhs.sharding().IsTileMaximal() &&
|
||||
!rhs.sharding().IsTileMaximal());
|
||||
@ -583,7 +574,6 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
|
||||
rhs =
|
||||
rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero, reversed_rhs_dims);
|
||||
}
|
||||
|
||||
// Reshard LHS by exchanging halo such that each shard computes the partial
|
||||
// sum of the full shape result, and add AllReduce.
|
||||
//
|
||||
@ -701,11 +691,8 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
|
||||
lhs_with_halo = *concat;
|
||||
}
|
||||
|
||||
auto conv = b->AddInstruction(HloInstruction::CreateConvolve(
|
||||
output_base_shape, lhs_with_halo, rhs.hlo(),
|
||||
original_hlo->feature_group_count(), original_hlo->batch_group_count(),
|
||||
new_window, original_hlo->convolution_dimension_numbers(),
|
||||
original_hlo->precision_config()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto conv, create_sharded_conv(lhs_with_halo, rhs.hlo(), b, new_window));
|
||||
auto ar =
|
||||
lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
|
||||
b, conv, MakeBinaryAdd(output_base_shape.element_type(), module), {},
|
||||
@ -720,8 +707,11 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
|
||||
// RHS.
|
||||
StatusOr<HloInstruction*> PartitionConvolutionTiledOutput(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding, const Window& conv_window,
|
||||
HloInstruction* original_hlo, SpmdBuilder* b) {
|
||||
const HloSharding& output_sharding,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*,
|
||||
const Window& conv_window)>& create_sharded_conv,
|
||||
const Window& conv_window, HloInstruction* original_hlo, SpmdBuilder* b) {
|
||||
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
|
||||
const auto& dnums = original_hlo->convolution_dimension_numbers();
|
||||
TF_RET_CHECK(!output_sharding.IsTileMaximal());
|
||||
@ -772,19 +762,13 @@ StatusOr<HloInstruction*> PartitionConvolutionTiledOutput(
|
||||
resharded_operand_and_window->shard_window.dimensions(
|
||||
dnums.input_spatial_dimensions(i));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape sharded_conv_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
resharded_operand_and_window->sharded_input->shape(),
|
||||
rhs.hlo()->shape(), original_hlo->feature_group_count(),
|
||||
original_hlo->batch_group_count(), new_window, dnums));
|
||||
auto sharded_conv,
|
||||
create_sharded_conv(resharded_operand_and_window->sharded_input,
|
||||
rhs.hlo(), b, new_window));
|
||||
|
||||
auto shard_shape = MakePartitionedShape(output_base_shape, output_sharding);
|
||||
*sharded_conv_shape.mutable_layout() = shard_shape.layout();
|
||||
auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve(
|
||||
sharded_conv_shape, resharded_operand_and_window->sharded_input,
|
||||
rhs.hlo(), original_hlo->feature_group_count(),
|
||||
original_hlo->batch_group_count(), new_window, dnums,
|
||||
original_hlo->precision_config()));
|
||||
if (!resharded_operand_and_window->dynamic_slice_index_on_output
|
||||
.has_value()) {
|
||||
CHECK(ShapeUtil::Compatible(shard_shape, sharded_conv->shape()));
|
||||
@ -799,29 +783,34 @@ StatusOr<HloInstruction*> PartitionConvolutionTiledOutput(
|
||||
// Partition convolution with only one kind of dims partitioned.
|
||||
StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding, const Window& conv_window,
|
||||
HloInstruction* original_hlo, int64 num_partitions,
|
||||
const SpmdPartitionerOptions& options, HloInstruction* partition_id,
|
||||
HloModule* module, SpmdBuilder* b) {
|
||||
const HloSharding& output_sharding,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*,
|
||||
const Window& conv_window)>& create_sharded_conv,
|
||||
const Window& conv_window, HloInstruction* original_hlo,
|
||||
int64 num_partitions, const SpmdPartitionerOptions& options,
|
||||
HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
|
||||
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
|
||||
|
||||
// Case 1: Handle depthwise convolution with batch group count or
|
||||
// feature group count.
|
||||
if (original_hlo->batch_group_count() > 1) {
|
||||
TF_ASSIGN_OR_RETURN(auto parallel_partitioned_conv,
|
||||
PartitionConvolutionWithBatchGroupCount(
|
||||
lhs, rhs, output_base_shape, output_sharding,
|
||||
conv_window, original_hlo, num_partitions, b));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto parallel_partitioned_conv,
|
||||
PartitionConvolutionWithBatchGroupCount(
|
||||
lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
|
||||
conv_window, original_hlo, num_partitions, b));
|
||||
if (parallel_partitioned_conv) {
|
||||
return parallel_partitioned_conv;
|
||||
}
|
||||
}
|
||||
|
||||
if (original_hlo->feature_group_count() > 1) {
|
||||
TF_ASSIGN_OR_RETURN(auto parallel_partitioned_conv,
|
||||
PartitionConvolutionWithFeatureGroupCount(
|
||||
lhs, rhs, output_base_shape, output_sharding,
|
||||
conv_window, original_hlo, num_partitions, b));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto parallel_partitioned_conv,
|
||||
PartitionConvolutionWithFeatureGroupCount(
|
||||
lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
|
||||
conv_window, original_hlo, num_partitions, b));
|
||||
if (parallel_partitioned_conv) {
|
||||
return parallel_partitioned_conv;
|
||||
}
|
||||
@ -837,8 +826,8 @@ StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto partitioned_conv,
|
||||
PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
|
||||
lhs, rhs, output_base_shape, output_sharding, conv_window,
|
||||
original_hlo, partition_id, module, b));
|
||||
lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
|
||||
conv_window, original_hlo, partition_id, module, b));
|
||||
if (partitioned_conv) {
|
||||
return partitioned_conv;
|
||||
}
|
||||
@ -846,8 +835,8 @@ StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto partitioned_conv,
|
||||
PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
|
||||
lhs, rhs, output_base_shape, output_sharding, conv_window,
|
||||
original_hlo, partition_id, module, b));
|
||||
lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
|
||||
conv_window, original_hlo, partition_id, module, b));
|
||||
|
||||
if (partitioned_conv) {
|
||||
return partitioned_conv;
|
||||
@ -860,7 +849,7 @@ StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
|
||||
TF_ASSIGN_OR_RETURN(auto partitioned_conv,
|
||||
PartitionConvolutionTiledOutput(
|
||||
lhs, rhs, output_base_shape, output_sharding,
|
||||
conv_window, original_hlo, b));
|
||||
create_sharded_conv, conv_window, original_hlo, b));
|
||||
|
||||
if (partitioned_conv) {
|
||||
return partitioned_conv;
|
||||
@ -869,22 +858,92 @@ StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<HloInstruction>> CreateShardedConvConvolution(
|
||||
const HloInstruction& conv,
|
||||
const dot_as_convolution_util::DotConvolutionDimsInfo& dot_dnums,
|
||||
HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo,
|
||||
const Window& conv_window) {
|
||||
CHECK_EQ(conv.opcode(), HloOpcode::kConvolution);
|
||||
const auto& conv_dnums = conv.convolution_dimension_numbers();
|
||||
auto window = conv.window();
|
||||
for (const auto& dim : dot_dnums.batch_dims) {
|
||||
auto wd = window.mutable_dimensions(dim.spatial_dim);
|
||||
wd->set_size(sharded_lhs_hlo->shape().dimensions(
|
||||
conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
|
||||
wd->set_stride(std::max<int64>(1, wd->size() - 1));
|
||||
wd->set_base_dilation(wd->size());
|
||||
}
|
||||
for (const auto& dim : dot_dnums.contracting_dims) {
|
||||
if (dim.spatial_dim < 0) {
|
||||
continue;
|
||||
}
|
||||
auto wd = window.mutable_dimensions(dim.spatial_dim);
|
||||
wd->set_size(sharded_lhs_hlo->shape().dimensions(
|
||||
conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
|
||||
}
|
||||
for (const auto& dim : dot_dnums.rhs_non_contracting_dims) {
|
||||
if (dim.spatial_dim < 0) {
|
||||
continue;
|
||||
}
|
||||
auto wd = window.mutable_dimensions(dim.spatial_dim);
|
||||
wd->set_size(sharded_rhs_hlo->shape().dimensions(
|
||||
conv_dnums.kernel_spatial_dimensions(dim.spatial_dim)));
|
||||
wd->set_padding_high(wd->size() - 1);
|
||||
wd->set_padding_low(wd->size() - 1);
|
||||
}
|
||||
|
||||
for (const auto& dim : dot_dnums.conv_spatial_dims) {
|
||||
auto wd = window.mutable_dimensions(dim.spatial_dim);
|
||||
const auto& new_window_dimension = conv_window.dimensions(dim.spatial_dim);
|
||||
wd->set_size(new_window_dimension.size());
|
||||
wd->set_padding_high(new_window_dimension.padding_high());
|
||||
wd->set_padding_low(new_window_dimension.padding_low());
|
||||
}
|
||||
|
||||
int64 feature_group_count = conv.feature_group_count();
|
||||
if (feature_group_count > 1) {
|
||||
feature_group_count = sharded_lhs_hlo->shape().dimensions(
|
||||
conv_dnums.input_feature_dimension()) /
|
||||
sharded_rhs_hlo->shape().dimensions(
|
||||
conv_dnums.kernel_input_feature_dimension());
|
||||
}
|
||||
|
||||
int64 batch_group_count = conv.batch_group_count();
|
||||
if (batch_group_count > 1) {
|
||||
batch_group_count =
|
||||
sharded_lhs_hlo->shape().dimensions(conv_dnums.input_batch_dimension());
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape sharded_conv_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
|
||||
feature_group_count, batch_group_count, window, conv_dnums));
|
||||
*sharded_conv_shape.mutable_layout() = conv.shape().layout();
|
||||
return HloInstruction::CreateConvolve(
|
||||
sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo, feature_group_count,
|
||||
batch_group_count, window, conv_dnums, conv.precision_config());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Partition convolution.
|
||||
StatusOr<HloInstruction*> PartitionConvolution(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*,
|
||||
const Window& conv_window)>& create_sharded_conv,
|
||||
const Window& conv_window, HloInstruction* original_hlo,
|
||||
int64 num_partitions, const SpmdPartitionerOptions& options,
|
||||
HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
|
||||
TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto try_partitioned_conv,
|
||||
PartitionConvolutionBaseCase(lhs, rhs, output_base_shape, output_sharding,
|
||||
conv_window, original_hlo, num_partitions,
|
||||
options, partition_id, module, b));
|
||||
TF_ASSIGN_OR_RETURN(auto try_partitioned_conv,
|
||||
PartitionConvolutionBaseCase(
|
||||
lhs, rhs, output_base_shape, output_sharding,
|
||||
create_sharded_conv, conv_window, original_hlo,
|
||||
num_partitions, options, partition_id, module, b));
|
||||
if (try_partitioned_conv) {
|
||||
return try_partitioned_conv;
|
||||
}
|
||||
@ -932,13 +991,22 @@ Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) {
|
||||
}
|
||||
auto create_sharded_conv =
|
||||
[&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo,
|
||||
spmd::SpmdBuilder* b) -> StatusOr<HloInstruction*> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto sharded_conv,
|
||||
dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
|
||||
*hlo, dims_info, lhs_hlo, rhs_hlo));
|
||||
return b->AddInstruction(std::move(sharded_conv));
|
||||
spmd::SpmdBuilder* b,
|
||||
const Window& conv_window) -> StatusOr<HloInstruction*> {
|
||||
if (dims_info.conv_spatial_dims.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto sharded_conv,
|
||||
dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
|
||||
*hlo, dims_info, lhs_hlo, rhs_hlo));
|
||||
return b->AddInstruction(std::move(sharded_conv));
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(auto sharded_conv,
|
||||
CreateShardedConvConvolution(*hlo, dims_info, lhs_hlo,
|
||||
rhs_hlo, conv_window));
|
||||
return b->AddInstruction(std::move(sharded_conv));
|
||||
}
|
||||
};
|
||||
|
||||
return HandleDotHelper(hlo, mapping, create_sharded_conv);
|
||||
}
|
||||
|
||||
|
@ -29,6 +29,9 @@ namespace spmd {
|
||||
StatusOr<HloInstruction*> PartitionConvolution(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*,
|
||||
const Window& conv_window)>& create_sharded_conv,
|
||||
const Window& conv_window, HloInstruction* original_hlo,
|
||||
int64 num_partitions, const SpmdPartitionerOptions& options,
|
||||
HloInstruction* partition_id, HloModule* module, SpmdBuilder* b);
|
||||
|
@ -19,8 +19,10 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
@ -29,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/numbers.h"
|
||||
@ -72,8 +75,9 @@ Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) {
|
||||
mapping.rhs_non_contracting_dims.back().rhs = i;
|
||||
mapping.rhs_non_contracting_dims.back().output = next_output_dim++;
|
||||
}
|
||||
auto create_sharded_dot = [&](HloInstruction* l, HloInstruction* r,
|
||||
SpmdBuilder* b) -> StatusOr<HloInstruction*> {
|
||||
auto create_sharded_dot =
|
||||
[&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b,
|
||||
const Window& conv_window) -> StatusOr<HloInstruction*> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto sharded_dot_shape,
|
||||
ShapeInference::InferDotOpShape(l->shape(), r->shape(),
|
||||
@ -92,11 +96,13 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
|
||||
int64 num_partitions,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
|
||||
HloModule* module, HloInstruction* original_hlo, int64 lhs_batch_partitions,
|
||||
int64 rhs_batch_partitions, int64 output_batch_partitions,
|
||||
int64 lhs_contracting_partitions, int64 rhs_contracting_partitions,
|
||||
int64 lhs_non_contracting_partitions, int64 rhs_non_contracting_partitions,
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*,
|
||||
const Window& conv_window)>& create_sharded_dot,
|
||||
const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
|
||||
int64 lhs_batch_partitions, int64 rhs_batch_partitions,
|
||||
int64 output_batch_partitions, int64 lhs_contracting_partitions,
|
||||
int64 rhs_contracting_partitions, int64 lhs_non_contracting_partitions,
|
||||
int64 rhs_non_contracting_partitions,
|
||||
int64 output_lhs_non_contracting_partitions,
|
||||
int64 output_rhs_non_contracting_partitions,
|
||||
const SpmdPartitionerOptions& options, SpmdBuilder* b,
|
||||
@ -170,7 +176,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
if (lhs_batch_partitions == rhs_batch_partitions &&
|
||||
rhs_batch_partitions == num_partitions &&
|
||||
lhs_sharding_transposed_to_match_rhs == rhs_sharding) {
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
|
||||
dot->set_sharding(*lhs_sharding_transposed_to_match_output);
|
||||
return PartitionedHlo(dot, output_base_shape, lhs.state())
|
||||
.Reshard(output_sharding)
|
||||
@ -196,7 +203,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
}
|
||||
auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot, create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b));
|
||||
auto dot,
|
||||
create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b, conv_window));
|
||||
return dot;
|
||||
}
|
||||
// RHS and output are batch partitioned in the same way.
|
||||
@ -212,7 +220,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
}
|
||||
auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot, create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b));
|
||||
auto dot,
|
||||
create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b, conv_window));
|
||||
return dot;
|
||||
}
|
||||
return nullptr;
|
||||
@ -310,8 +319,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
dot_rhs = slice;
|
||||
}
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto dot,
|
||||
create_sharded_dot(dot_lhs, dot_rhs, &body_b));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot, create_sharded_dot(dot_lhs, dot_rhs, &body_b, conv_window));
|
||||
if (windowed_at_contracting_dims) {
|
||||
// Accumulate the partial output to the result buffer.
|
||||
o = body_b.AddInstruction(
|
||||
@ -465,7 +474,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
rhs =
|
||||
rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
|
||||
auto ar =
|
||||
lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
|
||||
b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {},
|
||||
@ -481,8 +491,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
output_lhs_non_contracting_partitions == num_partitions &&
|
||||
lhs_sharding_transposed_to_match_output == output_sharding) {
|
||||
auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo();
|
||||
TF_ASSIGN_OR_RETURN(auto dot,
|
||||
create_sharded_dot(lhs.hlo(), rhs_replicated, b));
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs_replicated,
|
||||
b, conv_window));
|
||||
return dot;
|
||||
}
|
||||
|
||||
@ -491,8 +501,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
output_rhs_non_contracting_partitions == num_partitions &&
|
||||
rhs_sharding_transposed_to_match_output == output_sharding) {
|
||||
auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo();
|
||||
TF_ASSIGN_OR_RETURN(auto dot,
|
||||
create_sharded_dot(lhs_replicated, rhs.hlo(), b));
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs_replicated, rhs.hlo(),
|
||||
b, conv_window));
|
||||
return dot;
|
||||
}
|
||||
|
||||
@ -503,8 +513,9 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
lhs.Reshard(*output_sharding_transposed_to_match_lhs);
|
||||
auto resharded_rhs =
|
||||
rhs.Reshard(*output_sharding_transposed_to_match_rhs);
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(),
|
||||
resharded_rhs.hlo(), b));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot, create_sharded_dot(resharded_lhs.hlo(), resharded_rhs.hlo(),
|
||||
b, conv_window));
|
||||
return dot;
|
||||
}
|
||||
// Output is partitioned along LHS non-contracting dimensions.
|
||||
@ -513,8 +524,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
lhs.Reshard(*output_sharding_transposed_to_match_lhs);
|
||||
auto replicated_rhs = rhs.Reshard(HloSharding::Replicate());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot,
|
||||
create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), b));
|
||||
auto dot, create_sharded_dot(resharded_lhs.hlo(),
|
||||
replicated_rhs.hlo(), b, conv_window));
|
||||
return dot;
|
||||
}
|
||||
// Output is partitioned along RHS non-contracting dimensions.
|
||||
@ -522,8 +533,9 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
auto replicated_lhs = lhs.Reshard(HloSharding::Replicate());
|
||||
auto resharded_rhs =
|
||||
rhs.Reshard(*output_sharding_transposed_to_match_rhs);
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(),
|
||||
resharded_rhs.hlo(), b));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot, create_sharded_dot(replicated_lhs.hlo(),
|
||||
resharded_rhs.hlo(), b, conv_window));
|
||||
return dot;
|
||||
}
|
||||
}
|
||||
@ -566,7 +578,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
rhs =
|
||||
rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
|
||||
return lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
|
||||
b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {},
|
||||
(*lhs.state().next_channel_id)++);
|
||||
@ -579,8 +592,9 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
|
||||
int64 num_partitions,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
|
||||
HloModule* module, HloInstruction* original_hlo,
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*,
|
||||
const Window& conv_window)>& create_sharded_dot,
|
||||
const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
|
||||
const SpmdPartitionerOptions& options, SpmdBuilder* b,
|
||||
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
|
||||
windowed_dot_general_loops);
|
||||
@ -592,8 +606,9 @@ StatusOr<HloInstruction*> PartitionDotGroupOnBatch(
|
||||
int64 rhs_contracting_partitions, int64 lhs_non_contracting_partitions,
|
||||
int64 rhs_non_contracting_partitions,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
|
||||
HloModule* module, HloInstruction* original_hlo,
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*,
|
||||
const Window& conv_window)>& create_sharded_dot,
|
||||
const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
|
||||
bool require_matching_devices_to_group,
|
||||
const SpmdPartitionerOptions& options, SpmdBuilder* b,
|
||||
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
|
||||
@ -808,8 +823,8 @@ StatusOr<HloInstruction*> PartitionDotGroupOnBatch(
|
||||
GetPerGroupBaseShape(output_grouped, output_base_shape),
|
||||
output_grouped.sharding, dims_mapping,
|
||||
num_partitions / output_grouped.device_groups.size(),
|
||||
create_sharded_dot, module, original_hlo, options, b,
|
||||
windowed_dot_general_loops));
|
||||
create_sharded_dot, conv_window, module, original_hlo,
|
||||
options, b, windowed_dot_general_loops));
|
||||
dot->set_sharding(UngroupSharding(output_grouped));
|
||||
return PartitionedHlo(dot, output_base_shape, lhs.state())
|
||||
.Reshard(output_sharding)
|
||||
@ -826,8 +841,9 @@ StatusOr<HloInstruction*> PartitionDotGroupOnNonContracting(
|
||||
const Shape& output_base_shape, const HloSharding& output_sharding,
|
||||
const DotConvDimsMapping& dims_mapping, int64 num_partitions,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
|
||||
HloModule* module, HloInstruction* original_hlo,
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*,
|
||||
const Window& conv_window)>& create_sharded_dot,
|
||||
const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
|
||||
bool require_matching_devices_to_group,
|
||||
const SpmdPartitionerOptions& options, SpmdBuilder* b,
|
||||
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
|
||||
@ -952,8 +968,8 @@ StatusOr<HloInstruction*> PartitionDotGroupOnNonContracting(
|
||||
GetPerGroupBaseShape(output_grouped, output_base_shape),
|
||||
output_grouped.sharding, dims_mapping,
|
||||
num_partitions / matching_grouped.device_groups.size(),
|
||||
create_sharded_dot, module, original_hlo, options, b,
|
||||
windowed_dot_general_loops));
|
||||
create_sharded_dot, conv_window, module, original_hlo,
|
||||
options, b, windowed_dot_general_loops));
|
||||
return dot;
|
||||
}
|
||||
|
||||
@ -966,8 +982,9 @@ StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
|
||||
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
|
||||
int64 num_partitions,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
|
||||
HloModule* module, HloInstruction* original_hlo,
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*,
|
||||
const Window& conv_window)>& create_sharded_dot,
|
||||
const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
|
||||
bool require_matching_devices_to_group,
|
||||
const SpmdPartitionerOptions& options, SpmdBuilder* b,
|
||||
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
|
||||
@ -1090,10 +1107,9 @@ StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
|
||||
PartitionedHlo(rhs.hlo(),
|
||||
GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
|
||||
inner_state),
|
||||
MakePartitionedShape(output_base_shape, outer_output_tmp_sharding),
|
||||
inner_output_sharding, dims_mapping, num_partitions / group_count,
|
||||
create_sharded_dot, module, original_hlo, options, b,
|
||||
windowed_dot_general_loops));
|
||||
output_base_shape, inner_output_sharding, dims_mapping,
|
||||
num_partitions / group_count, create_sharded_dot, conv_window, module,
|
||||
original_hlo, options, b, windowed_dot_general_loops));
|
||||
if (!dot) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -1124,8 +1140,9 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
|
||||
int64 num_partitions,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
|
||||
HloModule* module, HloInstruction* original_hlo,
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*,
|
||||
const Window& conv_window)>& create_sharded_dot,
|
||||
const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
|
||||
bool require_matching_devices_to_group,
|
||||
const SpmdPartitionerOptions& options, SpmdBuilder* b,
|
||||
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
|
||||
@ -1180,35 +1197,25 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
|
||||
// Try partition the purely spatially-partitioned convolution with convolution
|
||||
// spatial dimension partitioned or depthwise parallel dimension partitioned.
|
||||
if (!dims_mapping.conv_spatial_dims.empty() &&
|
||||
bool is_conv_spatial_dim_partitioned =
|
||||
(lhs_conv_spatial_partitions > 1 || rhs_conv_spatial_partitions > 1 ||
|
||||
output_conv_spatial_partitions > 1 ||
|
||||
original_hlo->batch_group_count() > 1 ||
|
||||
original_hlo->feature_group_count() > 1)) {
|
||||
const auto& conv_dnums = original_hlo->convolution_dimension_numbers();
|
||||
auto window = original_hlo->window();
|
||||
|
||||
// TODO(wangtao): remove this hack by passing create_sharded_conv to
|
||||
// PartitionConv.
|
||||
// Update convolution window when it is in the recursive call for
|
||||
// batch_dims.
|
||||
if (original_hlo->batch_group_count() == 1 &&
|
||||
original_hlo->feature_group_count() == 1 &&
|
||||
!ShapeUtil::Compatible(original_hlo->shape(), output_base_shape)) {
|
||||
for (const auto& dim : dims_mapping.batch_dims) {
|
||||
auto wd = window.mutable_dimensions(dim.spatial);
|
||||
wd->set_size(lhs.hlo()->shape().dimensions(
|
||||
conv_dnums.input_spatial_dimensions(dim.spatial)));
|
||||
wd->set_stride(std::max<int64>(1, wd->size() - 1));
|
||||
wd->set_base_dilation(wd->size());
|
||||
}
|
||||
}
|
||||
|
||||
output_conv_spatial_partitions > 1);
|
||||
bool is_conv_batch_or_contracting_dim_partitioned =
|
||||
(lhs_batch_partitions > 1 || rhs_batch_partitions > 1 ||
|
||||
output_batch_partitions > 1 ||
|
||||
(lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1));
|
||||
if ((!dims_mapping.conv_spatial_dims.empty() &&
|
||||
is_conv_spatial_dim_partitioned &&
|
||||
!is_conv_batch_or_contracting_dim_partitioned) ||
|
||||
(original_hlo->opcode() == HloOpcode::kConvolution &&
|
||||
(original_hlo->batch_group_count() > 1 ||
|
||||
original_hlo->feature_group_count() > 1))) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto partitioned_conv,
|
||||
PartitionConvolution(lhs, rhs, output_base_shape, output_sharding,
|
||||
dims_mapping, window, original_hlo, num_partitions,
|
||||
options, lhs.state().partition_id, module, b));
|
||||
dims_mapping, create_sharded_dot, conv_window,
|
||||
original_hlo, num_partitions, options,
|
||||
lhs.state().partition_id, module, b));
|
||||
|
||||
if (partitioned_conv) {
|
||||
return partitioned_conv;
|
||||
@ -1219,7 +1226,7 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
auto try_partitioned_dot,
|
||||
PartitionBaseCase(
|
||||
lhs, rhs, output_base_shape, output_sharding, dims_mapping,
|
||||
num_partitions, create_sharded_dot, module, original_hlo,
|
||||
num_partitions, create_sharded_dot, conv_window, module, original_hlo,
|
||||
lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
|
||||
lhs_contracting_partitions, rhs_contracting_partitions,
|
||||
lhs_non_contracting_partitions, rhs_non_contracting_partitions,
|
||||
@ -1243,8 +1250,8 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
lhs, rhs, output_base_shape, output_sharding, dims_mapping,
|
||||
num_partitions, lhs_contracting_partitions,
|
||||
rhs_contracting_partitions, lhs_non_contracting_partitions,
|
||||
rhs_non_contracting_partitions, create_sharded_dot, module,
|
||||
original_hlo, require_matching_devices_to_group, options, b,
|
||||
rhs_non_contracting_partitions, create_sharded_dot, conv_window,
|
||||
module, original_hlo, require_matching_devices_to_group, options, b,
|
||||
windowed_dot_general_loops));
|
||||
if (dot) {
|
||||
return dot;
|
||||
@ -1268,7 +1275,6 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
ShapeUtil::ByteSizeOf(rhs.hlo()->shape()) <=
|
||||
rhs_non_contracting_partitions *
|
||||
ShapeUtil::ByteSizeOf(lhs.hlo()->shape()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot,
|
||||
PartitionDotGroupOnNonContracting(
|
||||
@ -1284,7 +1290,7 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
lhs_matching ? output_rhs_non_contracting_partitions
|
||||
: output_lhs_non_contracting_partitions,
|
||||
output_base_shape, output_sharding, dims_mapping, num_partitions,
|
||||
create_sharded_dot, module, original_hlo,
|
||||
create_sharded_dot, conv_window, module, original_hlo,
|
||||
require_matching_devices_to_group, options, b,
|
||||
windowed_dot_general_loops));
|
||||
if (dot) {
|
||||
@ -1304,15 +1310,15 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
}
|
||||
if (!matching_dims.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot,
|
||||
PartitionDotGroupOnNonContracting(
|
||||
/*lhs_matching=*/true, lhs, rhs, lhs_contracting_partitions,
|
||||
rhs_contracting_partitions, matching_dims,
|
||||
rhs_non_contracting_partitions,
|
||||
output_rhs_non_contracting_partitions, output_base_shape,
|
||||
output_sharding, dims_mapping, num_partitions, create_sharded_dot,
|
||||
module, original_hlo, require_matching_devices_to_group, options,
|
||||
b, windowed_dot_general_loops));
|
||||
auto dot, PartitionDotGroupOnNonContracting(
|
||||
/*lhs_matching=*/true, lhs, rhs,
|
||||
lhs_contracting_partitions, rhs_contracting_partitions,
|
||||
matching_dims, rhs_non_contracting_partitions,
|
||||
output_rhs_non_contracting_partitions,
|
||||
output_base_shape, output_sharding, dims_mapping,
|
||||
num_partitions, create_sharded_dot, conv_window, module,
|
||||
original_hlo, require_matching_devices_to_group,
|
||||
options, b, windowed_dot_general_loops));
|
||||
if (dot) {
|
||||
return dot;
|
||||
}
|
||||
@ -1331,15 +1337,15 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
}
|
||||
if (!matching_dims.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot,
|
||||
PartitionDotGroupOnNonContracting(
|
||||
/*lhs_matching=*/false, rhs, lhs, rhs_contracting_partitions,
|
||||
lhs_contracting_partitions, matching_dims,
|
||||
lhs_non_contracting_partitions,
|
||||
output_lhs_non_contracting_partitions, output_base_shape,
|
||||
output_sharding, dims_mapping, num_partitions, create_sharded_dot,
|
||||
module, original_hlo, require_matching_devices_to_group, options,
|
||||
b, windowed_dot_general_loops));
|
||||
auto dot, PartitionDotGroupOnNonContracting(
|
||||
/*lhs_matching=*/false, rhs, lhs,
|
||||
rhs_contracting_partitions, lhs_contracting_partitions,
|
||||
matching_dims, lhs_non_contracting_partitions,
|
||||
output_lhs_non_contracting_partitions,
|
||||
output_base_shape, output_sharding, dims_mapping,
|
||||
num_partitions, create_sharded_dot, conv_window, module,
|
||||
original_hlo, require_matching_devices_to_group,
|
||||
options, b, windowed_dot_general_loops));
|
||||
if (dot) {
|
||||
return dot;
|
||||
}
|
||||
@ -1356,7 +1362,8 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
output_lhs_non_contracting_partitions,
|
||||
output_rhs_non_contracting_partitions, output_base_shape,
|
||||
output_sharding, dims_mapping, num_partitions, create_sharded_dot,
|
||||
module, original_hlo, require_matching_devices_to_group, options, b,
|
||||
conv_window, module, original_hlo,
|
||||
require_matching_devices_to_group, options, b,
|
||||
windowed_dot_general_loops));
|
||||
if (dot) {
|
||||
return dot;
|
||||
@ -1374,14 +1381,14 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
}
|
||||
if (!matching_dims.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot,
|
||||
PartitionDotGroupOnContracting(
|
||||
lhs, rhs, matching_dims, output_batch_partitions,
|
||||
output_lhs_non_contracting_partitions,
|
||||
output_rhs_non_contracting_partitions, output_base_shape,
|
||||
output_sharding, dims_mapping, num_partitions, create_sharded_dot,
|
||||
module, original_hlo, require_matching_devices_to_group, options,
|
||||
b, windowed_dot_general_loops));
|
||||
auto dot, PartitionDotGroupOnContracting(
|
||||
lhs, rhs, matching_dims, output_batch_partitions,
|
||||
output_lhs_non_contracting_partitions,
|
||||
output_rhs_non_contracting_partitions,
|
||||
output_base_shape, output_sharding, dims_mapping,
|
||||
num_partitions, create_sharded_dot, conv_window, module,
|
||||
original_hlo, require_matching_devices_to_group,
|
||||
options, b, windowed_dot_general_loops));
|
||||
if (dot) {
|
||||
return dot;
|
||||
}
|
||||
@ -1401,8 +1408,9 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
PartitionDot(PartitionedHlo(lhs.hlo(), lhs.base_shape(), inner_state),
|
||||
PartitionedHlo(rhs.hlo(), rhs.base_shape(), inner_state),
|
||||
output_base_shape, grouped_output.sharding, dims_mapping,
|
||||
output_sharding.NumTiles(), create_sharded_dot, module,
|
||||
original_hlo, options, b, windowed_dot_general_loops));
|
||||
output_sharding.NumTiles(), create_sharded_dot,
|
||||
conv_window, module, original_hlo, options, b,
|
||||
windowed_dot_general_loops));
|
||||
if (dot) {
|
||||
return dot;
|
||||
}
|
||||
@ -1414,7 +1422,7 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
auto dot,
|
||||
PartitionBaseCase(
|
||||
lhs, rhs, output_base_shape, output_sharding, dims_mapping,
|
||||
num_partitions, create_sharded_dot, module, original_hlo,
|
||||
num_partitions, create_sharded_dot, conv_window, module, original_hlo,
|
||||
lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
|
||||
lhs_contracting_partitions, rhs_contracting_partitions,
|
||||
lhs_non_contracting_partitions, rhs_non_contracting_partitions,
|
||||
@ -1433,8 +1441,9 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
|
||||
int64 num_partitions,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
|
||||
HloModule* module, HloInstruction* original_hlo,
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*,
|
||||
const Window& conv_window)>& create_sharded_dot,
|
||||
const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
|
||||
const SpmdPartitionerOptions& options, SpmdBuilder* b,
|
||||
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
|
||||
windowed_dot_general_loops) {
|
||||
@ -1444,17 +1453,18 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto try_partition,
|
||||
PartitionDot(lhs, rhs, output_base_shape, output_sharding, dims_mapping,
|
||||
num_partitions, create_sharded_dot, module, original_hlo,
|
||||
require_matching_devices_to_group, options, b,
|
||||
windowed_dot_general_loops));
|
||||
num_partitions, create_sharded_dot, conv_window, module,
|
||||
original_hlo, require_matching_devices_to_group, options,
|
||||
b, windowed_dot_general_loops));
|
||||
if (try_partition) {
|
||||
return try_partition;
|
||||
}
|
||||
}
|
||||
|
||||
// Default action.
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.Replicate().hlo(),
|
||||
rhs.Replicate().hlo(), b));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot, create_sharded_dot(lhs.Replicate().hlo(), rhs.Replicate().hlo(),
|
||||
b, conv_window));
|
||||
dot->set_sharding(HloSharding::Replicate());
|
||||
return PartitionedHlo(dot, output_base_shape, lhs.state())
|
||||
.Reshard(output_sharding)
|
||||
@ -1466,14 +1476,20 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
HloInstruction* hlo, const DotConvDimsMapping& dims_mapping,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) {
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*,
|
||||
const Window& conv_window)>& create_sharded_dot) {
|
||||
auto& lhs = GetPartitionedHlo(hlo->operand(0));
|
||||
auto& rhs = GetPartitionedHlo(hlo->operand(1));
|
||||
Window conv_window;
|
||||
if (hlo->opcode() == HloOpcode::kConvolution) {
|
||||
conv_window = hlo->window();
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto partitioned_dot,
|
||||
PartitionDot(lhs, rhs, hlo->shape(), hlo->sharding(), dims_mapping,
|
||||
num_partitions_, create_sharded_dot, module_, hlo, options_,
|
||||
&b_, &windowed_dot_general_loops_));
|
||||
num_partitions_, create_sharded_dot, conv_window, module_,
|
||||
hlo, options_, &b_, &windowed_dot_general_loops_));
|
||||
SetPartitionedHlo(hlo, [&] { return partitioned_dot; });
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -407,10 +407,11 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
|
||||
Status HandlePartitionId(HloInstruction* hlo) override;
|
||||
|
||||
// Implementation of dot partitioning given DotGeneralDimsMapping.
|
||||
Status HandleDotHelper(
|
||||
HloInstruction* hlo, const DotConvDimsMapping& dims_mapping,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot);
|
||||
Status HandleDotHelper(HloInstruction* hlo,
|
||||
const DotConvDimsMapping& dims_mapping,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*,
|
||||
const Window& conv_window)>& create_sharded_dot);
|
||||
|
||||
// Common handle for elementwise HLOs.
|
||||
Status HandleElementwise(HloInstruction* hlo);
|
||||
|
@ -5893,6 +5893,50 @@ ENTRY entry {
|
||||
op::Shape("f32[1,1,128,256]")));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest,
|
||||
ConvolutionInputSpatialDimAndFeatureDimParttiioned) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%lhs = f32[8,210,210,12] parameter(0)
|
||||
%lhs.copy = f32[8,210,210,12] copy(f32[8,210,210,12] %lhs),
|
||||
sharding={devices=[1,2,1,2]0,1,2,3}
|
||||
%rhs = f32[3,3,12,32] parameter(1)
|
||||
%rhs.copy = f32[3,3,12,32] copy(f32[3,3,12,32] %rhs),
|
||||
sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
|
||||
ROOT %conv = f32[8,210,210,32] convolution(
|
||||
f32[8,210,210,12] %lhs.copy,
|
||||
f32[3,3,12,32] %rhs.copy),
|
||||
window={size=3x3 pad=1_1x1_1},
|
||||
dim_labels=b01f_01io->b01f,
|
||||
sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/4));
|
||||
VLOG(1) << module->ToString();
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
auto lhs = AllOf(
|
||||
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
|
||||
op::Constant(), op::Reshape())),
|
||||
op::Shape("f32[8,105,210,6]"));
|
||||
auto left_halo =
|
||||
AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,1,210,6]"));
|
||||
auto right_halo =
|
||||
AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,1,210,6]"));
|
||||
auto exchanged_lhs = AllOf(
|
||||
op::Select(op::And(_, _), op::Concatenate(left_halo, lhs, right_halo),
|
||||
op::Broadcast(_)),
|
||||
op::Shape("f32[8,107,210,6]"));
|
||||
auto rhs = AllOf(
|
||||
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
|
||||
op::Reshape(), op::Constant())),
|
||||
op::Shape("f32[3,3,6,32]"));
|
||||
EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(
|
||||
exchanged_lhs, op::CollectivePermute(rhs))),
|
||||
op::Shape("f32[8,105,210,32]")));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
|
@ -2676,6 +2676,7 @@ xla_test(
|
||||
xla_test(
|
||||
name = "cholesky_test",
|
||||
srcs = ["cholesky_test.cc"],
|
||||
real_hardware_only = True,
|
||||
tags = [
|
||||
"no_rocm",
|
||||
"optonly",
|
||||
|
@ -61,6 +61,44 @@ XLA_TEST_F(CholeskyTest, NonPSDInput) {
|
||||
ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
XLA_TEST_F(CholeskyTest, NonPSDBatched) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
Array3D<float> a_vals({
|
||||
{
|
||||
{10, 0, 0},
|
||||
{1, 20, 0},
|
||||
{1, 1, 30},
|
||||
},
|
||||
{
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
},
|
||||
});
|
||||
|
||||
XlaOp a;
|
||||
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
Cholesky(a, /*lower=*/true);
|
||||
|
||||
float nan = std::numeric_limits<float>::quiet_NaN();
|
||||
Array3D<float> expected({
|
||||
{
|
||||
{3.16227766, 0., 0.},
|
||||
{0.31622777, 4.4609416, 0.},
|
||||
{0.31622777, 0.20175113, 5.46436606},
|
||||
},
|
||||
{
|
||||
{nan, nan, nan},
|
||||
{nan, nan, nan},
|
||||
{nan, nan, nan},
|
||||
},
|
||||
});
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, expected, {a_data.get()},
|
||||
ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
XLA_TEST_F(CholeskyTest, Lower) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
@ -181,7 +219,7 @@ class RandomCholeskyTest
|
||||
: public ClientLibraryTestBase,
|
||||
public ::testing::WithParamInterface<CholeskyTestCase> {};
|
||||
|
||||
XLA_TEST_P(RandomCholeskyTest, Random) {
|
||||
XLA_TEST_P(RandomCholeskyTest, Real) {
|
||||
// Test fails with TensorFloat-32 enabled
|
||||
tensorflow::enable_tensor_float_32_execution(false);
|
||||
XlaBuilder builder(TestName());
|
||||
@ -220,14 +258,65 @@ XLA_TEST_P(RandomCholeskyTest, Random) {
|
||||
ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
XLA_TEST_P(RandomCholeskyTest, Complex) {
|
||||
// Test fails with TensorFloat-32 enabled
|
||||
tensorflow::enable_tensor_float_32_execution(false);
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
auto test_params = GetParam();
|
||||
std::vector<int64> dimensions = {std::get<0>(test_params),
|
||||
std::get<1>(test_params),
|
||||
std::get<1>(test_params)};
|
||||
bool lower = std::get<2>(test_params);
|
||||
Shape shape = ShapeUtil::MakeShape(F32, dimensions);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto literal_real,
|
||||
LiteralUtil::CreateRandomLiteral<F32>(shape, 0.0, 1.0));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto literal_imag,
|
||||
LiteralUtil::CreateRandomLiteral<F32>(shape, 0.0, 1.0));
|
||||
|
||||
auto input_real = Parameter(&builder, 0, shape, "input_real");
|
||||
auto input_imag = Parameter(&builder, 1, shape, "input_imag");
|
||||
auto input = Complex(input_real, input_imag);
|
||||
// Form a random positive definite matrix.
|
||||
auto matrix = BatchDot(input, TransposeInMinorDims(Conj(input)),
|
||||
PrecisionConfig::HIGHEST);
|
||||
|
||||
auto cholesky = Triangle(Cholesky(matrix, lower), lower);
|
||||
|
||||
// Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0
|
||||
XlaOp verification;
|
||||
if (lower) {
|
||||
verification = BatchDot(cholesky, TransposeInMinorDims(Conj(cholesky)),
|
||||
PrecisionConfig::HIGHEST);
|
||||
} else {
|
||||
verification = BatchDot(TransposeInMinorDims(Conj(cholesky)), cholesky,
|
||||
PrecisionConfig::HIGHEST);
|
||||
}
|
||||
auto delta = matrix - verification;
|
||||
Reduce(Abs(delta * Conj(delta)), ConstantR0<float>(&builder, 0.0),
|
||||
CreateScalarAddComputation(F32, &builder), {0, 1, 2});
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto input_data_real,
|
||||
client_->TransferToServer(literal_real));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto input_data_imag,
|
||||
client_->TransferToServer(literal_imag));
|
||||
ComputeAndCompareR0<float>(&builder, 0.0,
|
||||
{input_data_real.get(), input_data_imag.get()},
|
||||
ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(RandomCholeskyTestInstance, RandomCholeskyTest,
|
||||
::testing::Values(CholeskyTestCase{1, 1, true},
|
||||
CholeskyTestCase{1, 2, true},
|
||||
CholeskyTestCase{1, 50, true},
|
||||
CholeskyTestCase{1, 50, false},
|
||||
CholeskyTestCase{1, 255, false},
|
||||
CholeskyTestCase{10, 5, true},
|
||||
CholeskyTestCase{5, 10, false},
|
||||
CholeskyTestCase{2, 20, true}));
|
||||
CholeskyTestCase{2, 20, true},
|
||||
CholeskyTestCase{2, 129, true}));
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -6,9 +6,11 @@ op {
|
||||
element in the tensor. Input range is `[-inf, inf]` and
|
||||
output range is `[-1,1]`.
|
||||
|
||||
```python
|
||||
x = tf.constant([-float("inf"), -5, -0.5, 1, 1.2, 2, 3, float("inf")])
|
||||
tf.math.tanh(x) ==> [-1. -0.99990916 -0.46211717 0.7615942 0.8336547 0.9640276 0.9950547 1.]
|
||||
```
|
||||
>>> x = tf.constant([-float("inf"), -5, -0.5, 1, 1.2, 2, 3, float("inf")])
|
||||
>>> tf.math.tanh(x)
|
||||
<tf.Tensor: shape=(8,), dtype=float32, numpy=
|
||||
array([-1. , -0.99990916, -0.46211717, 0.7615942 , 0.8336547 ,
|
||||
0.9640276 , 0.9950547 , 1. ], dtype=float32)>
|
||||
|
||||
END
|
||||
}
|
||||
|
@ -218,6 +218,9 @@ void BaseCollectiveExecutor::StartAbort(const Status& s) {
|
||||
VLOG(1) << "BaseCollectiveExecutor::StartAbort " << s;
|
||||
cem_->GetParamResolver()->StartAbort(s);
|
||||
remote_access_->StartAbort(s);
|
||||
if (cem_->GetNcclCommunicator() != nullptr) {
|
||||
cem_->GetNcclCommunicator()->StartAbort(s);
|
||||
}
|
||||
}
|
||||
|
||||
void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
|
||||
|
@ -102,7 +102,6 @@ class TestCollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
|
||||
}
|
||||
|
||||
NcclCommunicatorInterface* GetNcclCommunicator() const override {
|
||||
LOG(FATAL) << "Unimplemented"; // Crash OK
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -227,7 +227,9 @@ Status DataServiceDispatcherClient::EnsureInitialized() {
|
||||
std::shared_ptr<grpc::ChannelCredentials> credentials;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
|
||||
auto channel = grpc::CreateChannel(address_, credentials);
|
||||
grpc::ChannelArguments args;
|
||||
args.SetMaxReceiveMessageSize(std::numeric_limits<int32>::max());
|
||||
auto channel = grpc::CreateCustomChannel(address_, credentials, args);
|
||||
stub_ = DispatcherService::NewStub(channel);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -70,6 +70,10 @@ Status DataServiceWorkerImpl::Start(const std::string& worker_address) {
|
||||
|
||||
Status s = Heartbeat();
|
||||
while (!s.ok()) {
|
||||
if (!errors::IsUnavailable(s) && !errors::IsAborted(s) &&
|
||||
!errors::IsCancelled(s)) {
|
||||
return s;
|
||||
}
|
||||
LOG(WARNING) << "Failed to register with dispatcher at "
|
||||
<< config_.dispatcher_address() << ": " << s;
|
||||
Env::Default()->SleepForMicroseconds(kRetryIntervalMicros);
|
||||
|
@ -25,6 +25,15 @@ namespace data {
|
||||
namespace model {
|
||||
namespace {
|
||||
|
||||
// Helper function for node traversal that doesn't skip any nodes.
|
||||
inline bool IsAnyNode(const std::shared_ptr<Node> node) { return true; }
|
||||
|
||||
// Helper function for node traversal that filters out nodes for which
|
||||
// autotuning is disabled.
|
||||
inline bool IsAutotuneNode(const std::shared_ptr<Node> node) {
|
||||
return node->autotune();
|
||||
}
|
||||
|
||||
// Wrapper for the square function to reduce verbosity.
|
||||
inline double Square(double x) { return x * x; }
|
||||
|
||||
@ -82,7 +91,8 @@ class InterleaveMany : public Node {
|
||||
if (num_inputs() <= 1) {
|
||||
(*output_times)[long_name()] = self_processing_time;
|
||||
if (gradients) {
|
||||
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
|
||||
for (const auto& node :
|
||||
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
|
||||
gradients->erase(node->long_name());
|
||||
}
|
||||
}
|
||||
@ -94,7 +104,8 @@ class InterleaveMany : public Node {
|
||||
(*output_times)[inputs_.front()->long_name()]) /
|
||||
static_cast<double>(num_inputs() - 1);
|
||||
if (gradients) {
|
||||
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
|
||||
for (const auto& node :
|
||||
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
|
||||
auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
|
||||
if (gradient) {
|
||||
*gradient /= static_cast<double>(num_inputs() - 1);
|
||||
@ -211,7 +222,8 @@ class AsyncInterleaveMany : public Node {
|
||||
if (num_inputs() <= 1) {
|
||||
(*output_times)[long_name()] = self_processing_time;
|
||||
if (gradients) {
|
||||
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
|
||||
for (const auto& node :
|
||||
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
|
||||
gradients->erase(node->long_name());
|
||||
}
|
||||
}
|
||||
@ -245,7 +257,8 @@ class AsyncInterleaveMany : public Node {
|
||||
consumer_time_der +
|
||||
producer_time_der * inputs_time_der_sum / parallelism;
|
||||
|
||||
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
|
||||
for (const auto& node :
|
||||
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
|
||||
auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
|
||||
if (gradient) {
|
||||
*gradient *= (producer_time_der /
|
||||
@ -345,14 +358,16 @@ class KnownRatio : public Node {
|
||||
if (ratio_ == 0) {
|
||||
(*output_times)[long_name()] = self_processing_time;
|
||||
if (gradients) {
|
||||
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
|
||||
for (const auto& node :
|
||||
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
|
||||
gradients->erase(node->long_name());
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (gradients) {
|
||||
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
|
||||
for (const auto& node :
|
||||
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
|
||||
auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
|
||||
if (gradient) {
|
||||
*gradient *= ratio_;
|
||||
@ -483,7 +498,8 @@ class AsyncKnownRatio : public Node {
|
||||
consumer_time = input_time;
|
||||
producer_time = 0.0L;
|
||||
if (gradients) {
|
||||
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
|
||||
for (const auto& node :
|
||||
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
|
||||
gradients->erase(node->long_name());
|
||||
}
|
||||
|
||||
@ -528,7 +544,8 @@ class AsyncKnownRatio : public Node {
|
||||
(*output_time_gradients)[long_name()] =
|
||||
consumer_time_der + producer_time_der * inputs_time_der_sum;
|
||||
|
||||
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
|
||||
for (const auto& node :
|
||||
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
|
||||
auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
|
||||
if (gradient) {
|
||||
*gradient *= (ratio_ * producer_time_der);
|
||||
@ -629,7 +646,8 @@ class UnknownRatio : public Node {
|
||||
inputs_.front()->num_elements() == 0) {
|
||||
(*output_times)[long_name()] = self_processing_time;
|
||||
if (gradients) {
|
||||
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
|
||||
for (const auto& node :
|
||||
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
|
||||
gradients->erase(node->long_name());
|
||||
}
|
||||
}
|
||||
@ -640,7 +658,8 @@ class UnknownRatio : public Node {
|
||||
double ratio = static_cast<double>(inputs_.front()->num_elements()) /
|
||||
static_cast<double>(num_elements_);
|
||||
if (gradients) {
|
||||
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
|
||||
for (const auto& node :
|
||||
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
|
||||
auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
|
||||
if (gradient) {
|
||||
*gradient *= ratio;
|
||||
@ -917,7 +936,8 @@ void Node::CollectTunableParameters(
|
||||
absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) const {
|
||||
tf_shared_lock l(mu_);
|
||||
// Collect tunable parameters from the leaves of the nodes tree to the root.
|
||||
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
|
||||
for (const auto& node :
|
||||
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
|
||||
tf_shared_lock l(node->mu_);
|
||||
node->CollectTunableParametersHelper(parameters);
|
||||
}
|
||||
@ -928,7 +948,8 @@ string Node::DebugString() const {
|
||||
absl::flat_hash_map<string, string> debug_strings;
|
||||
tf_shared_lock l(mu_);
|
||||
// Build up the debug string from the leaves of the nodes tree to the root.
|
||||
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
|
||||
for (const auto& node :
|
||||
CollectNodes(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
|
||||
tf_shared_lock l(node->mu_);
|
||||
node->DebugStringHelper(&debug_strings);
|
||||
}
|
||||
@ -952,7 +973,7 @@ double Node::OutputTime(absl::flat_hash_map<string, double>* input_times,
|
||||
// `nullptr`) and the output time for each node.
|
||||
absl::flat_hash_map<string, double> output_time_gradients, output_times;
|
||||
tf_shared_lock l(mu_);
|
||||
auto nodes = CollectNodes(TraversalOrder::BFS);
|
||||
auto nodes = CollectNodes(TraversalOrder::BFS, IsAutotuneNode);
|
||||
|
||||
// Computes and stores input time for each node from the root to leaves of the
|
||||
// nodes tree.
|
||||
@ -1001,7 +1022,8 @@ double Node::TotalBufferedBytes() const {
|
||||
absl::flat_hash_map<string, double> total_bytes;
|
||||
tf_shared_lock l(mu_);
|
||||
// Compute total buffered bytes from the leaves of the nodes tree to the root.
|
||||
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
|
||||
for (const auto& node :
|
||||
CollectNodes(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
|
||||
tf_shared_lock l(node->mu_);
|
||||
node->TotalBufferedBytesHelper(&total_bytes);
|
||||
}
|
||||
@ -1015,7 +1037,8 @@ double Node::TotalMaximumBufferedBytes() const {
|
||||
tf_shared_lock l(mu_);
|
||||
// Compute total maximum buffered bytes from the leaves of the nodes tree
|
||||
// to the root.
|
||||
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
|
||||
for (const auto& node :
|
||||
CollectNodes(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
|
||||
tf_shared_lock l(node->mu_);
|
||||
node->TotalMaximumBufferedBytesHelper(&total_bytes);
|
||||
}
|
||||
@ -1033,7 +1056,8 @@ double Node::TotalProcessingTime(
|
||||
|
||||
// Computes per-element CPU time spent in the subtree rooted in the node from
|
||||
// the leaves of the nodes tree to the root.
|
||||
for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) {
|
||||
for (const auto& node :
|
||||
CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
|
||||
tf_shared_lock l(node->mu_);
|
||||
node->TotalProcessingTimeLocked(processing_times, &total_processing_times);
|
||||
}
|
||||
@ -1123,14 +1147,17 @@ double Node::SelfProcessingTimeLocked() const {
|
||||
static_cast<double>(num_elements_);
|
||||
}
|
||||
|
||||
Node::NodeVector Node::CollectNodes(TraversalOrder order) const
|
||||
Node::NodeVector Node::CollectNodes(
|
||||
TraversalOrder order, bool collect_node(const std::shared_ptr<Node>)) const
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_) {
|
||||
NodeVector node_vector;
|
||||
std::list<std::shared_ptr<Node>> temp_list;
|
||||
|
||||
for (auto& input : inputs_) {
|
||||
node_vector.push_back(input);
|
||||
temp_list.push_back(input);
|
||||
if (collect_node(input)) {
|
||||
node_vector.push_back(input);
|
||||
temp_list.push_back(input);
|
||||
}
|
||||
}
|
||||
|
||||
while (!temp_list.empty()) {
|
||||
@ -1138,8 +1165,10 @@ Node::NodeVector Node::CollectNodes(TraversalOrder order) const
|
||||
temp_list.pop_front();
|
||||
tf_shared_lock l(cur_node->mu_);
|
||||
for (auto& input : cur_node->inputs_) {
|
||||
node_vector.push_back(input);
|
||||
temp_list.push_back(input);
|
||||
if (collect_node(input)) {
|
||||
node_vector.push_back(input);
|
||||
temp_list.push_back(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -485,8 +485,12 @@ class Node {
|
||||
|
||||
// Returns a vector of nodes of the subtree rooted in this node. The nodes are
|
||||
// either in breadth-first search or reverse breadth-first search order
|
||||
// depending on the `order` argument. The root node itself is not collected.
|
||||
NodeVector CollectNodes(TraversalOrder order) const
|
||||
// depending on the `order` argument. The nodes are collected based on the
|
||||
// results of the `collect_node` predicate: if the predicate returns `false`
|
||||
// for a given node, then the subtree rooted in this node is excluded. The
|
||||
// root node itself is not collected.
|
||||
NodeVector CollectNodes(TraversalOrder order,
|
||||
bool collect_node(const std::shared_ptr<Node>)) const
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Collect tunable parameters for the node.
|
||||
|
@ -531,7 +531,6 @@ Status SchedulerState::Init(const GrapplerItem* item,
|
||||
initial_nodes->push_back(curr_node);
|
||||
VLOG(3) << "Added ready node: " << curr_node->name();
|
||||
}
|
||||
|
||||
feed_nodes.erase(curr_node->name());
|
||||
|
||||
if (IsPersistent(*curr_node)) {
|
||||
@ -778,6 +777,10 @@ NodeState& SchedulerState::GetNodeStateOrCreateIt(const NodeDef* node) {
|
||||
node_state.num_outputs_executed[-1] = 0;
|
||||
node_state.outputs[-1] = {};
|
||||
|
||||
// Initialize time_scheduled to infinity, so we know whether it has been
|
||||
// assigned a non-default value later.
|
||||
node_state.time_scheduled = Costs::Duration().infinity();
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
@ -862,10 +865,16 @@ std::vector<const NodeDef*> SchedulerState::MarkNodeExecuted(
|
||||
// Node is scheduled when the device is available AND all the inputs are
|
||||
// ready; hence, time_scheduled is time_ready if time_ready > device curr
|
||||
// time.
|
||||
node_state.time_scheduled =
|
||||
std::max(device.GetCurrTime(), node_state.time_ready);
|
||||
// Override device curr time with the time_scheduled.
|
||||
device.device_costs.execution_time = node_state.time_scheduled;
|
||||
// NodeState times are assigned infinity at initialization. If they are
|
||||
// still infinity here, we need to assign them. If not, it has been assigned
|
||||
// already, so skip. This latter case may occur when a scheduler in-lines
|
||||
// function calls, and thus schedules only function sub-nodes.
|
||||
if (node_state.time_scheduled == Costs::Duration().infinity()) {
|
||||
node_state.time_scheduled =
|
||||
std::max(device.GetCurrTime(), node_state.time_ready);
|
||||
// Override device curr time with the time_scheduled.
|
||||
device.device_costs.execution_time = node_state.time_scheduled;
|
||||
}
|
||||
device.device_costs = CombineCosts(device.device_costs, total_node_costs);
|
||||
auto curr_time = device.GetCurrTime();
|
||||
node_state.time_finished = curr_time;
|
||||
@ -1000,8 +1009,13 @@ Costs SchedulerState::Summary() const {
|
||||
for (const auto& node_port : state.persistent_nodes) {
|
||||
const auto* node = node_port.first;
|
||||
const auto port = node_port.second;
|
||||
const auto output_size =
|
||||
CalculateOutputSize(node_map_.at(node).output_properties, port);
|
||||
auto output_size = 0;
|
||||
// Check if the node is in the node_map. It may be that the node executed
|
||||
// on this device was executed by a different Scheduler.
|
||||
if (node_map_.find(node) != node_map_.end()) {
|
||||
output_size =
|
||||
CalculateOutputSize(node_map_.at(node).output_properties, port);
|
||||
}
|
||||
persistent_memory_usage += output_size;
|
||||
op_to_memory[node->op()] += output_size;
|
||||
persistent_ops.insert(node->op());
|
||||
@ -1048,8 +1062,12 @@ Costs SchedulerState::Summary() const {
|
||||
for (const auto& node_port : state.mem_usage_snapshot_at_peak) {
|
||||
const auto* node = node_port.first;
|
||||
const auto port = node_port.second;
|
||||
op_to_memory[node->op()] +=
|
||||
CalculateOutputSize(node_map_.at(node).output_properties, port);
|
||||
// Check if the node is in the node_map. It may be that the node executed
|
||||
// on this device was executed by a different Scheduler.
|
||||
if (node_map_.find(node) != node_map_.end()) {
|
||||
op_to_memory[node->op()] +=
|
||||
CalculateOutputSize(node_map_.at(node).output_properties, port);
|
||||
}
|
||||
}
|
||||
Costs::NanoSeconds total_compute_time_ns;
|
||||
bool is_total_cost_accurate = true;
|
||||
@ -1132,6 +1150,12 @@ void SchedulerState::GenerateRunMetadata(RunMetadata* metadata) {
|
||||
DeviceStepStats* device_stepstats = stepstats->add_dev_stats();
|
||||
device_stepstats->set_device(device.first);
|
||||
for (const auto& node_def : device.second.nodes_executed) {
|
||||
// Only proceed if the node is in the node_map. This is to cover the case
|
||||
// where a device has executed a node that is not in the node_map of
|
||||
// this scheduler.
|
||||
if (node_map_.find(node_def) == node_map_.end()) {
|
||||
continue;
|
||||
}
|
||||
const NodeState& nodestate = node_map_.at(node_def);
|
||||
NodeExecStats* node_stats = device_stepstats->add_node_stats();
|
||||
uint64 total_output_size = 0;
|
||||
@ -1229,6 +1253,12 @@ SchedulerState::GetPersistentMemoryUsage() const {
|
||||
return result;
|
||||
}
|
||||
|
||||
void SchedulerState::SetNodeStateTimeScheduled(const NodeDef* node) {
|
||||
auto& node_state = node_map_.at(node);
|
||||
auto& device = device_[node_state.device_name];
|
||||
node_state.time_scheduled = device.GetCurrTime();
|
||||
}
|
||||
|
||||
VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
|
||||
const bool use_aggressive_shape_inference,
|
||||
Cluster* cluster,
|
||||
|
@ -371,6 +371,10 @@ class SchedulerState {
|
||||
}
|
||||
|
||||
protected:
|
||||
// Assigns the time_scheduled in the NodeState of node to the current
|
||||
// execution_time of the device executing this node.
|
||||
void SetNodeStateTimeScheduled(const NodeDef* node);
|
||||
|
||||
// This method can be used by a class derived from SchedulerState to
|
||||
// access the device state map.
|
||||
std::unordered_map<string, DeviceState>* GetMutableDeviceState() {
|
||||
|
@ -5228,6 +5228,24 @@ tf_kernel_library(
|
||||
deps = STRING_DEPS,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "as_string_op_test",
|
||||
size = "small",
|
||||
srcs = ["as_string_op_test.cc"],
|
||||
deps = [
|
||||
":as_string_op",
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "unicode_ops",
|
||||
prefix = "unicode_ops",
|
||||
|
@ -65,9 +65,26 @@ class AsStringOp : public OpKernel {
|
||||
OP_REQUIRES(ctx, !(scientific && shortest),
|
||||
errors::InvalidArgument(
|
||||
"Cannot select both scientific and shortest notation"));
|
||||
|
||||
format_ = "%";
|
||||
if (!fill_string.empty()) {
|
||||
switch (fill_string[0]) {
|
||||
case ' ':
|
||||
case '+':
|
||||
case '-':
|
||||
case '0':
|
||||
case '#':
|
||||
strings::Appendf(&format_, "%s", fill_string.c_str());
|
||||
break;
|
||||
default:
|
||||
bool fill_not_supported = true;
|
||||
OP_REQUIRES(ctx, !fill_not_supported,
|
||||
errors::InvalidArgument("Fill argument not supported: \"",
|
||||
fill_string, "\""));
|
||||
}
|
||||
}
|
||||
if (width > -1) {
|
||||
strings::Appendf(&format_, "%s%d", fill_string.c_str(), width);
|
||||
strings::Appendf(&format_, "%d", width);
|
||||
}
|
||||
if (precision > -1) {
|
||||
strings::Appendf(&format_, ".%d", precision);
|
||||
|
245
tensorflow/core/kernels/as_string_op_test.cc
Normal file
245
tensorflow/core/kernels/as_string_op_test.cc
Normal file
@ -0,0 +1,245 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class AsStringGraphTest : public OpsTestBase {
|
||||
protected:
|
||||
Status Init(DataType input_type, const string& fill = "", int width = -1,
|
||||
int precision = -1, bool scientific = false,
|
||||
bool shortest = false) {
|
||||
TF_CHECK_OK(NodeDefBuilder("op", "AsString")
|
||||
.Input(FakeInput(input_type))
|
||||
.Attr("fill", fill)
|
||||
.Attr("precision", precision)
|
||||
.Attr("scientific", scientific)
|
||||
.Attr("shortest", shortest)
|
||||
.Attr("width", width)
|
||||
.Finalize(node_def()));
|
||||
return InitOp();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(AsStringGraphTest, Int8) {
|
||||
TF_ASSERT_OK(Init(DT_INT8));
|
||||
|
||||
AddInputFromArray<int8>(TensorShape({3}), {-42, 0, 42});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
|
||||
test::FillValues<tstring>(&expected, {"-42", "0", "42"});
|
||||
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, Int64) {
|
||||
TF_ASSERT_OK(Init(DT_INT64));
|
||||
|
||||
AddInputFromArray<int64>(TensorShape({3}), {-42, 0, 42});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
|
||||
test::FillValues<tstring>(&expected, {"-42", "0", "42"});
|
||||
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, FloatDefault) {
|
||||
TF_ASSERT_OK(Init(DT_FLOAT));
|
||||
|
||||
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
|
||||
test::FillValues<tstring>(
|
||||
&expected, {"-42.000000", "0.000000", "3.141590", "42.000000"});
|
||||
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, FloatScientific) {
|
||||
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
|
||||
/*scientific=*/true));
|
||||
|
||||
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
|
||||
test::FillValues<tstring>(&expected, {"-4.200000e+01", "0.000000e+00",
|
||||
"3.141590e+00", "4.200000e+01"});
|
||||
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, FloatShortest) {
|
||||
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
|
||||
/*scientific=*/false, /*shortest=*/true));
|
||||
|
||||
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
|
||||
test::FillValues<tstring>(&expected, {"-42", "0", "3.14159", "42"});
|
||||
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, FloatPrecisionOnly) {
|
||||
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/2));
|
||||
|
||||
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
|
||||
test::FillValues<tstring>(&expected, {"-42.00", "0.00", "3.14", "42.00"});
|
||||
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, FloatWidthOnly) {
|
||||
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/5));
|
||||
|
||||
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
|
||||
test::FillValues<tstring>(
|
||||
&expected, {"-42.000000", "0.000000", "3.141590", "42.000000"});
|
||||
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, Float_5_2_Format) {
|
||||
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/5, /*precision=*/2));
|
||||
|
||||
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
|
||||
test::FillValues<tstring>(&expected, {"-42.00", " 0.00", " 3.14", "42.00"});
|
||||
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, Complex) {
|
||||
TF_ASSERT_OK(Init(DT_COMPLEX64, /*fill=*/"", /*width=*/5, /*precision=*/2));
|
||||
|
||||
AddInputFromArray<complex64>(TensorShape({3}), {{-4, 2}, {0}, {3.14159, -1}});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
|
||||
test::FillValues<tstring>(
|
||||
&expected, {"(-4.00, 2.00)", "( 0.00, 0.00)", "( 3.14,-1.00)"});
|
||||
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, Bool) {
|
||||
TF_ASSERT_OK(Init(DT_BOOL));
|
||||
|
||||
AddInputFromArray<bool>(TensorShape({2}), {true, false});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor expected(allocator(), DT_STRING, TensorShape({2}));
|
||||
test::FillValues<tstring>(&expected, {"true", "false"});
|
||||
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, String) {
|
||||
Status s = Init(DT_STRING);
|
||||
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||
ASSERT_TRUE(absl::StrContains(
|
||||
s.error_message(),
|
||||
"Value for attr 'T' of string is not in the list of allowed values"));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, OnlyOneOfScientificAndShortest) {
|
||||
Status s = Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
|
||||
/*scientific=*/true, /*shortest=*/true);
|
||||
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||
ASSERT_TRUE(
|
||||
absl::StrContains(s.error_message(),
|
||||
"Cannot select both scientific and shortest notation"));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, NoShortestForNonFloat) {
|
||||
Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
|
||||
/*scientific=*/false, /*shortest=*/true);
|
||||
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||
ASSERT_TRUE(absl::StrContains(
|
||||
s.error_message(),
|
||||
"scientific and shortest format not supported for datatype"));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, NoScientificForNonFloat) {
|
||||
Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
|
||||
/*scientific=*/true);
|
||||
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||
ASSERT_TRUE(absl::StrContains(
|
||||
s.error_message(),
|
||||
"scientific and shortest format not supported for datatype"));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, NoPrecisionForNonFloat) {
|
||||
Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/5);
|
||||
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||
ASSERT_TRUE(absl::StrContains(s.error_message(),
|
||||
"precision not supported for datatype"));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, LongFill) {
|
||||
Status s = Init(DT_INT32, /*fill=*/"asdf");
|
||||
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||
ASSERT_TRUE(absl::StrContains(s.error_message(),
|
||||
"Fill string must be one or fewer characters"));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, FillWithZero) {
|
||||
TF_ASSERT_OK(Init(DT_INT64, /*fill=*/"0", /*width=*/4));
|
||||
|
||||
AddInputFromArray<int64>(TensorShape({3}), {-42, 0, 42});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
|
||||
test::FillValues<tstring>(&expected, {"-042", "0000", "0042"});
|
||||
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, FillWithSpace) {
|
||||
TF_ASSERT_OK(Init(DT_INT64, /*fill=*/" ", /*width=*/4));
|
||||
|
||||
AddInputFromArray<int64>(TensorShape({3}), {-42, 0, 42});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
|
||||
test::FillValues<tstring>(&expected, {" -42", " 0", " 42"});
|
||||
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, FillWithChar1) {
|
||||
TF_ASSERT_OK(Init(DT_INT64, /*fill=*/"-", /*width=*/4));
|
||||
|
||||
AddInputFromArray<int64>(TensorShape({3}), {-42, 0, 42});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
|
||||
test::FillValues<tstring>(&expected, {"-42 ", "0 ", "42 "});
|
||||
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, FillWithChar3) {
|
||||
Status s = Init(DT_INT32, /*fill=*/"s");
|
||||
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||
ASSERT_TRUE(
|
||||
absl::StrContains(s.error_message(), "Fill argument not supported"));
|
||||
}
|
||||
|
||||
TEST_F(AsStringGraphTest, FillWithChar4) {
|
||||
Status s = Init(DT_INT32, /*fill=*/"n");
|
||||
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||
ASSERT_TRUE(
|
||||
absl::StrContains(s.error_message(), "Fill argument not supported"));
|
||||
}
|
||||
|
||||
} // end namespace
|
||||
} // end namespace tensorflow
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user