Merge commit for internal changes
Conflict tensorflow/contrib/layers/python/layers/layers.py: Preserved both indentation on "decay" change and doc fix on "center".
This commit is contained in:
commit
00f1d4369a
@ -38,6 +38,13 @@ new_http_archive(
|
||||
sha256 = "b4c178fd6236dcf0a20d25d07c45eebe85281263978c6a6f1dfc49d75befc45f"
|
||||
)
|
||||
|
||||
new_http_archive(
|
||||
name = "stylize",
|
||||
build_file = "models.BUILD",
|
||||
url = "https://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip",
|
||||
sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa"
|
||||
)
|
||||
|
||||
# TENSORBOARD_BOWER_AUTOGENERATED_BELOW_THIS_LINE_DO_NOT_EDIT
|
||||
|
||||
new_http_archive(
|
||||
|
18
configure
vendored
18
configure
vendored
@ -57,9 +57,27 @@ done
|
||||
if is_windows; then
|
||||
TF_NEED_GCP=0
|
||||
TF_NEED_HDFS=0
|
||||
TF_NEED_JEMALLOC=0
|
||||
TF_NEED_OPENCL=0
|
||||
fi
|
||||
|
||||
while [ "$TF_NEED_JEMALLOC" == "" ]; do
|
||||
read -p "Do you wish to use jemalloc as the malloc implementation? "\
|
||||
"(Linux only) [Y/n] " INPUT
|
||||
case $INPUT in
|
||||
[Yy]* ) echo "jemalloc enabled on Linux"; TF_NEED_JEMALLOC=1;;
|
||||
[Nn]* ) echo "jemalloc disabled on Linux"; TF_NEED_JEMALLOC=0;;
|
||||
"" ) echo "jemalloc enabled on Linux"; TF_NEED_JEMALLOC=1;;
|
||||
* ) echo "Invalid selection: " $INPUT;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [ "$TF_NEED_JEMALLOC" == "1" ]; then
|
||||
sed -i -e "s/WITH_JEMALLOC = False/WITH_JEMALLOC = True/" tensorflow/core/platform/default/build_config.bzl
|
||||
else
|
||||
sed -i -e "s/WITH_JEMALLOC = True/WITH_JEMALLOC = False/" tensorflow/core/platform/default/build_config.bzl
|
||||
fi
|
||||
|
||||
while [ "$TF_NEED_GCP" == "" ]; do
|
||||
read -p "Do you wish to build TensorFlow with "\
|
||||
"Google Cloud Platform support? [y/N] " INPUT
|
||||
|
4
tensorflow/.clang-format
Normal file
4
tensorflow/.clang-format
Normal file
@ -0,0 +1,4 @@
|
||||
# Run manually to reformat a file:
|
||||
# clang-format -i --style=file <file>
|
||||
BasedOnStyle: Google
|
||||
DerivePointerAlignment: false
|
@ -190,6 +190,7 @@ filegroup(
|
||||
"//tensorflow/examples/image_retraining:all_files",
|
||||
"//tensorflow/examples/label_image:all_files",
|
||||
"//tensorflow/examples/learn:all_files",
|
||||
"//tensorflow/examples/saved_model:all_files",
|
||||
"//tensorflow/examples/tutorials/estimators:all_files",
|
||||
"//tensorflow/examples/tutorials/mnist:all_files",
|
||||
"//tensorflow/examples/tutorials/word2vec:all_files",
|
||||
@ -203,7 +204,6 @@ filegroup(
|
||||
"//tensorflow/python/debug:all_files",
|
||||
"//tensorflow/python/kernel_tests:all_files",
|
||||
"//tensorflow/python/saved_model:all_files",
|
||||
"//tensorflow/python/saved_model/example:all_files",
|
||||
"//tensorflow/python/tools:all_files",
|
||||
"//tensorflow/tensorboard:all_files",
|
||||
"//tensorflow/tensorboard/app:all_files",
|
||||
|
@ -6,6 +6,7 @@ licenses(["notice"]) # Apache 2.0
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
"tf_cuda_library",
|
||||
"tf_custom_op_library",
|
||||
)
|
||||
@ -23,13 +24,19 @@ tf_cuda_library(
|
||||
name = "c_api",
|
||||
srcs = ["c_api.cc"],
|
||||
hdrs = ["c_api.h"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/cc/saved_model:loader",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/cc/saved_model:loader",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
|
@ -20,7 +20,9 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#ifndef __ANDROID__
|
||||
#include "tensorflow/cc/saved_model/loader.h"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/log_memory.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
@ -37,6 +39,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
@ -159,11 +162,13 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in,
|
||||
return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
|
||||
}
|
||||
const auto proto_size = in.ByteSize();
|
||||
void* buf = malloc(proto_size);
|
||||
void* buf = tensorflow::port::Malloc(proto_size);
|
||||
in.SerializeToArray(buf, proto_size);
|
||||
out->data = buf;
|
||||
out->length = proto_size;
|
||||
out->data_deallocator = [](void* data, size_t length) { free(data); };
|
||||
out->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -287,13 +292,15 @@ void TF_SetConfig(TF_SessionOptions* options, const void* proto,
|
||||
TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; }
|
||||
|
||||
TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) {
|
||||
void* copy = malloc(proto_len);
|
||||
void* copy = tensorflow::port::Malloc(proto_len);
|
||||
memcpy(copy, proto, proto_len);
|
||||
|
||||
TF_Buffer* buf = new TF_Buffer;
|
||||
buf->data = copy;
|
||||
buf->length = proto_len;
|
||||
buf->data_deallocator = [](void* data, size_t length) { free(data); };
|
||||
buf->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
return buf;
|
||||
}
|
||||
|
||||
@ -694,7 +701,7 @@ TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
|
||||
TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
|
||||
|
||||
void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
|
||||
free(const_cast<void*>(lib_handle->op_list.data));
|
||||
tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data));
|
||||
delete lib_handle;
|
||||
}
|
||||
|
||||
@ -1704,6 +1711,7 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef __ANDROID__
|
||||
TF_Session* TF_LoadSessionFromSavedModel(
|
||||
const TF_SessionOptions* session_options, const TF_Buffer* run_options,
|
||||
const char* export_dir, const char* const* tags, int tags_len,
|
||||
@ -1757,6 +1765,7 @@ TF_Session* TF_LoadSessionFromSavedModel(
|
||||
session->last_num_graph_nodes = graph->graph.num_node_ids();
|
||||
return session;
|
||||
}
|
||||
#endif // __ANDROID__
|
||||
|
||||
void TF_CloseSession(TF_Session* s, TF_Status* status) {
|
||||
status->status = s->session->Close();
|
||||
|
@ -835,6 +835,10 @@ typedef struct TF_Session TF_Session;
|
||||
extern TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opts,
|
||||
TF_Status* status);
|
||||
|
||||
#ifndef __ANDROID__
|
||||
// TODO(ashankar): Remove the __ANDROID__ guard. This will require ensuring that
|
||||
// the tensorflow/cc/saved_model:loader build target is Android friendly.
|
||||
|
||||
// This function creates a new TF_Session (which is created on success) using
|
||||
// `session_options`, and then initializes state (restoring tensors and other
|
||||
// assets) using `run_options`.
|
||||
@ -853,6 +857,7 @@ TF_Session* TF_LoadSessionFromSavedModel(
|
||||
const TF_SessionOptions* session_options, const TF_Buffer* run_options,
|
||||
const char* export_dir, const char* const* tags, int tags_len,
|
||||
TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status);
|
||||
#endif // __ANDROID__
|
||||
|
||||
// Close a session.
|
||||
//
|
||||
|
@ -204,23 +204,23 @@ Status RewriteAndPruneGraph(Graph* graph, const Config& config,
|
||||
string feed_id;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFeedIdAttr, &feed_id));
|
||||
if (missing_feeds.erase(feed_id) == 0) {
|
||||
return errors::Aborted(kArgOp, " node found with unknown feed id: ",
|
||||
feed_id);
|
||||
return errors::Aborted(kArgOp,
|
||||
" node found with unknown feed id: ", feed_id);
|
||||
}
|
||||
} else if (n->type_string() == kRetvalOp) {
|
||||
string fetch_id;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFetchIdAttr, &fetch_id));
|
||||
if (missing_fetches.erase(fetch_id) == 0) {
|
||||
return errors::Aborted(kRetvalOp, " node found with unknown fetch id: ",
|
||||
fetch_id);
|
||||
return errors::Aborted(kRetvalOp,
|
||||
" node found with unknown fetch id: ", fetch_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!missing_feeds.empty() || !missing_fetches.empty()) {
|
||||
return errors::Aborted("Post graph-pruning", ", missing feeds: ",
|
||||
str_util::Join(missing_feeds, ", "),
|
||||
", missing fetches: ",
|
||||
str_util::Join(missing_fetches, ", "));
|
||||
return errors::Aborted(
|
||||
"Post graph-pruning",
|
||||
", missing feeds: ", str_util::Join(missing_feeds, ", "),
|
||||
", missing fetches: ", str_util::Join(missing_fetches, ", "));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -351,16 +351,19 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
|
||||
for (int i = 0; i < pshape->parameters_size(); ++i) {
|
||||
arg_layouts.push_back(pshape->mutable_parameters(i));
|
||||
}
|
||||
xla::StatusOr<std::unique_ptr<xla::AotCompilationResult>> aot_or =
|
||||
client->CompileAheadOfTime(computation, arg_layouts, pshape->result(),
|
||||
aot_opts);
|
||||
xla::LocalClient::AheadOfTimeComputationInstance instance;
|
||||
instance.computation = &computation;
|
||||
instance.argument_layouts = std::move(arg_layouts);
|
||||
instance.result_layout = &pshape->result();
|
||||
xla::StatusOr<std::vector<std::unique_ptr<xla::AotCompilationResult>>>
|
||||
aot_or = client->CompileAheadOfTime({instance}, aot_opts);
|
||||
if (!aot_or.ok()) {
|
||||
return errors::Unknown("XLA compilation failed: ",
|
||||
aot_or.status().error_message());
|
||||
}
|
||||
compile_result->aot =
|
||||
xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
|
||||
aot_or.ConsumeValueOrDie());
|
||||
std::move(aot_or.ValueOrDie().back()));
|
||||
compile_result->entry_point = aot_opts.entry_point_name();
|
||||
compile_result->pointer_size =
|
||||
xla::LocalClient::PointerSizeForTriple(aot_opts.triple());
|
||||
|
@ -18,6 +18,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -27,22 +30,18 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.platform import flags as flags_lib
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
flags = flags_lib
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string('out_dir', '',
|
||||
'Output directory for graphs, checkpoints and savers.')
|
||||
FLAGS = None
|
||||
|
||||
|
||||
def tfadd():
|
||||
def tfadd(_):
|
||||
x = constant_op.constant([1], name='x_const')
|
||||
y = constant_op.constant([2], name='y_const')
|
||||
math_ops.add(x, y, name='x_y_sum')
|
||||
|
||||
|
||||
def tfadd_with_ckpt():
|
||||
def tfadd_with_ckpt(out_dir):
|
||||
x = array_ops.placeholder(dtypes.int32, name='x_hold')
|
||||
y = variables.Variable(constant_op.constant([0]), name='y_saved')
|
||||
math_ops.add(x, y, name='x_y_sum')
|
||||
@ -53,11 +52,11 @@ def tfadd_with_ckpt():
|
||||
sess.run(init_op)
|
||||
sess.run(y.assign(y + 42))
|
||||
# Without the checkpoint, the variable won't be set to 42.
|
||||
ckpt = '%s/test_graph_tfadd_with_ckpt.ckpt' % FLAGS.out_dir
|
||||
ckpt = '%s/test_graph_tfadd_with_ckpt.ckpt' % out_dir
|
||||
saver.save(sess, ckpt)
|
||||
|
||||
|
||||
def tfadd_with_ckpt_saver():
|
||||
def tfadd_with_ckpt_saver(out_dir):
|
||||
x = array_ops.placeholder(dtypes.int32, name='x_hold')
|
||||
y = variables.Variable(constant_op.constant([0]), name='y_saved')
|
||||
math_ops.add(x, y, name='x_y_sum')
|
||||
@ -68,27 +67,27 @@ def tfadd_with_ckpt_saver():
|
||||
sess.run(init_op)
|
||||
sess.run(y.assign(y + 42))
|
||||
# Without the checkpoint, the variable won't be set to 42.
|
||||
ckpt_file = '%s/test_graph_tfadd_with_ckpt_saver.ckpt' % FLAGS.out_dir
|
||||
ckpt_file = '%s/test_graph_tfadd_with_ckpt_saver.ckpt' % out_dir
|
||||
saver.save(sess, ckpt_file)
|
||||
# Without the SaverDef, the restore op won't be named correctly.
|
||||
saver_file = '%s/test_graph_tfadd_with_ckpt_saver.saver' % FLAGS.out_dir
|
||||
saver_file = '%s/test_graph_tfadd_with_ckpt_saver.saver' % out_dir
|
||||
with open(saver_file, 'w') as f:
|
||||
f.write(saver.as_saver_def().SerializeToString())
|
||||
|
||||
|
||||
def tfgather():
|
||||
def tfgather(_):
|
||||
params = array_ops.placeholder(dtypes.float32, name='params')
|
||||
indices = array_ops.placeholder(dtypes.int32, name='indices')
|
||||
array_ops.gather(params, indices, name='gather_output')
|
||||
|
||||
|
||||
def tfmatmul():
|
||||
def tfmatmul(_):
|
||||
x = array_ops.placeholder(dtypes.float32, name='x_hold')
|
||||
y = array_ops.placeholder(dtypes.float32, name='y_hold')
|
||||
math_ops.matmul(x, y, name='x_y_prod')
|
||||
|
||||
|
||||
def tfmatmulandadd():
|
||||
def tfmatmulandadd(_):
|
||||
# This tests multiple outputs.
|
||||
x = array_ops.placeholder(dtypes.float32, name='x_hold')
|
||||
y = array_ops.placeholder(dtypes.float32, name='y_hold')
|
||||
@ -96,24 +95,33 @@ def tfmatmulandadd():
|
||||
math_ops.add(x, y, name='x_y_sum')
|
||||
|
||||
|
||||
def write_graph(build_graph):
|
||||
def write_graph(build_graph, out_dir):
|
||||
"""Build a graph using build_graph and write it out."""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
build_graph()
|
||||
filename = '%s/test_graph_%s.pb' % (FLAGS.out_dir, build_graph.__name__)
|
||||
build_graph(out_dir)
|
||||
filename = '%s/test_graph_%s.pb' % (out_dir, build_graph.__name__)
|
||||
with open(filename, 'w') as f:
|
||||
f.write(g.as_graph_def().SerializeToString())
|
||||
|
||||
|
||||
def main(_):
|
||||
write_graph(tfadd)
|
||||
write_graph(tfadd_with_ckpt)
|
||||
write_graph(tfadd_with_ckpt_saver)
|
||||
write_graph(tfgather)
|
||||
write_graph(tfmatmul)
|
||||
write_graph(tfmatmulandadd)
|
||||
write_graph(tfadd, FLAGS.out_dir)
|
||||
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
|
||||
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
|
||||
write_graph(tfgather, FLAGS.out_dir)
|
||||
write_graph(tfmatmul, FLAGS.out_dir)
|
||||
write_graph(tfmatmulandadd, FLAGS.out_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register('type', 'bool', lambda v: v.lower() == 'true')
|
||||
parser.add_argument(
|
||||
'--out_dir',
|
||||
type=str,
|
||||
default='',
|
||||
help='Output directory for graphs, checkpoints and savers.'
|
||||
)
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -41,12 +41,15 @@ const char* const kXlaClusterAttr = "_XlaCluster";
|
||||
|
||||
namespace {
|
||||
|
||||
bool HasXLAKernel(const NodeDef& node_def, DeviceType jit_device_type) {
|
||||
bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
|
||||
// _Send and _Recv should not be marked for compilation.
|
||||
if (node.IsSend() || node.IsRecv()) return false;
|
||||
|
||||
// There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
|
||||
// is really a kind of function call and will be handled by
|
||||
// IsCompilableCall().
|
||||
if (node_def.op() == "SymbolicGradient") return false;
|
||||
return FindKernelDef(jit_device_type, node_def, nullptr, nullptr).ok();
|
||||
if (node.type_string() == "SymbolicGradient") return false;
|
||||
return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok();
|
||||
}
|
||||
|
||||
// Make sure we don't recurse infinitely on recursive functions.
|
||||
@ -125,7 +128,7 @@ bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type,
|
||||
return IsCompilableWhile(node->def(), jit_device_type, depth + 1,
|
||||
lib_runtime);
|
||||
}
|
||||
if (!HasXLAKernel(node->def(), jit_device_type) &&
|
||||
if (!HasXLAKernel(*node, jit_device_type) &&
|
||||
!IsCompilableCall(node->def(), jit_device_type, depth + 1,
|
||||
lib_runtime)) {
|
||||
VLOG(2) << "Function marking failed: unsupported op " << node->name()
|
||||
@ -168,7 +171,7 @@ Status FindCompilationCandidates(
|
||||
CHECK(XlaOpRegistry::GetJitDevice(device_type.type(), &jit_device_name,
|
||||
/*requires_jit=*/nullptr));
|
||||
DeviceType jit_device_type(*jit_device_name);
|
||||
if (!HasXLAKernel(node->def(), jit_device_type) &&
|
||||
if (!HasXLAKernel(*node, jit_device_type) &&
|
||||
!IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime.get())) {
|
||||
VLOG(2) << "Compilation rejected node: unsupported op " << node->name()
|
||||
<< ": " << node->def().op();
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -41,7 +42,7 @@ void* XlaDeviceAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
|
||||
// Regardless of the size requested, always allocate a XlaGlobalData. Respect
|
||||
// the aligment request because there is alignment checking even for Tensors
|
||||
// whose data is never accessed.
|
||||
void* p = port::aligned_malloc(sizeof(XlaGlobalData), alignment);
|
||||
void* p = port::AlignedMalloc(sizeof(XlaGlobalData), alignment);
|
||||
VLOG(2) << "Allocated XLA device tensor " << p;
|
||||
return new (p) XlaGlobalData();
|
||||
}
|
||||
@ -50,7 +51,7 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
|
||||
XlaGlobalData* global_data = reinterpret_cast<XlaGlobalData*>(ptr);
|
||||
VLOG(2) << "Deallocated XLA device tensor " << ptr;
|
||||
global_data->~XlaGlobalData();
|
||||
port::aligned_free(ptr);
|
||||
port::AlignedFree(ptr);
|
||||
}
|
||||
|
||||
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
|
||||
|
@ -45,7 +45,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
|
||||
name_prefix, &device);
|
||||
if (!status.ok()) {
|
||||
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
||||
LOG(WARNING) << "Failed to create XLA_GPU device: " << status;
|
||||
VLOG(1) << "Failed to create XLA_GPU device: " << status;
|
||||
return Status::OK();
|
||||
}
|
||||
devices->push_back(device.release());
|
||||
|
@ -18,7 +18,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -32,29 +34,8 @@ from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import flags as flags_lib
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
flags = flags_lib
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_integer('batch_size', 128,
|
||||
'Inputs are fed in batches of this size, for both '
|
||||
'inference and training. Larger values cause the matmul '
|
||||
'in each LSTM cell to have higher dimensionality.')
|
||||
flags.DEFINE_integer('seq_length', 60,
|
||||
'Length of the unrolled sequence of LSTM cells in a layer.'
|
||||
'Larger values cause more LSTM matmuls to be run.')
|
||||
flags.DEFINE_integer('num_inputs', 1024,
|
||||
'Dimension of inputs that are fed into each LSTM cell.')
|
||||
flags.DEFINE_integer('num_nodes', 1024, 'Number of nodes in each LSTM cell.')
|
||||
flags.DEFINE_string('device', 'gpu',
|
||||
'TensorFlow device to assign ops to, e.g. "gpu", "cpu". '
|
||||
'For details see documentation for tf.Graph.device.')
|
||||
|
||||
flags.DEFINE_string('dump_graph_dir', '', 'If non-empty, dump graphs in '
|
||||
'*.pbtxt format to this directory.')
|
||||
|
||||
|
||||
def _DumpGraph(graph, basename):
|
||||
if FLAGS.dump_graph_dir:
|
||||
@ -290,4 +271,54 @@ class LSTMBenchmark(test.Benchmark):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register('type', 'bool', lambda v: v.lower() == 'true')
|
||||
parser.add_argument(
|
||||
'--batch_size',
|
||||
type=int,
|
||||
default=128,
|
||||
help="""\
|
||||
Inputs are fed in batches of this size, for both inference and training.
|
||||
Larger values cause the matmul in each LSTM cell to have higher
|
||||
dimensionality.\
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
'--seq_length',
|
||||
type=int,
|
||||
default=60,
|
||||
help="""\
|
||||
Length of the unrolled sequence of LSTM cells in a layer.Larger values
|
||||
cause more LSTM matmuls to be run.\
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
'--num_inputs',
|
||||
type=int,
|
||||
default=1024,
|
||||
help='Dimension of inputs that are fed into each LSTM cell.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--num_nodes',
|
||||
type=int,
|
||||
default=1024,
|
||||
help='Number of nodes in each LSTM cell.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--device',
|
||||
type=str,
|
||||
default='gpu',
|
||||
help="""\
|
||||
TensorFlow device to assign ops to, e.g. "gpu", "cpu". For details see
|
||||
documentation for tf.Graph.device.\
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dump_graph_dir',
|
||||
type=str,
|
||||
default='',
|
||||
help='If non-empty, dump graphs in *.pbtxt format to this directory.'
|
||||
)
|
||||
global FLAGS # pylint:disable=global-at-module-level
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
test.main(argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -89,6 +89,27 @@ cc_library(
|
||||
|
||||
# Internal targets below this point.
|
||||
|
||||
cc_test(
|
||||
name = "xla_compiler_test",
|
||||
srcs = ["xla_compiler_test.cc"],
|
||||
deps = [
|
||||
":xla_compiler",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:sendrecv_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "str_util_test",
|
||||
srcs = [
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/local_device.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -47,7 +48,7 @@ class XlaCompilationAllocator : public Allocator {
|
||||
// XlaExpression. Respect the aligment request because there is
|
||||
// alignment checking even for Tensors whose data is never
|
||||
// accessed.
|
||||
void* p = port::aligned_malloc(sizeof(XlaExpression), alignment);
|
||||
void* p = port::AlignedMalloc(sizeof(XlaExpression), alignment);
|
||||
XlaExpression* expression = reinterpret_cast<XlaExpression*>(p);
|
||||
new (expression) XlaExpression();
|
||||
return expression;
|
||||
@ -56,7 +57,7 @@ class XlaCompilationAllocator : public Allocator {
|
||||
void DeallocateRaw(void* ptr) override {
|
||||
XlaExpression* expression = reinterpret_cast<XlaExpression*>(ptr);
|
||||
expression->~XlaExpression();
|
||||
port::aligned_free(ptr);
|
||||
port::AlignedFree(ptr);
|
||||
}
|
||||
|
||||
// Make sure that even tensors with 0 elements have allocated
|
||||
|
@ -318,7 +318,7 @@ Status XlaCompiler::CompileGraph(string const& name,
|
||||
}
|
||||
|
||||
XlaContext* xla_context =
|
||||
new XlaContext(client(), name, allow_cpu_custom_calls_);
|
||||
new XlaContext(this, client(), name, allow_cpu_custom_calls_);
|
||||
core::ScopedUnref xla_context_unref(xla_context);
|
||||
|
||||
TF_RETURN_IF_ERROR(xla_context->BuildArguments(args, use_tuple_arg));
|
||||
@ -402,4 +402,15 @@ Status XlaCompiler::CompileGraph(string const& name,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaCompiler::GetChannelHandle(const string& key,
|
||||
xla::ChannelHandle* channel) {
|
||||
mutex_lock lock(mu_);
|
||||
auto result = channels_.emplace(key, xla::ChannelHandle());
|
||||
if (result.second) {
|
||||
TF_ASSIGN_OR_RETURN(result.first->second, client_->CreateChannelHandle());
|
||||
}
|
||||
*channel = result.first->second;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -172,6 +172,12 @@ class XlaCompiler {
|
||||
XlaCompilationDevice* device() const { return device_; }
|
||||
const DeviceMgr* device_mgr() const { return &device_mgr_; }
|
||||
|
||||
// Retrieves the channel handle associated with `key`. Allocates
|
||||
// a new channel handle if none exists.
|
||||
// Channel handles can be used to communicate between different computations.
|
||||
// Computations that communicate should be compiled with the same XlaCompiler.
|
||||
Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
|
||||
|
||||
private:
|
||||
// Does the real work of Compile() and CompileToComputation().
|
||||
Status CompileFunctionBody(FunctionLibraryRuntime* function_library,
|
||||
@ -195,6 +201,8 @@ class XlaCompiler {
|
||||
XlaCompilationDevice* device_; // Owned by device_mgr_
|
||||
DeviceMgr device_mgr_;
|
||||
|
||||
std::unordered_map<string, xla::ChannelHandle> channels_ GUARDED_BY(mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
|
||||
};
|
||||
|
||||
|
107
tensorflow/compiler/tf2xla/xla_compiler_test.cc
Normal file
107
tensorflow/compiler/tf2xla/xla_compiler_test.cc
Normal file
@ -0,0 +1,107 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class XlaCompilerTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
client_ = xla::ClientLibrary::LocalClientOrDie();
|
||||
|
||||
XlaCompiler::Options options;
|
||||
options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
|
||||
options.client = client_;
|
||||
compiler_.reset(new XlaCompiler(options));
|
||||
|
||||
XlaOpRegistry::RegisterJitKernels();
|
||||
|
||||
FunctionDefLibrary flib;
|
||||
flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
|
||||
flr_.reset(NewFunctionLibraryRuntime(
|
||||
compiler_->device_mgr(), /*env=*/nullptr, compiler_->device(),
|
||||
TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions(),
|
||||
/*custom_kernel_creator=*/nullptr));
|
||||
}
|
||||
|
||||
xla::Client* client_;
|
||||
std::unique_ptr<XlaCompiler> compiler_;
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
|
||||
std::unique_ptr<FunctionLibraryRuntime> flr_;
|
||||
};
|
||||
|
||||
TEST_F(XlaCompilerTest, Simple) {
|
||||
// Builds a graph that adds two Tensors.
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
|
||||
auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
|
||||
auto c = ops::Add(scope.WithOpName("C"), a, b);
|
||||
auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(scope.ToGraph(graph.get()));
|
||||
|
||||
// Builds a description of the arguments.
|
||||
std::vector<XlaCompiler::Argument> args(2);
|
||||
args[0].type = DT_INT32;
|
||||
args[0].shape = TensorShape({2});
|
||||
args[0].parameter = 0;
|
||||
args[1].type = DT_INT32;
|
||||
args[1].shape = TensorShape({2});
|
||||
args[1].parameter = 1;
|
||||
|
||||
// Compiles the graph.
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler_->CompileGraph("add", std::move(graph), flr_.get(),
|
||||
args, /*use_tuple_arg=*/false, &result));
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({7, 42});
|
||||
std::unique_ptr<xla::Literal> param1_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({-3, 101});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::GlobalData> param1_data =
|
||||
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::GlobalData> actual =
|
||||
client_
|
||||
->Execute(result.computation, {param0_data.get(), param1_data.get()})
|
||||
.ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::Literal> actual_literal =
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected_literal =
|
||||
xla::LiteralUtil::CreateR1<int32>({4, 143});
|
||||
xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -167,7 +167,7 @@ Status XlaContext::CollectResults(
|
||||
}
|
||||
}
|
||||
|
||||
if (handle.handle() > 0) {
|
||||
if (handle.handle() > 0 || has_side_effects_) {
|
||||
// Build the full computation. The return value is the handle
|
||||
// constructed above.
|
||||
xla::StatusOr<xla::Computation> computation_status = builder().Build();
|
||||
@ -190,9 +190,11 @@ Status XlaContext::CollectResults(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaContext::XlaContext(xla::Client* client, const string& computation_name,
|
||||
XlaContext::XlaContext(XlaCompiler* compiler, xla::Client* client,
|
||||
const string& computation_name,
|
||||
bool allow_cpu_custom_calls)
|
||||
: xla_builder_(client, computation_name),
|
||||
: compiler_(compiler),
|
||||
xla_builder_(client, computation_name),
|
||||
allow_cpu_custom_calls_(allow_cpu_custom_calls) {}
|
||||
|
||||
const xla::ComputationDataHandle&
|
||||
@ -233,6 +235,11 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void XlaContext::AddSideEffects() {
|
||||
mutex_lock lock(mu_);
|
||||
has_side_effects_ = true;
|
||||
}
|
||||
|
||||
/* static */ const XlaExpression* XlaContext::CastExpressionFromTensor(
|
||||
const Tensor& tensor) {
|
||||
const XlaExpression* expression =
|
||||
|
@ -68,7 +68,7 @@ class XlaExpression {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression);
|
||||
};
|
||||
|
||||
// The XlaContext is the datastructure accessible from
|
||||
// The XlaContext is the data structure accessible from
|
||||
// OpKernelContexts when evaluating a subgraph of Ops for JIT
|
||||
// compilation by XLA. When an Op is executed during JIT
|
||||
// compilation the input Tensors to the Op store handles to
|
||||
@ -132,8 +132,8 @@ class XlaContext : public ResourceBase {
|
||||
}
|
||||
|
||||
// Create a new XlaContext.
|
||||
XlaContext(xla::Client* client, const string& computation_name,
|
||||
bool allow_cpu_custom_calls);
|
||||
XlaContext(XlaCompiler* compiler, xla::Client* client,
|
||||
const string& computation_name, bool allow_cpu_custom_calls);
|
||||
|
||||
// Builds XLA computations for each of the arguments.
|
||||
// Should only be called once to initialize the arguments. Not thread-safe.
|
||||
@ -160,6 +160,9 @@ class XlaContext : public ResourceBase {
|
||||
Status AddConstRetval(int retval_index, DataType dtype,
|
||||
const xla::Literal& literal);
|
||||
|
||||
// Mark the computation as having side effects (i.e., Send operators).
|
||||
void AddSideEffects();
|
||||
|
||||
// Retrieves the ComputationDataHandle from an input Tensor to an Op. This
|
||||
// computation was constructed by an Op that executed previously and
|
||||
// created the output Tensor using CreateOutputTensorFromComputation
|
||||
@ -167,6 +170,8 @@ class XlaContext : public ResourceBase {
|
||||
static const xla::ComputationDataHandle& GetComputationFromTensor(
|
||||
const Tensor& tensor);
|
||||
|
||||
XlaCompiler* compiler() const { return compiler_; }
|
||||
|
||||
// Returns the ComputationBuilder that Ops use for compiling new
|
||||
// expressions.
|
||||
xla::ComputationBuilder& builder();
|
||||
@ -215,6 +220,8 @@ class XlaContext : public ResourceBase {
|
||||
// or CreateConstantOutputTensor.
|
||||
static const XlaExpression* GetExpressionFromTensor(const Tensor& tensor);
|
||||
|
||||
XlaCompiler* const compiler_;
|
||||
|
||||
mutable mutex mu_;
|
||||
|
||||
// The ComputationBuilder used to construct the subgraph's compiled
|
||||
@ -250,6 +257,9 @@ class XlaContext : public ResourceBase {
|
||||
// The non-data-dependent return values of the computation.
|
||||
std::vector<ConstRetVal> compile_time_constant_ GUARDED_BY(mu_);
|
||||
|
||||
// Does the computation have side effects, i.e., Send() calls?
|
||||
bool has_side_effects_ GUARDED_BY(mu_) = false;
|
||||
|
||||
// Cache of prebuilt computations indexed by their type.
|
||||
using ComputationMap = std::map<DataType, xla::Computation>;
|
||||
|
||||
|
@ -223,6 +223,10 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
|
||||
expression->set_constant_value(constant);
|
||||
}
|
||||
|
||||
void XlaOpKernelContext::SetOpHasSideEffects() {
|
||||
XlaContext::Get(context_).AddSideEffects();
|
||||
}
|
||||
|
||||
void XlaOpKernelContext::CtxFailure(Status s) { context_->CtxFailure(s); }
|
||||
void XlaOpKernelContext::CtxFailureWithWarning(Status s) {
|
||||
context_->CtxFailureWithWarning(s);
|
||||
|
@ -131,6 +131,9 @@ class XlaOpKernelContext {
|
||||
void SetStatus(const Status& status) { context_->SetStatus(status); }
|
||||
Status status() { return context_->status(); }
|
||||
|
||||
// Mark the op has having side effects (i.e., via Send).
|
||||
void SetOpHasSideEffects();
|
||||
|
||||
// Helper routines for the OP_REQUIRES macros
|
||||
void CtxFailure(Status s);
|
||||
void CtxFailureWithWarning(Status s);
|
||||
|
@ -314,12 +314,23 @@ tensorflow::Status LocalClient::ExecuteLocally(
|
||||
options, result);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<AotCompilationResult>> LocalClient::CompileAheadOfTime(
|
||||
const Computation& computation,
|
||||
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
|
||||
const Shape& result_layout, const AotCompilationOptions& options) {
|
||||
return local_service_->CompileAheadOfTime(
|
||||
computation.handle(), argument_layouts, result_layout, options);
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
LocalClient::CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
|
||||
computations,
|
||||
const AotCompilationOptions& options) {
|
||||
std::vector<LocalService::AheadOfTimeComputationInstance> service_instances;
|
||||
service_instances.reserve(computations.size());
|
||||
for (const AheadOfTimeComputationInstance& instance : computations) {
|
||||
service_instances.push_back({});
|
||||
LocalService::AheadOfTimeComputationInstance& service_instance =
|
||||
service_instances.back();
|
||||
TF_RET_CHECK(instance.computation != nullptr);
|
||||
service_instance.computation = instance.computation->handle();
|
||||
service_instance.argument_layouts = instance.argument_layouts;
|
||||
service_instance.result_layout = instance.result_layout;
|
||||
}
|
||||
return local_service_->CompileAheadOfTime(service_instances, options);
|
||||
}
|
||||
|
||||
int64 LocalClient::PointerSizeForTriple(tensorflow::StringPiece target_triple) {
|
||||
|
@ -219,19 +219,26 @@ class LocalClient : public Client {
|
||||
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
|
||||
const ExecutableBuildOptions& options);
|
||||
|
||||
// Compiles the computation for ahead-of-time execution. This is intended for
|
||||
// use in static compilation. The |argument_layouts| parameter is used to
|
||||
// inform the compiler of the expected layout for arguments while
|
||||
// |result_layout| is used to signal the layout of the result. The |options|
|
||||
// parameter is used to request which target the compiler should emit code
|
||||
// for.
|
||||
// A description of a computation to compile using CompileAheadOfTime.
|
||||
struct AheadOfTimeComputationInstance {
|
||||
const Computation* computation;
|
||||
// Inform the compiler of the expected layout for arguments.
|
||||
std::vector<const Shape*> argument_layouts;
|
||||
// Specifies the expected result layout.
|
||||
const Shape* result_layout;
|
||||
};
|
||||
|
||||
// Compiles a list of computations for ahead-of-time execution. This is
|
||||
// intended for use in static compilation. The |options| parameter describes
|
||||
// the target for which the compiler should emit code.
|
||||
//
|
||||
// TODO(b/31222190): This doesn't really belong in LocalClient. Move it to its
|
||||
// own library.
|
||||
StatusOr<std::unique_ptr<AotCompilationResult>> CompileAheadOfTime(
|
||||
const Computation& computation,
|
||||
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
|
||||
const Shape& result_layout, const AotCompilationOptions& options);
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
|
||||
computations,
|
||||
const AotCompilationOptions& options);
|
||||
|
||||
// Returns the size of a pointer in bytes for a given triple.
|
||||
static int64 PointerSizeForTriple(tensorflow::StringPiece triple);
|
||||
|
@ -360,4 +360,20 @@ tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src,
|
||||
}
|
||||
}
|
||||
|
||||
/* static */ bool LayoutUtil::AreDimensionsConsecutive(
|
||||
const Layout& layout, tensorflow::gtl::ArraySlice<int64> dims) {
|
||||
std::vector<int64> positions_in_layout;
|
||||
for (int64 dim : dims) {
|
||||
positions_in_layout.push_back(
|
||||
PositionInContainer(layout.minor_to_major(), dim));
|
||||
}
|
||||
std::sort(positions_in_layout.begin(), positions_in_layout.end());
|
||||
for (size_t i = 1; i < positions_in_layout.size(); ++i) {
|
||||
if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -144,6 +144,11 @@ class LayoutUtil {
|
||||
// except that the element type is ignored.
|
||||
static bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs);
|
||||
|
||||
// Returns whether the given dimensions are consecutive in the given layout,
|
||||
// not necessarily in the order given.
|
||||
static bool AreDimensionsConsecutive(const Layout& layout,
|
||||
tensorflow::gtl::ArraySlice<int64> dims);
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(LayoutUtil);
|
||||
};
|
||||
|
@ -136,6 +136,12 @@ class LiteralUtil {
|
||||
const Literal& literal, tensorflow::gtl::ArraySlice<int64> start_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> limit_indices);
|
||||
|
||||
// Creates a literal with a prepended dimension with bound "times"; e.g. a
|
||||
// f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from the input
|
||||
// literal replicated four times.
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> Replicate(const Literal& input, int64 times);
|
||||
|
||||
// Create a literal by converting each element in an original literal to a new
|
||||
// type.
|
||||
template <typename NativeSrcT, typename NativeDestT>
|
||||
@ -999,6 +1005,30 @@ LiteralUtil::CreateFullWithMonotonicDim0MajorLayout(
|
||||
return literal;
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::Replicate(
|
||||
const Literal& input, int64 times) {
|
||||
std::vector<int64> bounds = {times};
|
||||
bounds.insert(bounds.end(), input.shape().dimensions().begin(),
|
||||
input.shape().dimensions().end());
|
||||
auto literal = MakeUnique<Literal>();
|
||||
*literal->mutable_shape() =
|
||||
ShapeUtil::MakeShape(input.shape().element_type(), bounds);
|
||||
Reserve(ShapeUtil::ElementsIn(literal->shape()), literal.get());
|
||||
for (int64 index = 0; index < ShapeUtil::ElementsIn(input.shape()); ++index) {
|
||||
const std::vector<int64> element_indices =
|
||||
IndexUtil::LinearIndexToMultidimensionalIndex(input.shape(), index);
|
||||
const auto element = Get<NativeT>(input, element_indices);
|
||||
for (int64 sample = 0; sample < times; ++sample) {
|
||||
std::vector<int64> output_indices = {sample};
|
||||
output_indices.insert(output_indices.end(), element_indices.begin(),
|
||||
element_indices.end());
|
||||
Set<NativeT>(literal.get(), output_indices, element);
|
||||
}
|
||||
}
|
||||
return literal;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
|
||||
|
@ -749,11 +749,11 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
|
||||
TF_RET_CHECK(LayoutUtil::HasLayout(filter_shape));
|
||||
TF_RET_CHECK(LayoutUtil::HasLayout(convolution_shape));
|
||||
|
||||
// Require 1x1 filter in the spatial dimensions (so no need to extract image
|
||||
// patches).
|
||||
if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(0)) != 1 ||
|
||||
filter_shape.dimensions(dnums.kernel_spatial_dimensions(1)) != 1) {
|
||||
return Status::OK();
|
||||
// Require the spatial dimensions in the kernel to have a bound of one.
|
||||
for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
|
||||
if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
// Stride ignores part of the output, which matrix multiplication does not do,
|
||||
@ -782,9 +782,9 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
|
||||
input_shape.layout().minor_to_major(0) != dnums.feature_dimension() ||
|
||||
// The input feature dimension should come later in the minor-to-major
|
||||
// order.
|
||||
(PositionInContainer(AsInt64Slice(filter_shape.layout().minor_to_major()),
|
||||
(PositionInContainer(filter_shape.layout().minor_to_major(),
|
||||
dnums.kernel_input_feature_dimension()) <
|
||||
PositionInContainer(AsInt64Slice(filter_shape.layout().minor_to_major()),
|
||||
PositionInContainer(filter_shape.layout().minor_to_major(),
|
||||
dnums.kernel_output_feature_dimension()))) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -234,4 +234,8 @@ StatusOr<bool> Backend::devices_equivalent(int device_ordinal_a,
|
||||
executor_b->GetDeviceDescription().name());
|
||||
}
|
||||
|
||||
Status Backend::ResetDevices() {
|
||||
return transfer_manager_->ResetDevices(stream_executors_);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -149,6 +149,9 @@ class Backend {
|
||||
// used for scheduling work. For other platforms, returns NULL.
|
||||
const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const;
|
||||
|
||||
// Resets the devices associated with this backend.
|
||||
Status ResetDevices();
|
||||
|
||||
private:
|
||||
struct EigenThreadPoolWrapper;
|
||||
Backend(int64 replica_count, perftools::gputools::Platform* platform,
|
||||
|
@ -128,10 +128,11 @@ class Compiler {
|
||||
|
||||
// Compiles the HLO module for ahead-of-time execution. This is intended for
|
||||
// use in static compilation.
|
||||
virtual StatusOr<std::unique_ptr<AotCompilationResult>> CompileAheadOfTime(
|
||||
std::unique_ptr<HloModule> module,
|
||||
std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
|
||||
const AotCompilationOptions& options) = 0;
|
||||
virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(
|
||||
std::vector<std::unique_ptr<HloModule>> module,
|
||||
std::vector<std::unique_ptr<HloModuleConfig>> module_config,
|
||||
HloDumper dump_hlo, const AotCompilationOptions& options) = 0;
|
||||
|
||||
/////
|
||||
// The Compiler class also serves as a point to register compiler objects
|
||||
|
@ -478,10 +478,13 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> CpuCompiler::Compile(
|
||||
"Compilation of multiple HLO modules is not yet supported on CPU.");
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<AotCompilationResult>> CpuCompiler::CompileAheadOfTime(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
|
||||
const AotCompilationOptions& aot_options) {
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CpuCompiler::CompileAheadOfTime(
|
||||
std::vector<std::unique_ptr<HloModule>> hlo_modules,
|
||||
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
|
||||
HloDumper dump_hlo, const AotCompilationOptions& aot_options) {
|
||||
TF_RET_CHECK(hlo_modules.size() == module_configs.size());
|
||||
|
||||
if (aot_options.PlatformId() != se::host::kHostPlatformId) {
|
||||
return InvalidArgument("Incompatible AOT compilation platform");
|
||||
}
|
||||
@ -549,72 +552,78 @@ StatusOr<std::unique_ptr<AotCompilationResult>> CpuCompiler::CompileAheadOfTime(
|
||||
const llvm::DataLayout& data_layout = llvm_module.getDataLayout();
|
||||
int64 pointer_size = data_layout.getPointerSize();
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
RunHloPasses(hlo_module.get(), module_config.get(), dump_hlo));
|
||||
std::vector<std::unique_ptr<AotCompilationResult>> results;
|
||||
for (int i = 0; i < hlo_modules.size(); ++i) {
|
||||
HloModule* hlo_module = hlo_modules[i].get();
|
||||
HloModuleConfig* module_config = module_configs[i].get();
|
||||
|
||||
SequentialHloOrdering::HloModuleSequence module_sequence =
|
||||
CreateModuleSequence(hlo_module.get());
|
||||
// Run buffer analysis on the HLO graph. This analysis figures out which
|
||||
// temporary buffers are required to run the computation.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<BufferAssignment> assignment,
|
||||
BufferAssigner::Run(
|
||||
hlo_module.get(),
|
||||
MakeUnique<SequentialHloOrdering>(hlo_module.get(), module_sequence),
|
||||
pointer_size));
|
||||
TF_RETURN_IF_ERROR(RunHloPasses(hlo_module, module_config, dump_hlo));
|
||||
|
||||
IrEmitter ir_emitter(*hlo_module, *module_config, *assignment, &llvm_module,
|
||||
/*hlo_to_profile_idx=*/nullptr);
|
||||
HloComputation* computation = hlo_module->entry_computation();
|
||||
for (auto embedded_computation :
|
||||
computation->MakeEmbeddedComputationsList()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ir_emitter
|
||||
.EmitComputation(embedded_computation, embedded_computation->name(),
|
||||
/*is_entry_computation=*/false,
|
||||
&module_sequence.at(embedded_computation))
|
||||
.status());
|
||||
}
|
||||
const string& entry_point_name = options.entry_point_name();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
llvm::Function * entry_function,
|
||||
ir_emitter.EmitComputation(computation, entry_point_name,
|
||||
/*is_entry_computation=*/true));
|
||||
SequentialHloOrdering::HloModuleSequence module_sequence =
|
||||
CreateModuleSequence(hlo_module);
|
||||
// Run buffer analysis on the HLO graph. This analysis figures out which
|
||||
// temporary buffers are required to run the computation.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<BufferAssignment> assignment,
|
||||
BufferAssigner::Run(hlo_module, MakeUnique<SequentialHloOrdering>(
|
||||
hlo_module, module_sequence),
|
||||
pointer_size));
|
||||
|
||||
entry_function->setName(llvm_ir::AsStringRef(entry_point_name));
|
||||
|
||||
Disassembler disassembler(*target_machine);
|
||||
CompilerFunctor compiler_functor(target_machine.get(), &disassembler,
|
||||
opt_level, CompilerFunctor::AllIntrinsics());
|
||||
llvm::object::OwningBinary<llvm::object::ObjectFile> object_file =
|
||||
compiler_functor(llvm_module);
|
||||
llvm::StringRef object_file_data_ref = object_file.getBinary()->getData();
|
||||
ObjectFileData object_file_data(object_file_data_ref.begin(),
|
||||
object_file_data_ref.end());
|
||||
|
||||
BufferSizes buffer_sizes;
|
||||
for (const BufferAllocation& allocation : assignment->Allocations()) {
|
||||
// Callers don't need to allocate temporary buffers for parameters.
|
||||
if (allocation.is_entry_computation_parameter()) {
|
||||
buffer_sizes.push_back(-1);
|
||||
continue;
|
||||
IrEmitter ir_emitter(*hlo_module, *module_config, *assignment, &llvm_module,
|
||||
/*hlo_to_profile_idx=*/nullptr);
|
||||
HloComputation* computation = hlo_module->entry_computation();
|
||||
for (auto embedded_computation :
|
||||
computation->MakeEmbeddedComputationsList()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ir_emitter
|
||||
.EmitComputation(embedded_computation,
|
||||
embedded_computation->name(),
|
||||
/*is_entry_computation=*/false,
|
||||
&module_sequence.at(embedded_computation))
|
||||
.status());
|
||||
}
|
||||
// Callers don't need to allocate anything for thread-local temporary
|
||||
// buffers. They are lowered to allocas.
|
||||
if (allocation.is_thread_local()) {
|
||||
buffer_sizes.push_back(-1);
|
||||
continue;
|
||||
const string& entry_point_name = options.entry_point_name();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
llvm::Function * entry_function,
|
||||
ir_emitter.EmitComputation(computation, entry_point_name,
|
||||
/*is_entry_computation=*/true));
|
||||
|
||||
entry_function->setName(llvm_ir::AsStringRef(entry_point_name));
|
||||
|
||||
Disassembler disassembler(*target_machine);
|
||||
CompilerFunctor compiler_functor(target_machine.get(), &disassembler,
|
||||
opt_level,
|
||||
CompilerFunctor::AllIntrinsics());
|
||||
llvm::object::OwningBinary<llvm::object::ObjectFile> object_file =
|
||||
compiler_functor(llvm_module);
|
||||
llvm::StringRef object_file_data_ref = object_file.getBinary()->getData();
|
||||
ObjectFileData object_file_data(object_file_data_ref.begin(),
|
||||
object_file_data_ref.end());
|
||||
|
||||
BufferSizes buffer_sizes;
|
||||
for (const BufferAllocation& allocation : assignment->Allocations()) {
|
||||
// Callers don't need to allocate temporary buffers for parameters.
|
||||
if (allocation.is_entry_computation_parameter()) {
|
||||
buffer_sizes.push_back(-1);
|
||||
continue;
|
||||
}
|
||||
// Callers don't need to allocate anything for thread-local temporary
|
||||
// buffers. They are lowered to allocas.
|
||||
if (allocation.is_thread_local()) {
|
||||
buffer_sizes.push_back(-1);
|
||||
continue;
|
||||
}
|
||||
buffer_sizes.push_back(allocation.size());
|
||||
}
|
||||
buffer_sizes.push_back(allocation.size());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation,
|
||||
assignment->GetUniqueTopLevelOutputAllocation());
|
||||
|
||||
results.emplace_back(MakeUnique<CpuAotCompilationResult>(
|
||||
std::move(object_file_data), std::move(buffer_sizes),
|
||||
result_allocation->index()));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation,
|
||||
assignment->GetUniqueTopLevelOutputAllocation());
|
||||
|
||||
return std::unique_ptr<AotCompilationResult>(
|
||||
MakeUnique<CpuAotCompilationResult>(std::move(object_file_data),
|
||||
std::move(buffer_sizes),
|
||||
result_allocation->index()));
|
||||
return std::move(results);
|
||||
}
|
||||
|
||||
se::Platform::Id CpuCompiler::PlatformId() const {
|
||||
|
@ -123,10 +123,11 @@ class CpuCompiler : public Compiler {
|
||||
HloDumper dump_hlo,
|
||||
std::vector<perftools::gputools::StreamExecutor*> stream_exec) override;
|
||||
|
||||
StatusOr<std::unique_ptr<AotCompilationResult>> CompileAheadOfTime(
|
||||
std::unique_ptr<HloModule> module,
|
||||
std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
|
||||
const AotCompilationOptions& options) override;
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(
|
||||
std::vector<std::unique_ptr<HloModule>> module,
|
||||
std::vector<std::unique_ptr<HloModuleConfig>> module_config,
|
||||
HloDumper dump_hlo, const AotCompilationOptions& options) override;
|
||||
|
||||
perftools::gputools::Platform::Id PlatformId() const override;
|
||||
|
||||
|
@ -160,7 +160,9 @@ Status GenericTransferManager::TransferLiteralToInfeed(
|
||||
return Unimplemented("Infeed is not supported on GPU (b/30467474)");
|
||||
}
|
||||
|
||||
Status GenericTransferManager::ResetDevice(se::StreamExecutor* executor) {
|
||||
Status GenericTransferManager::ResetDevices(
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
||||
executors) {
|
||||
return Unimplemented(
|
||||
"Device reset is not yet supported on CPU and GPU (b/30481585)");
|
||||
}
|
||||
|
@ -55,7 +55,9 @@ class GenericTransferManager : public TransferManager {
|
||||
Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor,
|
||||
const Literal& literal) override;
|
||||
|
||||
Status ResetDevice(perftools::gputools::StreamExecutor* executor) override;
|
||||
Status ResetDevices(
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
||||
executors) override;
|
||||
|
||||
StatusOr<std::vector<perftools::gputools::DeviceMemoryBase>>
|
||||
ShallowCopyTupleFromDevice(
|
||||
|
@ -312,10 +312,11 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> GpuCompiler::Compile(
|
||||
"Compilation of multiple HLO modules is not yet supported on GPU.");
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<AotCompilationResult>> GpuCompiler::CompileAheadOfTime(
|
||||
std::unique_ptr<HloModule> module,
|
||||
std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
|
||||
const AotCompilationOptions& options) {
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
GpuCompiler::CompileAheadOfTime(
|
||||
std::vector<std::unique_ptr<HloModule>> module,
|
||||
std::vector<std::unique_ptr<HloModuleConfig>> module_config,
|
||||
HloDumper dump_hlo, const AotCompilationOptions& options) {
|
||||
return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime");
|
||||
}
|
||||
|
||||
|
@ -52,10 +52,11 @@ class GpuCompiler : public Compiler {
|
||||
HloDumper dump_hlo,
|
||||
std::vector<perftools::gputools::StreamExecutor*> stream_exec) override;
|
||||
|
||||
StatusOr<std::unique_ptr<AotCompilationResult>> CompileAheadOfTime(
|
||||
std::unique_ptr<HloModule> module,
|
||||
std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
|
||||
AotCompilationOptions const& options) override;
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(
|
||||
std::vector<std::unique_ptr<HloModule>> module,
|
||||
std::vector<std::unique_ptr<HloModuleConfig>> module_config,
|
||||
HloDumper dump_hlo, AotCompilationOptions const& options) override;
|
||||
|
||||
perftools::gputools::Platform::Id PlatformId() const override;
|
||||
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "external/llvm/include/llvm/IR/Module.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
@ -121,8 +122,22 @@ bool IsReductionToVector(const HloInstruction& reduce) {
|
||||
return false;
|
||||
}
|
||||
const HloInstruction* input = reduce.operand(0);
|
||||
return ShapeUtil::Rank(input->shape()) > 1 &&
|
||||
ShapeUtil::Rank(reduce.shape()) == 1;
|
||||
std::vector<int64> dims_to_keep;
|
||||
for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) {
|
||||
if (!std::count(reduce.dimensions().begin(), reduce.dimensions().end(),
|
||||
dim)) {
|
||||
dims_to_keep.push_back(dim);
|
||||
}
|
||||
}
|
||||
return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
|
||||
dims_to_keep) &&
|
||||
ShapeUtil::Equal(reduce.shape(), ShapeUtil::FilterDimensions(
|
||||
[&dims_to_keep](int64 dim) {
|
||||
return std::count(
|
||||
dims_to_keep.begin(),
|
||||
dims_to_keep.end(), dim);
|
||||
},
|
||||
input->shape()));
|
||||
}
|
||||
|
||||
// This emits a device-side call to
|
||||
|
@ -1047,8 +1047,9 @@ Status IrEmitterUnnested::EmitRowReduction(
|
||||
// Figures out whether `reduce` is a row or column reduction, and which
|
||||
// dimensions to reduce, and calls either `EmitRowReduction` or
|
||||
// `EmitColumnReduction` as appropriate.
|
||||
// Prerequisite: the shape of `reduce` has rank 1 and, if `reduce` is fused, the
|
||||
// fused subgraph is pure elementwise.
|
||||
// Prerequisite: all the dimensions to keep are contiguous in the input layout
|
||||
// and, if `reduce` is fused, the fused subgraph is pure
|
||||
// elementwise.
|
||||
Status IrEmitterUnnested::EmitReductionToVector(
|
||||
HloInstruction* reduce, const Shape& input_shape,
|
||||
const llvm_ir::ElementGenerator& input_gen,
|
||||
@ -1063,25 +1064,39 @@ Status IrEmitterUnnested::EmitReductionToVector(
|
||||
<< reduce->ToString();
|
||||
|
||||
// Specialize multi-dimensional-array-to-vector reduction.
|
||||
//
|
||||
// TODO(b/33239522): we could use the same algorithm for general reduction
|
||||
// as long as the input dimensions to keep are adjacent in the layout and
|
||||
// have the same relative layout as their corresponding output dimensions.
|
||||
// For example, reducing shape [2,3,4,5] with minor_to_major={2,0,1,3} to
|
||||
// shape [2,4] with minor_to_major={1,0} can be implemented as a column
|
||||
// reduction from shape [15,8] to shape [8].
|
||||
int64 input_dim_to_keep = -1;
|
||||
std::vector<int64> input_dims_to_keep;
|
||||
for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape);
|
||||
++input_dim) {
|
||||
if (std::find(dimensions_to_reduce.begin(), dimensions_to_reduce.end(),
|
||||
input_dim) == dimensions_to_reduce.end()) {
|
||||
input_dim_to_keep = input_dim;
|
||||
break;
|
||||
input_dims_to_keep.push_back(input_dim);
|
||||
}
|
||||
}
|
||||
CHECK_NE(-1, input_dim_to_keep);
|
||||
|
||||
if (LayoutUtil::Minor(input_shape.layout(), 0) == input_dim_to_keep) {
|
||||
// Sort the dimensions to keep from minor to major, to facilitate checking
|
||||
// whether another dimension is major or minor of them.
|
||||
std::sort(input_dims_to_keep.begin(), input_dims_to_keep.end(),
|
||||
[&input_shape](int64 dim_a, int64 dim_b) {
|
||||
return PositionInContainer(input_shape.layout().minor_to_major(),
|
||||
dim_a) <
|
||||
PositionInContainer(input_shape.layout().minor_to_major(),
|
||||
dim_b);
|
||||
});
|
||||
// Now, if output rank is at least 1, `input_dims_to_keep.front()` is
|
||||
// minormost and `input_dims_to_keep.back()` is majormost.
|
||||
|
||||
// If the dimensions to keep are minormost, emit a column reduction. As all
|
||||
// the dimensions to keep are contiguous, by prerequisite of
|
||||
// `EmitReductionToVector`, we only need to check whether the minormost
|
||||
// dimension of the input is to keep.
|
||||
//
|
||||
// If the output is scalar, we could emit either a row or a column reduction.
|
||||
// Some tests have shown scalar reduction is no more efficient as row
|
||||
// reduction, and is simpler to emit as column reduction, so we emit a column
|
||||
// reduction in this case.
|
||||
if (input_dims_to_keep.empty() ||
|
||||
input_dims_to_keep.front() ==
|
||||
LayoutUtil::Minor(input_shape.layout(), 0)) {
|
||||
// Column reduction. Treat the result of "input" as a matrix whose width
|
||||
// is the most minor dimension and height the product of other dimensions,
|
||||
// and treat "reduce" as a column reduction of the input matrix.
|
||||
@ -1091,7 +1106,8 @@ Status IrEmitterUnnested::EmitReductionToVector(
|
||||
int64 height = 1;
|
||||
for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape);
|
||||
++input_dim) {
|
||||
if (input_dim != input_dim_to_keep) {
|
||||
if (!std::count(input_dims_to_keep.begin(), input_dims_to_keep.end(),
|
||||
input_dim)) {
|
||||
height *= input_shape.dimensions(input_dim);
|
||||
}
|
||||
}
|
||||
@ -1108,22 +1124,19 @@ Status IrEmitterUnnested::EmitReductionToVector(
|
||||
int64 width = 1;
|
||||
for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape);
|
||||
++input_dim) {
|
||||
if (PositionInContainer(
|
||||
AsInt64Slice(input_shape.layout().minor_to_major()), input_dim) >
|
||||
PositionInContainer(
|
||||
AsInt64Slice(input_shape.layout().minor_to_major()),
|
||||
input_dim_to_keep)) {
|
||||
if (PositionInContainer(input_shape.layout().minor_to_major(),
|
||||
input_dim) >
|
||||
PositionInContainer(input_shape.layout().minor_to_major(),
|
||||
input_dims_to_keep.back())) {
|
||||
depth *= input_shape.dimensions(input_dim);
|
||||
} else if (PositionInContainer(
|
||||
AsInt64Slice(input_shape.layout().minor_to_major()),
|
||||
input_dim) <
|
||||
PositionInContainer(
|
||||
AsInt64Slice(input_shape.layout().minor_to_major()),
|
||||
input_dim_to_keep)) {
|
||||
} else if (PositionInContainer(input_shape.layout().minor_to_major(),
|
||||
input_dim) <
|
||||
PositionInContainer(input_shape.layout().minor_to_major(),
|
||||
input_dims_to_keep.front())) {
|
||||
width *= input_shape.dimensions(input_dim);
|
||||
}
|
||||
}
|
||||
int64 height = input_shape.dimensions(input_dim_to_keep);
|
||||
const int64 height = ShapeUtil::ElementsIn(reduce->shape());
|
||||
return EmitRowReduction(depth, height, width, reduce, input_shape,
|
||||
input_gen, init_value_gen, reducer);
|
||||
}
|
||||
|
@ -206,42 +206,49 @@ tensorflow::Status LocalService::ExecuteLocally(
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<AotCompilationResult>>
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
LocalService::CompileAheadOfTime(
|
||||
const ComputationHandle& computation,
|
||||
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
|
||||
const Shape& result_layout, const AotCompilationOptions& options) {
|
||||
TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
|
||||
computation_tracker_.Resolve(computation));
|
||||
VersionedComputationHandle versioned_handle =
|
||||
user_computation->GetVersionedHandle();
|
||||
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
|
||||
computations,
|
||||
const AotCompilationOptions& options) {
|
||||
std::vector<std::unique_ptr<HloModule>> hlo_modules;
|
||||
std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
|
||||
for (const AheadOfTimeComputationInstance& instance : computations) {
|
||||
TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
|
||||
computation_tracker_.Resolve(instance.computation));
|
||||
VersionedComputationHandle versioned_handle =
|
||||
user_computation->GetVersionedHandle();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
computation_tracker_.BuildHloModule(versioned_handle,
|
||||
/*include_unused_parameters=*/true));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
|
||||
computation_tracker_.BuildHloModule(
|
||||
versioned_handle,
|
||||
/*include_unused_parameters=*/true));
|
||||
hlo_modules.push_back(std::move(hlo_module));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::shared_ptr<const ProgramShape> program_shape,
|
||||
user_computation->ComputeProgramShape(versioned_handle.version));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::shared_ptr<const ProgramShape> program_shape,
|
||||
user_computation->ComputeProgramShape(versioned_handle.version));
|
||||
|
||||
auto module_config = MakeUnique<HloModuleConfig>(*program_shape);
|
||||
auto* computation_layout = module_config->mutable_entry_computation_layout();
|
||||
for (int i = 0; i < argument_layouts.size(); ++i) {
|
||||
const Shape& argument_layout = *argument_layouts[i];
|
||||
if (ShapeUtil::IsTuple(argument_layout)) {
|
||||
return Unimplemented("tuple arguments not supported yet");
|
||||
module_configs.push_back(MakeUnique<HloModuleConfig>(*program_shape));
|
||||
HloModuleConfig* module_config = module_configs.back().get();
|
||||
auto* computation_layout =
|
||||
module_config->mutable_entry_computation_layout();
|
||||
for (int i = 0; i < instance.argument_layouts.size(); ++i) {
|
||||
const Shape& argument_layout = *instance.argument_layouts[i];
|
||||
if (ShapeUtil::IsTuple(argument_layout)) {
|
||||
return Unimplemented("tuple arguments not supported yet");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
|
||||
argument_layout));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
|
||||
argument_layout));
|
||||
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
|
||||
*instance.result_layout));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
|
||||
result_layout));
|
||||
|
||||
return execute_backend_->compiler()
|
||||
->CompileAheadOfTime(std::move(hlo_module), std::move(module_config),
|
||||
->CompileAheadOfTime(std::move(hlo_modules), std::move(module_configs),
|
||||
MakeHloDumper(), options)
|
||||
.ConsumeValueOrDie();
|
||||
}
|
||||
@ -426,8 +433,9 @@ StatusOr<std::unique_ptr<ShapedBuffer>> LocalService::ExecuteLocallyInternal(
|
||||
} else {
|
||||
se::StreamExecutor* stream_executor;
|
||||
if (options.device_ordinal() >= 0) {
|
||||
TF_ASSIGN_OR_RETURN(stream_executor, execute_backend_->stream_executor(
|
||||
options.device_ordinal()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
stream_executor,
|
||||
execute_backend_->stream_executor(options.device_ordinal()));
|
||||
} else {
|
||||
stream_executor = execute_backend_->default_stream_executor();
|
||||
}
|
||||
|
@ -139,13 +139,21 @@ class LocalService : public Service {
|
||||
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||
const LocalExecuteOptions& options, ShapedBuffer* result_buffer);
|
||||
|
||||
// Compiles the computation for ahead-of-time execution. This is intended for
|
||||
// use in static compilation. See |LocalClient::CompileAheadOfTime| for
|
||||
// additional details.
|
||||
StatusOr<std::unique_ptr<AotCompilationResult>> CompileAheadOfTime(
|
||||
const ComputationHandle& computation,
|
||||
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
|
||||
const Shape& result_layout, const AotCompilationOptions& Options);
|
||||
// A description of a computation to compile using CompileAheadOfTime.
|
||||
struct AheadOfTimeComputationInstance {
|
||||
ComputationHandle computation;
|
||||
std::vector<const Shape*> argument_layouts;
|
||||
const Shape* result_layout = nullptr;
|
||||
};
|
||||
|
||||
// Compiles a list of computations for ahead-of-time execution. This is
|
||||
// intended for use in static compilation. See
|
||||
// |LocalClient::CompileAheadOfTime| for additional details.
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
|
||||
computations,
|
||||
const AotCompilationOptions& Options);
|
||||
|
||||
// Builds an Executable with the given argument layouts and options. If
|
||||
// result_layout is non-null, then the executable is compiled to produce a
|
||||
|
@ -1019,16 +1019,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
|
||||
|
||||
tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg,
|
||||
ResetDeviceResponse* result) {
|
||||
int first_device_ordinal = arg->has_device_handle()
|
||||
? arg->device_handle().handle()
|
||||
: execute_backend_->default_device_ordinal();
|
||||
TF_ASSIGN_OR_RETURN(auto executors,
|
||||
execute_backend_->Replicas(first_device_ordinal));
|
||||
for (se::StreamExecutor* executor : executors) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
execute_backend_->transfer_manager()->ResetDevice(executor));
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
return execute_backend_->ResetDevices();
|
||||
}
|
||||
|
||||
tensorflow::Status Service::TransferToClientInProcess(
|
||||
|
@ -162,7 +162,15 @@ class Service : public ServiceInterface {
|
||||
const TransferToInfeedRequest* arg,
|
||||
TransferToInfeedResponse* result) override;
|
||||
|
||||
// Resets the device, clearing all existing state on the device.
|
||||
// Resets devices, clearing all existing state on all the devices associated
|
||||
// with this service (including memory allocated on the devices).
|
||||
//
|
||||
// ResetDevice may only be called where no previous Execution state on the
|
||||
// device is used by the next Execution.
|
||||
//
|
||||
// ResetDevice should be called before an Execution that expect the device to
|
||||
// be in the reset state. For example, if the prior Execution modifies device
|
||||
// state (e.g., architectural state) that the next Execution depends on.
|
||||
tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
|
||||
ResetDeviceResponse* result) override;
|
||||
|
||||
|
@ -1319,9 +1319,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
// Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
|
||||
// we need output[i]=input[dimensions[i]] which is
|
||||
// Permute(Inverse(dimensions),input).
|
||||
return ShapeUtil::MakeShape(operand.element_type(),
|
||||
Permute(InversePermutation(dimensions),
|
||||
AsInt64Slice(operand.dimensions())));
|
||||
return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand);
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
|
||||
|
@ -1125,8 +1125,8 @@ TEST_F(ShapeInferenceTest, Transpose) {
|
||||
ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0});
|
||||
EXPECT_IS_OK(inferred_shape_and_status);
|
||||
Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
|
||||
EXPECT_TRUE(ShapeUtil::Equal(inferred_shape,
|
||||
ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
|
||||
EXPECT_TRUE(ShapeUtil::Compatible(inferred_shape,
|
||||
ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
@ -63,8 +64,10 @@ class TransferManager {
|
||||
perftools::gputools::StreamExecutor* executor,
|
||||
const Literal& literal) = 0;
|
||||
|
||||
// Resets the device that the given executor runs on.
|
||||
virtual Status ResetDevice(perftools::gputools::StreamExecutor* executor) = 0;
|
||||
// Resets the devices associated with this transfer manager.
|
||||
virtual Status ResetDevices(
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
||||
executor) = 0;
|
||||
|
||||
// Shallow copy a tuple from the device and create a DeviceMemoryBase object
|
||||
// for each element in the tuple. A DeviceMemoryBase object refers to the
|
||||
|
@ -984,4 +984,38 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
||||
check_input_unit_indices(output_shape, input_shape);
|
||||
}
|
||||
|
||||
/* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete,
|
||||
Shape shape) {
|
||||
shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete);
|
||||
if (LayoutUtil::HasLayout(shape)) {
|
||||
Layout* layout = shape.mutable_layout();
|
||||
for (size_t i = 0; i < layout->minor_to_major().size();) {
|
||||
if (layout->minor_to_major(i) == dim_to_delete) {
|
||||
layout->mutable_minor_to_major()->erase(
|
||||
layout->minor_to_major().begin() + i);
|
||||
continue;
|
||||
}
|
||||
if (layout->minor_to_major(i) > dim_to_delete) {
|
||||
(*layout->mutable_minor_to_major())[i] -= 1;
|
||||
}
|
||||
++i;
|
||||
}
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
/* static */ Shape ShapeUtil::FilterDimensions(
|
||||
const std::function<bool(int64)>& p, Shape shape) {
|
||||
std::vector<int64> dims_to_delete;
|
||||
for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) {
|
||||
if (!p(i)) {
|
||||
dims_to_delete.push_back(i);
|
||||
}
|
||||
}
|
||||
for (int64 dim : dims_to_delete) {
|
||||
shape = DeleteDimension(dim, shape);
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -374,6 +374,19 @@ class ShapeUtil {
|
||||
static bool ReshapeIsBitcast(const Shape& input_shape,
|
||||
const Shape& output_shape);
|
||||
|
||||
// Returns a shape with the given dimension deleted.
|
||||
// For example:
|
||||
// • `DeleteDimension(1, T[m, n, k]) = T[m, k]`
|
||||
static Shape DeleteDimension(int64 dim_to_delete, Shape shape);
|
||||
|
||||
// Returns a shape with all the dimensions of the input shape for which `p`
|
||||
// returns true.
|
||||
// For examples:
|
||||
// • `FilterDimensions((< 2), T[m, n, k]) = T[m, n]`
|
||||
// • `FilterDimensions(is_even_number, T[m, n, k]) = T[m, k]`
|
||||
static Shape FilterDimensions(const std::function<bool(int64)>& p,
|
||||
Shape shape);
|
||||
|
||||
private:
|
||||
// Recursive helper for comparing the equality of two shapes. Returns true if
|
||||
// the shapes are the same. If compare_layouts is true, then layouts must also
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
@ -72,16 +73,19 @@ int main(int argc, char** argv) {
|
||||
|
||||
llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string));
|
||||
|
||||
xla::Computation computation = builder.Build().ConsumeValueOrDie();
|
||||
xla::LocalClient::AheadOfTimeComputationInstance instance{
|
||||
&computation, /*argument_layouts=*/{&opaque_shape}, &r0f32};
|
||||
|
||||
xla::cpu::CpuAotCompilationOptions options(
|
||||
triple_string,
|
||||
/*cpu_name=*/"", /*features=*/"", "SumAndDouble",
|
||||
xla::cpu::CpuAotCompilationOptions::RelocationModel::Static);
|
||||
|
||||
auto results =
|
||||
client->CompileAheadOfTime({instance}, options).ConsumeValueOrDie();
|
||||
auto result = xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
|
||||
client
|
||||
->CompileAheadOfTime(builder.Build().ValueOrDie(),
|
||||
/*argument_layouts=*/{&opaque_shape}, r0f32,
|
||||
options)
|
||||
.ConsumeValueOrDie());
|
||||
std::move(results.front()));
|
||||
// We should have two buffers, one for the result and one temporary buffer,
|
||||
// and both should be float-sized. It's lame to hard-code this, but we need
|
||||
// local_client_aot_test.cc to be able to easily invoke the function.
|
||||
|
@ -176,12 +176,6 @@ std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,
|
||||
return output;
|
||||
}
|
||||
|
||||
int64 PositionInContainer(tensorflow::gtl::ArraySlice<int64> container,
|
||||
int64 value) {
|
||||
return std::find(container.begin(), container.end(), value) -
|
||||
container.begin();
|
||||
}
|
||||
|
||||
PaddingConfig MakeNoPaddingConfig(int64 rank) {
|
||||
PaddingConfig padding_config;
|
||||
for (int64 dnum = 0; dnum < rank; ++dnum) {
|
||||
|
@ -183,8 +183,11 @@ std::vector<int64> InversePermutation(
|
||||
std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,
|
||||
tensorflow::gtl::ArraySlice<int64> p2);
|
||||
|
||||
int64 PositionInContainer(tensorflow::gtl::ArraySlice<int64> container,
|
||||
int64 value);
|
||||
template <typename Container>
|
||||
int64 PositionInContainer(const Container& container, int64 value) {
|
||||
return std::distance(container.begin(),
|
||||
std::find(container.begin(), container.end(), value));
|
||||
}
|
||||
|
||||
// Returns a PaddingConfig object that represents no padding for the given rank.
|
||||
PaddingConfig MakeNoPaddingConfig(int64 rank);
|
||||
|
@ -33,6 +33,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/java/src/main/native",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -1,6 +1,10 @@
|
||||
TensorFlow-Android-Inference
|
||||
============================
|
||||
Android Java interface to the TensorFlow native APIs
|
||||
This directory contains CMake support for building the Android Java Inference
|
||||
interface to the TensorFlow native APIs.
|
||||
|
||||
See [tensorflow/contrib/android](..) for more details about the library, and
|
||||
instructions for building with Bazel.
|
||||
|
||||
Usage
|
||||
-----
|
||||
@ -24,9 +28,9 @@ Note: this makes native code in the lib traceable from your app.
|
||||
|
||||
Dependencies
|
||||
------------
|
||||
TensorFlow-Android-Inference depends on the TensorFlow static libs already built in your
|
||||
local TensorFlow repo directory. For Linux/Mac OS, build_all_android.sh is used
|
||||
in build.gradle to build it. It DOES take time to build the core libs;
|
||||
TensorFlow-Android-Inference depends on the TensorFlow static libs already built
|
||||
in your local TensorFlow repo directory. For Linux/Mac OS, build_all_android.sh
|
||||
is used in build.gradle to build it. It DOES take time to build the core libs;
|
||||
so, by default, it is commented out to avoid confusion (otherwise
|
||||
Android Studio would appear to hang during opening the project).
|
||||
To enable it, refer to the comment in
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
||||
@ -31,12 +32,8 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import flags
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
flags.DEFINE_integer("batch_size", 64, "batch size.")
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class CudnnRNNBenchmark(test.Benchmark):
|
||||
"""Benchmarks Cudnn LSTM and other related models.
|
||||
|
@ -0,0 +1,98 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_PLATFORM_HEXAGON_SOC_INTERFACE_H_
|
||||
#define TENSORFLOW_PLATFORM_HEXAGON_SOC_INTERFACE_H_
|
||||
|
||||
#include <inttypes.h>
|
||||
|
||||
// Declaration of APIs provided by hexagon shared library. This header is shared
|
||||
// with both hexagon library built with qualcomm SDK and tensorflow.
|
||||
// All functions defined here must have prefix "soc_interface" to avoid
|
||||
// naming conflicts.
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#else
|
||||
#include <stdbool.h>
|
||||
#endif // __cplusplus
|
||||
// Returns the version of loaded hexagon wrapper shared library.
|
||||
// You should assert that the version matches the expected version before
|
||||
// calling APIs defined in this header.
|
||||
int soc_interface_GetWrapperVersion();
|
||||
// Returns the version of hexagon binary.
|
||||
// You should assert that the version matches the expected version before
|
||||
// calling APIs defined in this header.
|
||||
int soc_interface_GetSocControllerVersion();
|
||||
// Initialize SOC
|
||||
bool soc_interface_Init();
|
||||
// Finalize SOC
|
||||
bool soc_interface_Finalize();
|
||||
// Execute graph on SOC
|
||||
bool soc_interface_ExecuteGraph();
|
||||
// Teardown graph setup
|
||||
bool soc_interface_TeardownGraph();
|
||||
// Send input data to SOC
|
||||
bool soc_interface_FillInputNodeFloat(int x, int y, int z, int d,
|
||||
const uint8_t* const buf,
|
||||
uint64_t buf_size);
|
||||
// Load output data from SOC
|
||||
bool soc_interface_ReadOutputNodeFloat(const char* const node_name,
|
||||
uint8_t** buf, uint64_t* buf_size);
|
||||
// Setup graph
|
||||
// TODO(satok): Remove and use runtime version
|
||||
bool soc_interface_setupDummyGraph(int version);
|
||||
|
||||
// Allocate memory for params of node inputs and node outputs
|
||||
bool soc_interface_AllocateNodeInputAndNodeOutputArray(int total_input_count,
|
||||
int total_output_count);
|
||||
|
||||
// Release memory for params of node inputs and node outputs
|
||||
bool soc_interface_ReleaseNodeInputAndNodeOutputArray();
|
||||
|
||||
// Set one node's inputs and return pointer to that struct
|
||||
void* soc_interface_SetOneNodeInputs(int input_count, const int* const node_id,
|
||||
const int* const port);
|
||||
|
||||
// Set one node's outputs and return pointer to that struct
|
||||
void* soc_interface_SetOneNodeOutputs(int output_count, int* max_size);
|
||||
|
||||
// Append const node to the graph
|
||||
bool soc_interface_AppendConstNode(const char* const name, int node_id,
|
||||
int batch, int height, int width, int depth,
|
||||
const uint8_t* const data, int data_length);
|
||||
|
||||
// Append node to the graph
|
||||
bool soc_interface_AppendNode(const char* const name, int node_id, int op_id,
|
||||
int padding_id, const void* const inputs,
|
||||
int inputs_count, const void* const outputs,
|
||||
int outputs_count);
|
||||
|
||||
// Instantiate graph
|
||||
bool soc_interface_InstantiateGraph();
|
||||
|
||||
// Construct graph
|
||||
bool soc_interface_ConstructGraph();
|
||||
|
||||
// Set log level
|
||||
void soc_interface_SetLogLevel(int log_level);
|
||||
|
||||
// Set debug flag
|
||||
void soc_interface_SetDebugFlag(uint64_t flag);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_PLATFORM_HEXAGON_SOC_INTERFACE_H_
|
124
tensorflow/contrib/hvx/hexagon_controller/src_soc_interface/soc_interface.c
Executable file
124
tensorflow/contrib/hvx/hexagon_controller/src_soc_interface/soc_interface.c
Executable file
@ -0,0 +1,124 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "soc_interface.h"
|
||||
|
||||
int soc_interface_GetWrapperVersion() {
|
||||
// TODO(satok): implement
|
||||
return -1;
|
||||
}
|
||||
|
||||
int soc_interface_GetSocControllerVersion() {
|
||||
// TODO(satok): implement
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool soc_interface_Init() {
|
||||
// TODO(satok): implement
|
||||
return false;
|
||||
}
|
||||
|
||||
bool soc_interface_Finalize() {
|
||||
// TODO(satok): implement
|
||||
return false;
|
||||
}
|
||||
|
||||
bool soc_interface_ExecuteGraph() {
|
||||
// TODO(satok): implement
|
||||
return false;
|
||||
}
|
||||
|
||||
bool soc_interface_TeardownGraph() {
|
||||
// TODO(satok): implement
|
||||
return false;
|
||||
}
|
||||
|
||||
bool soc_interface_FillInputNodeFloat(
|
||||
int x, int y, int z, int d, const uint8_t* const buf, uint64_t buf_size) {
|
||||
// TODO(satok): implement
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO(satok): Remove and use runtime version
|
||||
bool soc_interface_ReadOutputNodeFloat(
|
||||
const char* const node_name, uint8_t** buf, uint64_t *buf_size) {
|
||||
// TODO(satok): implement
|
||||
return false;
|
||||
}
|
||||
|
||||
bool soc_interface_SetupGraphDummy(int version) {
|
||||
// TODO(satok): implement
|
||||
return false;
|
||||
}
|
||||
|
||||
bool soc_interface_AllocateNodeInputAndNodeOutputArray(
|
||||
int total_input_count, int total_output_count) {
|
||||
// TODO(satok): implement
|
||||
return false;
|
||||
}
|
||||
|
||||
bool soc_interface_ReleaseNodeInputAndNodeOutputArray() {
|
||||
// TODO(satok): implement
|
||||
return false;
|
||||
}
|
||||
|
||||
void* soc_interface_SetOneNodeInputs(
|
||||
int input_count, const int* const node_id, const int* const port) {
|
||||
// TODO(satok): implement
|
||||
return 0;
|
||||
}
|
||||
|
||||
void* soc_interface_SetOneNodeOutputs(int output_count, int* max_size) {
|
||||
// TODO(satok): implement
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Append const node to the graph
|
||||
bool soc_interface_AppendConstNode(
|
||||
const char* const name, int node_id, int batch, int height, int width,
|
||||
int depth, const uint8_t* const data, int data_length) {
|
||||
// TODO(satok): implement
|
||||
return false;
|
||||
}
|
||||
|
||||
// Append node to the graph
|
||||
bool soc_interface_AppendNode(
|
||||
const char* const name, int node_id, int ops_id, int padding_id,
|
||||
const void* const inputs, int inputs_count, const void* const outputs,
|
||||
int outputs_count) {
|
||||
// TODO(satok): implement
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// Instantiate graph
|
||||
bool soc_interface_InstantiateGraph() {
|
||||
// TODO(satok): implement
|
||||
return false;
|
||||
}
|
||||
|
||||
// Construct graph
|
||||
bool soc_interface_ConstructGraph() {
|
||||
// TODO(satok): implement
|
||||
return false;
|
||||
}
|
||||
|
||||
void soc_interface_SetLogLevel(int log_level) {
|
||||
// TODO(satok): implement
|
||||
}
|
||||
|
||||
void soc_interface_SetDebugFlag(uint64_t flag) {
|
||||
// TODO(satok): implement
|
||||
}
|
@ -173,11 +173,12 @@ def _fused_batch_norm(
|
||||
`data_format` is `NHWC` and the second dimension if `data_format` is
|
||||
`NCHW`.
|
||||
decay: decay for the moving average. Reasonable values for `decay` are close
|
||||
to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc. Lower
|
||||
`decay` value (recommend trying `decay`=0.9) if model experiences reasonably
|
||||
good training performance but poor validation and/or test performance.
|
||||
center: If True, add offset of `beta` to normalized tensor. If False, `beta`
|
||||
is ignored.
|
||||
to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc.
|
||||
Lower `decay` value (recommend trying `decay`=0.9) if model experiences
|
||||
reasonably good training performance but poor validation and/or test
|
||||
performance.
|
||||
center: If True, add offset of `beta` to normalized tensor. If False,
|
||||
`beta` is ignored.
|
||||
scale: If True, multiply by `gamma`. If False, `gamma` is
|
||||
not used. When the next layer is linear (also e.g. `nn.relu`), this can be
|
||||
disabled since the scaling can be done by the next layer.
|
||||
@ -632,16 +633,12 @@ def batch_norm(
|
||||
if need_moments:
|
||||
# Calculate the moments based on the individual batch.
|
||||
if batch_weights is None:
|
||||
# Use a copy of moving_mean as a shift to compute more reliable moments.
|
||||
shift = math_ops.add(moving_mean, 0)
|
||||
if data_format == DATA_FORMAT_NCHW:
|
||||
shift = array_ops.reshape(shift, params_shape_broadcast)
|
||||
mean, variance = nn.moments(inputs, moments_axes, shift=shift,
|
||||
keep_dims=True)
|
||||
mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)
|
||||
mean = array_ops.reshape(mean, [-1])
|
||||
variance = array_ops.reshape(variance, [-1])
|
||||
else:
|
||||
mean, variance = nn.moments(inputs, moments_axes, shift=shift)
|
||||
mean, variance = nn.moments(inputs, moments_axes)
|
||||
else:
|
||||
if data_format == DATA_FORMAT_NCHW:
|
||||
mean, variance = nn.weighted_moments(inputs, moments_axes,
|
||||
@ -1385,7 +1382,7 @@ def fully_connected(inputs,
|
||||
Raises:
|
||||
ValueError: if x has rank less than 2 or if its last dimension is not set.
|
||||
"""
|
||||
if not (isinstance(num_outputs, six.integer_types)):
|
||||
if not isinstance(num_outputs, six.integer_types):
|
||||
raise ValueError('num_outputs should be int or long, got %s.', num_outputs)
|
||||
|
||||
layer_variable_getter = _build_variable_getter({'bias': 'biases'})
|
||||
|
@ -2356,7 +2356,7 @@ class BatchNormTest(test.TestCase):
|
||||
else:
|
||||
image_shape = (batch_size, channels, height, width)
|
||||
axis = (0, 2, 3)
|
||||
image_values = np.random.rand(*image_shape) + 2
|
||||
image_values = np.random.rand(*image_shape) + 256
|
||||
expected_mean = np.mean(image_values, axis=axis)
|
||||
expected_var = np.var(image_values, axis=axis)
|
||||
if fused:
|
||||
@ -2393,9 +2393,9 @@ class BatchNormTest(test.TestCase):
|
||||
# The outputs should be close to 0.0 mean and 1.0 variance
|
||||
self.assertAllClose(
|
||||
np.mean(
|
||||
np_output, axis=axis), [0] * channels, rtol=0.1, atol=0.1)
|
||||
np_output, axis=axis), [0] * channels, rtol=0.001, atol=0.001)
|
||||
self.assertAllClose(
|
||||
np.var(np_output, axis=axis), [1] * channels, rtol=0.1, atol=0.1)
|
||||
np.var(np_output, axis=axis), [1] * channels, rtol=0.01, atol=0.01)
|
||||
# The gradients should change slowly while updating moving_mean.
|
||||
max_diff = np.max(np.abs(images_gradients_value - new_images_gradients))
|
||||
self.assertGreaterEqual(max_diff, 0.0)
|
||||
@ -2558,25 +2558,29 @@ class LayerNormTest(test.TestCase):
|
||||
# output_train and output_eval should be the same.
|
||||
self.assertAllClose(sess.run([output_train]), sess.run([output_eval]))
|
||||
|
||||
def doOutputTest(self, input_shape):
|
||||
with self.test_session() as sess:
|
||||
input_values = np.random.rand(*input_shape)
|
||||
inputs = constant_op.constant(
|
||||
input_values, shape=input_shape, dtype=dtypes.float32)
|
||||
output_op = _layers.layer_norm(inputs, scope='LN')
|
||||
# Initialize all variables
|
||||
sess.run(variables_lib.global_variables_initializer())
|
||||
# The mean and variance of the output should be close to 0 and 1
|
||||
# respectively.
|
||||
moments_axis = tuple([i for i in range(1, len(input_shape))])
|
||||
outputs = sess.run(output_op)
|
||||
expected_mean = np.zeros(input_shape[0])
|
||||
expected_var = np.ones(input_shape[0])
|
||||
mean = np.mean(outputs, axis=moments_axis)
|
||||
var = np.var(outputs, axis=moments_axis)
|
||||
tol = 1e-5
|
||||
self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol)
|
||||
self.assertAllClose(var, expected_var, rtol=tol, atol=tol)
|
||||
def doOutputTest(self, input_shape, tol=1e-3):
|
||||
for mu in [0.0, 1e2]:
|
||||
for sigma in [1.0, 0.1]:
|
||||
input_values = np.random.rand(*input_shape) * sigma + mu
|
||||
expected_mean = np.zeros(input_shape[0])
|
||||
expected_var = np.ones(input_shape[0])
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.test_session(graph=g) as sess:
|
||||
inputs = constant_op.constant(input_values, shape=input_shape,
|
||||
dtype=dtypes.float32)
|
||||
output_op = _layers.layer_norm(inputs, scope='LN')
|
||||
# Initialize all variables
|
||||
sess.run(variables_lib.global_variables_initializer())
|
||||
# The mean and variance of the output should be close to 0 and 1
|
||||
# respectively.
|
||||
moments_axis = tuple([i for i in range(1, len(input_shape))])
|
||||
outputs = sess.run(output_op)
|
||||
# Make sure that there are no NaNs
|
||||
self.assertFalse(np.isnan(outputs).any())
|
||||
mean = np.mean(outputs, axis=moments_axis)
|
||||
var = np.var(outputs, axis=moments_axis)
|
||||
self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol)
|
||||
self.assertAllClose(var, expected_var, rtol=tol, atol=tol)
|
||||
|
||||
def testOutput2DInput(self):
|
||||
self.doOutputTest((10, 300))
|
||||
@ -2584,6 +2588,12 @@ class LayerNormTest(test.TestCase):
|
||||
def testOutput4DInput(self):
|
||||
self.doOutputTest((100, 10, 10, 3))
|
||||
|
||||
def testOutputSmallInput(self):
|
||||
self.doOutputTest((10, 10, 10, 30))
|
||||
|
||||
def testOutputBigInput(self):
|
||||
self.doOutputTest((1, 100, 100, 1))
|
||||
|
||||
|
||||
class MaxPool2DTest(test.TestCase):
|
||||
|
||||
|
@ -65,7 +65,7 @@ def l1_regularizer(scale, scope=None):
|
||||
my_scale = ops.convert_to_tensor(scale,
|
||||
dtype=weights.dtype.base_dtype,
|
||||
name='scale')
|
||||
return standard_ops.mul(
|
||||
return standard_ops.multiply(
|
||||
my_scale,
|
||||
standard_ops.reduce_sum(standard_ops.abs(weights)),
|
||||
name=name)
|
||||
@ -104,7 +104,7 @@ def l2_regularizer(scale, scope=None):
|
||||
my_scale = ops.convert_to_tensor(scale,
|
||||
dtype=weights.dtype.base_dtype,
|
||||
name='scale')
|
||||
return standard_ops.mul(my_scale, nn.l2_loss(weights), name=name)
|
||||
return standard_ops.multiply(my_scale, nn.l2_loss(weights), name=name)
|
||||
|
||||
return l2
|
||||
|
||||
|
@ -407,14 +407,15 @@ class BaseEstimator(
|
||||
raise ValueError('Can not provide both steps and max_steps.')
|
||||
_verify_input_args(x, y, input_fn, None, batch_size)
|
||||
if x is not None:
|
||||
return SKCompat(self).fit(x, y, batch_size, steps, max_steps, monitors)
|
||||
SKCompat(self).fit(x, y, batch_size, steps, max_steps, monitors)
|
||||
return self
|
||||
|
||||
if max_steps is not None:
|
||||
try:
|
||||
start_step = load_variable(self._model_dir, ops.GraphKeys.GLOBAL_STEP)
|
||||
if max_steps <= start_step:
|
||||
logging.info('Skipping training since max_steps has already saved.')
|
||||
return None
|
||||
return self
|
||||
except: # pylint: disable=bare-except
|
||||
pass
|
||||
|
||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
||||
import json
|
||||
import os
|
||||
|
||||
from tensorflow.contrib.framework import deprecated
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
@ -256,79 +255,30 @@ class RunConfig(ClusterConfig):
|
||||
def tf_config(self):
|
||||
return self._tf_config
|
||||
|
||||
@tf_config.setter
|
||||
@deprecated(
|
||||
'2017-01-08',
|
||||
'RunConfig will be made immutable, please pass all args to constructor.')
|
||||
def tf_config(self, value):
|
||||
self._tf_config = value
|
||||
|
||||
@property
|
||||
def tf_random_seed(self):
|
||||
return self._tf_random_seed
|
||||
|
||||
@tf_random_seed.setter
|
||||
@deprecated(
|
||||
'2017-01-08',
|
||||
'RunConfig will be made immutable, please pass all args to constructor.')
|
||||
def tf_random_seed(self, value):
|
||||
self._tf_random_seed = value
|
||||
|
||||
@property
|
||||
def save_summary_steps(self):
|
||||
return self._save_summary_steps
|
||||
|
||||
@save_summary_steps.setter
|
||||
@deprecated(
|
||||
'2017-01-08',
|
||||
'RunConfig will be made immutable, please pass all args to constructor.')
|
||||
def save_summary_steps(self, value):
|
||||
self._save_summary_steps = value
|
||||
|
||||
@property
|
||||
def save_checkpoints_secs(self):
|
||||
return self._save_checkpoints_secs
|
||||
|
||||
@save_checkpoints_secs.setter
|
||||
@deprecated(
|
||||
'2017-01-08',
|
||||
'RunConfig will be made immutable, please pass all args to constructor.')
|
||||
def save_checkpoints_secs(self, value):
|
||||
self._save_checkpoints_secs = value
|
||||
|
||||
@property
|
||||
def save_checkpoints_steps(self):
|
||||
return self._save_checkpoints_steps
|
||||
|
||||
@save_checkpoints_steps.setter
|
||||
@deprecated(
|
||||
'2017-01-08',
|
||||
'RunConfig will be made immutable, please pass all args to constructor.')
|
||||
def save_checkpoints_steps(self, value):
|
||||
self._save_checkpoints_steps = value
|
||||
|
||||
@property
|
||||
def keep_checkpoint_max(self):
|
||||
return self._keep_checkpoint_max
|
||||
|
||||
@keep_checkpoint_max.setter
|
||||
@deprecated(
|
||||
'2017-01-08',
|
||||
'RunConfig will be made immutable, please pass all args to constructor.')
|
||||
def keep_checkpoint_max(self, value):
|
||||
self._keep_checkpoint_max = value
|
||||
|
||||
@property
|
||||
def keep_checkpoint_every_n_hours(self):
|
||||
return self._keep_checkpoint_every_n_hours
|
||||
|
||||
@keep_checkpoint_every_n_hours.setter
|
||||
@deprecated(
|
||||
'2017-01-08',
|
||||
'RunConfig will be made immutable, please pass all args to constructor.')
|
||||
def keep_checkpoint_every_n_hours(self, value):
|
||||
self._keep_checkpoint_every_n_hours = value
|
||||
|
||||
|
||||
def _count_ps(cluster_spec):
|
||||
"""Counts the number of parameter servers in cluster_spec."""
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Implementations of different data feeders to provide data for TF trainer."""
|
||||
|
||||
# TODO(ipolosukhin): Replace this module with feed-dict queue runners & queues.
|
||||
@ -37,13 +36,13 @@ from tensorflow.python.platform import tf_logging as logging
|
||||
from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels
|
||||
from .dask_io import HAS_DASK, extract_dask_data, extract_dask_labels
|
||||
|
||||
|
||||
# pylint: enable=g-multiple-import,g-bad-import-order
|
||||
|
||||
|
||||
def _get_in_out_shape(x_shape, y_shape, n_classes, batch_size=None):
|
||||
"""Returns shape for input and output of the data feeder."""
|
||||
x_is_dict, y_is_dict = isinstance(x_shape, dict), y_shape is not None and isinstance(y_shape, dict)
|
||||
x_is_dict, y_is_dict = isinstance(
|
||||
x_shape, dict), y_shape is not None and isinstance(y_shape, dict)
|
||||
if y_is_dict and n_classes is not None:
|
||||
assert (isinstance(n_classes, dict))
|
||||
|
||||
@ -76,8 +75,11 @@ def _get_in_out_shape(x_shape, y_shape, n_classes, batch_size=None):
|
||||
if not y_is_dict:
|
||||
output_shape = out_el_shape(y_shape, n_classes)
|
||||
else:
|
||||
output_shape = dict([(k, out_el_shape(v, n_classes[k] if n_classes is not None and k in n_classes else None))
|
||||
for k, v in list(y_shape.items())])
|
||||
output_shape = dict([
|
||||
(k, out_el_shape(v, n_classes[k]
|
||||
if n_classes is not None and k in n_classes else None))
|
||||
for k, v in list(y_shape.items())
|
||||
])
|
||||
|
||||
return input_shape, output_shape, batch_size
|
||||
|
||||
@ -99,8 +101,12 @@ def _is_iterable(x):
|
||||
return hasattr(x, 'next') or hasattr(x, '__next__')
|
||||
|
||||
|
||||
def setup_train_data_feeder(
|
||||
x, y, n_classes, batch_size=None, shuffle=True, epochs=None):
|
||||
def setup_train_data_feeder(x,
|
||||
y,
|
||||
n_classes,
|
||||
batch_size=None,
|
||||
shuffle=True,
|
||||
epochs=None):
|
||||
"""Create data feeder, to sample inputs from dataset.
|
||||
|
||||
If `x` and `y` are iterators, use `StreamingDataFeeder`.
|
||||
@ -108,10 +114,13 @@ def setup_train_data_feeder(
|
||||
Args:
|
||||
x: numpy, pandas or Dask matrix or dictionary of aforementioned. Also
|
||||
supports iterables.
|
||||
y: numpy, pandas or Dask array or dictionary of aforementioned. Also supports
|
||||
y: numpy, pandas or Dask array or dictionary of aforementioned. Also
|
||||
supports
|
||||
iterables.
|
||||
n_classes: number of classes. Must be None or same type as y. In case, `y` is `dict`
|
||||
(or iterable which returns dict) such that `n_classes[key] = n_classes for y[key]`
|
||||
n_classes: number of classes. Must be None or same type as y. In case, `y`
|
||||
is `dict`
|
||||
(or iterable which returns dict) such that `n_classes[key] = n_classes for
|
||||
y[key]`
|
||||
batch_size: size to split data into parts. Must be >= 1.
|
||||
shuffle: Whether to shuffle the inputs.
|
||||
epochs: Number of epochs to run.
|
||||
@ -127,7 +136,7 @@ def setup_train_data_feeder(
|
||||
# pylint: disable=g-import-not-at-top
|
||||
import dask.dataframe as dd
|
||||
if (isinstance(x, (dd.Series, dd.DataFrame)) and
|
||||
(y is None or isinstance(y, (dd.Series, dd.DataFrame)))):
|
||||
(y is None or isinstance(y, (dd.Series, dd.DataFrame)))):
|
||||
data_feeder_cls = DaskDataFeeder
|
||||
else:
|
||||
data_feeder_cls = DataFeeder
|
||||
@ -140,7 +149,7 @@ def setup_train_data_feeder(
|
||||
'streaming learning to work.')
|
||||
return StreamingDataFeeder(x, y, n_classes, batch_size)
|
||||
return data_feeder_cls(
|
||||
x, y, n_classes, batch_size, shuffle=shuffle, epochs=epochs)
|
||||
x, y, n_classes, batch_size, shuffle=shuffle, epochs=epochs)
|
||||
|
||||
|
||||
def _batch_data(x, batch_size=None):
|
||||
@ -150,7 +159,8 @@ def _batch_data(x, batch_size=None):
|
||||
x_first_el = six.next(x)
|
||||
x = itertools.chain([x_first_el], x)
|
||||
|
||||
chunk = dict([(k, []) for k in list(x_first_el.keys())]) if isinstance(x_first_el, dict) else []
|
||||
chunk = dict([(k, []) for k in list(x_first_el.keys())]) if isinstance(
|
||||
x_first_el, dict) else []
|
||||
chunk_filled = False
|
||||
for data in x:
|
||||
if isinstance(data, dict):
|
||||
@ -161,7 +171,8 @@ def _batch_data(x, batch_size=None):
|
||||
chunk_filled = True
|
||||
if chunk_filled:
|
||||
yield chunk
|
||||
chunk = dict([(k, []) for k in list(x_first_el.keys())]) if isinstance(x_first_el, dict) else []
|
||||
chunk = dict([(k, []) for k in list(x_first_el.keys())]) if isinstance(
|
||||
x_first_el, dict) else []
|
||||
chunk_filled = False
|
||||
else:
|
||||
chunk.append(data)
|
||||
@ -259,16 +270,21 @@ def _access(data, iloc):
|
||||
def _check_dtype(dtype):
|
||||
if dtypes.as_dtype(dtype) == dtypes.float64:
|
||||
logging.warn(
|
||||
'float64 is not supported by many models, consider casting to float32.')
|
||||
'float64 is not supported by many models, consider casting to float32.')
|
||||
return dtype
|
||||
|
||||
|
||||
class DataFeeder(object):
|
||||
"""Data feeder is an example class to sample data for TF trainer."""
|
||||
|
||||
def __init__(
|
||||
self, x, y, n_classes, batch_size=None, shuffle=True, random_state=None,
|
||||
epochs=None):
|
||||
def __init__(self,
|
||||
x,
|
||||
y,
|
||||
n_classes,
|
||||
batch_size=None,
|
||||
shuffle=True,
|
||||
random_state=None,
|
||||
epochs=None):
|
||||
"""Initializes a DataFeeder instance.
|
||||
|
||||
Args:
|
||||
@ -299,29 +315,33 @@ class DataFeeder(object):
|
||||
input_dtype: DType of input (or dictionary of shapes).
|
||||
output_dtype: DType of output (or dictionary of shapes.
|
||||
"""
|
||||
x_is_dict, y_is_dict = isinstance(x, dict), y is not None and isinstance(y, dict)
|
||||
x_is_dict, y_is_dict = isinstance(x, dict), y is not None and isinstance(
|
||||
y, dict)
|
||||
if isinstance(y, list):
|
||||
y = np.array(y)
|
||||
|
||||
self._x = dict([(k, check_array(v, v.dtype)) for k, v in list(x.items())]) if x_is_dict else check_array(x, x.dtype)
|
||||
self._x = dict([(k, check_array(v, v.dtype)) for k, v in list(x.items())
|
||||
]) if x_is_dict else check_array(x, x.dtype)
|
||||
self._y = None if y is None else \
|
||||
dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())]) if x_is_dict else check_array(y, y.dtype)
|
||||
|
||||
# self.n_classes is not None means we're converting raw target indices to one-hot.
|
||||
if n_classes is not None:
|
||||
if not y_is_dict:
|
||||
y_dtype = (np.int64 if n_classes is not None and n_classes > 1 else np.float32)
|
||||
y_dtype = (np.int64
|
||||
if n_classes is not None and n_classes > 1 else np.float32)
|
||||
self._y = (None if y is None else check_array(y, dtype=y_dtype))
|
||||
|
||||
self.n_classes = n_classes
|
||||
self.max_epochs = epochs
|
||||
|
||||
x_shape = dict([(k, v.shape) for k, v in list(self._x.items())]) if x_is_dict else self._x.shape
|
||||
y_shape = dict(
|
||||
[(k, v.shape) for k, v in list(self._y.items())]) if y_is_dict else None if y is None else self._y.shape
|
||||
x_shape = dict([(k, v.shape) for k, v in list(self._x.items())
|
||||
]) if x_is_dict else self._x.shape
|
||||
y_shape = dict([(k, v.shape) for k, v in list(self._y.items())
|
||||
]) if y_is_dict else None if y is None else self._y.shape
|
||||
|
||||
self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
|
||||
x_shape, y_shape, n_classes, batch_size)
|
||||
x_shape, y_shape, n_classes, batch_size)
|
||||
|
||||
# Input dtype matches dtype of x.
|
||||
self._input_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(self._x.items())]) if x_is_dict \
|
||||
@ -339,9 +359,10 @@ class DataFeeder(object):
|
||||
|
||||
self._shuffle = shuffle
|
||||
self.random_state = np.random.RandomState(
|
||||
42) if random_state is None else random_state
|
||||
42) if random_state is None else random_state
|
||||
|
||||
num_samples = list(self._x.values())[0].shape[0] if x_is_dict else self._x.shape[0]
|
||||
num_samples = list(self._x.values())[0].shape[
|
||||
0] if x_is_dict else self._x.shape[0]
|
||||
if self._shuffle:
|
||||
self.indices = self.random_state.permutation(num_samples)
|
||||
else:
|
||||
@ -380,8 +401,8 @@ class DataFeeder(object):
|
||||
Returns:
|
||||
The epoch placeholder.
|
||||
"""
|
||||
self._epoch_placeholder = array_ops.placeholder(dtypes.int32, [1],
|
||||
name='epoch')
|
||||
self._epoch_placeholder = array_ops.placeholder(
|
||||
dtypes.int32, [1], name='epoch')
|
||||
return self._epoch_placeholder
|
||||
|
||||
def input_builder(self):
|
||||
@ -398,19 +419,17 @@ class DataFeeder(object):
|
||||
placeholder = {}
|
||||
for key in list(shape.keys()):
|
||||
placeholder[key] = array_ops.placeholder(
|
||||
dtypes.as_dtype(dtype[key]),
|
||||
[None] + shape[key][1:],
|
||||
name=name_prepend + '_' + key
|
||||
)
|
||||
dtypes.as_dtype(dtype[key]), [None] + shape[key][1:],
|
||||
name=name_prepend + '_' + key)
|
||||
else:
|
||||
placeholder = array_ops.placeholder(
|
||||
dtypes.as_dtype(dtype),
|
||||
[None] + shape[1:],
|
||||
name=name_prepend)
|
||||
dtypes.as_dtype(dtype), [None] + shape[1:], name=name_prepend)
|
||||
return placeholder
|
||||
|
||||
self._input_placeholder = get_placeholder(self.input_shape, self._input_dtype, 'input')
|
||||
self._output_placeholder = get_placeholder(self.output_shape, self._output_dtype, 'output')
|
||||
self._input_placeholder = get_placeholder(self.input_shape,
|
||||
self._input_dtype, 'input')
|
||||
self._output_placeholder = get_placeholder(self.output_shape,
|
||||
self._output_dtype, 'output')
|
||||
return self._input_placeholder, self._output_placeholder
|
||||
|
||||
def set_placeholders(self, input_placeholder, output_placeholder):
|
||||
@ -432,9 +451,9 @@ class DataFeeder(object):
|
||||
A `dict` with data feed params while training.
|
||||
"""
|
||||
return {
|
||||
'epoch': self.epoch,
|
||||
'offset': self.offset,
|
||||
'batch_size': self._batch_size
|
||||
'epoch': self.epoch,
|
||||
'offset': self.offset,
|
||||
'batch_size': self._batch_size
|
||||
}
|
||||
|
||||
def get_feed_dict_fn(self):
|
||||
@ -444,12 +463,13 @@ class DataFeeder(object):
|
||||
A function that when called samples a random subset of batch size
|
||||
from `x` and `y`.
|
||||
"""
|
||||
x_is_dict, y_is_dict = isinstance(self._x, dict), self._y is not None and isinstance(self._y, dict)
|
||||
x_is_dict, y_is_dict = isinstance(
|
||||
self._x, dict), self._y is not None and isinstance(self._y, dict)
|
||||
|
||||
# Assign input features from random indices.
|
||||
def extract(data, indices):
|
||||
return (np.array(_access(data, indices)).reshape((indices.shape[0], 1))
|
||||
if len(data.shape) == 1 else _access(data, indices))
|
||||
return (np.array(_access(data, indices)).reshape((indices.shape[0], 1)) if
|
||||
len(data.shape) == 1 else _access(data, indices))
|
||||
|
||||
# assign labels from random indices
|
||||
def assign_label(data, shape, dtype, n_classes, indices):
|
||||
@ -481,19 +501,22 @@ class DataFeeder(object):
|
||||
feed_dict[self._epoch_placeholder.name] = [self.epoch]
|
||||
|
||||
# Take next batch of indices.
|
||||
x_len = list(self._x.values())[0].shape[0] if x_is_dict else self._x.shape[0]
|
||||
x_len = list(self._x.values())[0].shape[
|
||||
0] if x_is_dict else self._x.shape[0]
|
||||
end = min(x_len, self.offset + self._batch_size)
|
||||
batch_indices = self.indices[self.offset:end]
|
||||
|
||||
# adding input placeholder
|
||||
feed_dict.update(
|
||||
dict([(self._input_placeholder[k].name, extract(v, batch_indices)) for k, v in list(self._x.items())])
|
||||
if x_is_dict else {self._input_placeholder.name: extract(self._x, batch_indices)})
|
||||
dict([(self._input_placeholder[k].name, extract(v, batch_indices))
|
||||
for k, v in list(self._x.items())]) if x_is_dict else
|
||||
{self._input_placeholder.name: extract(self._x, batch_indices)})
|
||||
|
||||
# move offset and reset it if necessary
|
||||
self.offset += self._batch_size
|
||||
if self.offset >= x_len:
|
||||
self.indices = self.random_state.permutation(x_len) if self._shuffle else np.array(range(x_len))
|
||||
self.indices = self.random_state.permutation(
|
||||
x_len) if self._shuffle else np.array(range(x_len))
|
||||
self.offset = 0
|
||||
self.epoch += 1
|
||||
|
||||
@ -504,15 +527,19 @@ class DataFeeder(object):
|
||||
# adding output placeholders
|
||||
if y_is_dict:
|
||||
for k, v in list(self._y.items()):
|
||||
n_classes = (
|
||||
self.n_classes[k] if k in self.n_classes else None) if self.n_classes is not None else None
|
||||
n_classes = (self.n_classes[k] if k in self.n_classes else
|
||||
None) if self.n_classes is not None else None
|
||||
shape, dtype = self.output_shape[k], self._output_dtype[k]
|
||||
feed_dict.update(
|
||||
{self._output_placeholder[k].name: assign_label(v, shape, dtype, n_classes, batch_indices)})
|
||||
feed_dict.update({
|
||||
self._output_placeholder[k].name:
|
||||
assign_label(v, shape, dtype, n_classes, batch_indices)
|
||||
})
|
||||
else:
|
||||
shape, dtype, n_classes = self.output_shape, self._output_dtype, self.n_classes
|
||||
feed_dict.update(
|
||||
{self._output_placeholder.name: assign_label(self._y, shape, dtype, n_classes, batch_indices)})
|
||||
feed_dict.update({
|
||||
self._output_placeholder.name:
|
||||
assign_label(self._y, shape, dtype, n_classes, batch_indices)
|
||||
})
|
||||
|
||||
return feed_dict
|
||||
|
||||
@ -566,41 +593,56 @@ class StreamingDataFeeder(DataFeeder):
|
||||
self._y = None
|
||||
self.n_classes = n_classes
|
||||
|
||||
x_is_dict, y_is_dict = isinstance(x_first_el, dict), y is not None and isinstance(y_first_el, dict)
|
||||
x_is_dict = isinstance(x_first_el, dict)
|
||||
y_is_dict = y is not None and isinstance(y_first_el, dict)
|
||||
if y_is_dict and n_classes is not None:
|
||||
assert (isinstance(n_classes, dict))
|
||||
assert isinstance(n_classes, dict)
|
||||
|
||||
# extract shapes for first_elements
|
||||
x_first_el_shape = dict([(k, [1] + list(v.shape)) for k, v in list(x_first_el.items())]) if x_is_dict \
|
||||
else [1] + list(x_first_el.shape)
|
||||
if x_is_dict:
|
||||
x_first_el_shape = dict(
|
||||
[(k, [1] + list(v.shape)) for k, v in list(x_first_el.items())])
|
||||
else:
|
||||
x_first_el_shape = [1] + list(x_first_el.shape)
|
||||
|
||||
y_first_el_shape = dict([(k, [1] + list(v.shape)) for k, v in list(y_first_el.items())]) if y_is_dict \
|
||||
else ([1] + list(y_first_el[0].shape if isinstance(y_first_el, list) else y_first_el.shape)
|
||||
if y is not None else None)
|
||||
if y_is_dict:
|
||||
y_first_el_shape = dict(
|
||||
[(k, [1] + list(v.shape)) for k, v in list(y_first_el.items())])
|
||||
elif y is None:
|
||||
y_first_el_shape = None
|
||||
else:
|
||||
y_first_el_shape = ([1] + list(y_first_el[0].shape if isinstance(
|
||||
y_first_el, list) else y_first_el.shape))
|
||||
|
||||
self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(x_first_el_shape, y_first_el_shape,
|
||||
n_classes, batch_size)
|
||||
self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
|
||||
x_first_el_shape, y_first_el_shape, n_classes, batch_size)
|
||||
|
||||
# Input dtype of x_first_el.
|
||||
self._input_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(x_first_el.items())]) if x_is_dict \
|
||||
else _check_dtype(x_first_el.dtype)
|
||||
if x_is_dict:
|
||||
self._input_dtype = dict(
|
||||
[(k, _check_dtype(v.dtype)) for k, v in list(x_first_el.items())])
|
||||
else:
|
||||
self._input_dtype = _check_dtype(x_first_el.dtype)
|
||||
|
||||
# Output dtype of y_first_el.
|
||||
def check_y_dtype(el):
|
||||
if isinstance(el, list) or isinstance(el, np.ndarray):
|
||||
if isinstance(el, np.ndarray) and el.ndim == 0:
|
||||
return el.dtype
|
||||
else:
|
||||
return _check_dtype(np.dtype(type(el[0])))
|
||||
if isinstance(el, np.ndarray):
|
||||
return el.dtype
|
||||
elif isinstance(el, list):
|
||||
return check_y_dtype(el[0])
|
||||
else:
|
||||
return _check_dtype(np.dtype(type(el)))
|
||||
|
||||
# Output types are floats, due to both softmaxes and regression req.
|
||||
if n_classes is not None and (y is None or not y_is_dict) and n_classes > 0:
|
||||
self._output_dtype = np.float32
|
||||
elif y_is_dict:
|
||||
self._output_dtype = dict(
|
||||
[(k, check_y_dtype(v)) for k, v in list(y_first_el.items())])
|
||||
elif y is None:
|
||||
self._output_dtype = None
|
||||
else:
|
||||
self._output_dtype = dict([(k, check_y_dtype(v)) for k, v in list(y_first_el.items())]) if y_is_dict \
|
||||
else (check_y_dtype(y_first_el) if y is not None else None)
|
||||
self._output_dtype = check_y_dtype(y_first_el)
|
||||
|
||||
def get_feed_params(self):
|
||||
"""Function returns a `dict` with data feed params while training.
|
||||
@ -627,13 +669,17 @@ class StreamingDataFeeder(DataFeeder):
|
||||
"""
|
||||
|
||||
def init_array(shape, dtype):
|
||||
"""Initialize array of given shape or dict of shapes and dtype."""
|
||||
if shape is None:
|
||||
return None
|
||||
elif isinstance(shape, dict):
|
||||
return dict([(k, np.zeros(shape[k], dtype[k]))
|
||||
for k in list(shape.keys())])
|
||||
else:
|
||||
return dict([(k, np.zeros(shape[k], dtype[k])) for k in list(shape.keys())]) if isinstance(shape, dict) else \
|
||||
np.zeros(shape, dtype=dtype)
|
||||
return np.zeros(shape, dtype=dtype)
|
||||
|
||||
def put_data_array(dest, index, source=None, n_classes=None):
|
||||
"""Puts data array into container."""
|
||||
if source is None:
|
||||
dest = dest[:index]
|
||||
elif n_classes is not None and n_classes > 1:
|
||||
@ -650,12 +696,13 @@ class StreamingDataFeeder(DataFeeder):
|
||||
return dest
|
||||
|
||||
def put_data_array_or_dict(holder, index, data=None, n_classes=None):
|
||||
"""Puts data array or data dictionary into container."""
|
||||
if holder is None:
|
||||
return None
|
||||
if isinstance(holder, dict):
|
||||
if data is None:
|
||||
data = {k: None for k in holder.keys()}
|
||||
assert (isinstance(data, dict))
|
||||
assert isinstance(data, dict)
|
||||
for k in holder.keys():
|
||||
num_classes = n_classes[k] if (n_classes is not None and
|
||||
k in n_classes) else None
|
||||
@ -688,12 +735,18 @@ class StreamingDataFeeder(DataFeeder):
|
||||
out = put_data_array_or_dict(out, i, next_out, self.n_classes)
|
||||
|
||||
# creating feed_dict
|
||||
feed_dict = dict([(self._input_placeholder[k].name, inp[k]) for k in list(self._input_placeholder.keys())]) if \
|
||||
isinstance(inp, dict) else {self._input_placeholder.name: inp}
|
||||
if isinstance(inp, dict):
|
||||
feed_dict = dict([(self._input_placeholder[k].name, inp[k])
|
||||
for k in list(self._input_placeholder.keys())])
|
||||
else:
|
||||
feed_dict = {self._input_placeholder.name: inp}
|
||||
if self._y is not None:
|
||||
feed_dict.update(
|
||||
dict([(self._output_placeholder[k].name, out[k]) for k in list(self._output_placeholder.keys())]) \
|
||||
if isinstance(out, dict) else {self._output_placeholder.name: out})
|
||||
if isinstance(out, dict):
|
||||
feed_dict.update(
|
||||
dict([(self._output_placeholder[k].name, out[k])
|
||||
for k in list(self._output_placeholder.keys())]))
|
||||
else:
|
||||
feed_dict.update({self._output_placeholder.name: out})
|
||||
|
||||
return feed_dict
|
||||
|
||||
@ -708,8 +761,14 @@ class DaskDataFeeder(object):
|
||||
memory and still do random seeks for sampling of batches.
|
||||
"""
|
||||
|
||||
def __init__(self, x, y, n_classes, batch_size, shuffle=True,
|
||||
random_state=None, epochs=None):
|
||||
def __init__(self,
|
||||
x,
|
||||
y,
|
||||
n_classes,
|
||||
batch_size,
|
||||
shuffle=True,
|
||||
random_state=None,
|
||||
epochs=None):
|
||||
"""Initializes a DaskDataFeeder instance.
|
||||
|
||||
Args:
|
||||
@ -732,10 +791,14 @@ class DaskDataFeeder(object):
|
||||
output_shape: shape of the output.
|
||||
input_dtype: dtype of input.
|
||||
output_dtype: dtype of output.
|
||||
|
||||
Raises:
|
||||
ValueError: if `x` or `y` are `dict`, as they are not supported currently.
|
||||
"""
|
||||
|
||||
if isinstance(x, dict) or isinstance(y, dict):
|
||||
raise ValueError("DaskDataFeeder does not support dictionaries at the moment.")
|
||||
raise ValueError(
|
||||
'DaskDataFeeder does not support dictionaries at the moment.')
|
||||
|
||||
# pylint: disable=invalid-name,super-init-not-called
|
||||
import dask.dataframe as dd # pylint: disable=g-import-not-at-top
|
||||
@ -763,7 +826,7 @@ class DaskDataFeeder(object):
|
||||
self._shuffle = shuffle
|
||||
self.epochs = epochs
|
||||
self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
|
||||
x_shape, y_shape, n_classes, batch_size)
|
||||
x_shape, y_shape, n_classes, batch_size)
|
||||
self.sample_fraction = self._batch_size / float(x_count)
|
||||
self._input_dtype = _check_dtype(self._x.dtypes[0])
|
||||
self._output_dtype = _check_dtype(self._y.dtypes[self._y_columns])
|
||||
@ -797,8 +860,8 @@ class DaskDataFeeder(object):
|
||||
# TODO(ipolosukhin): option for with/without replacement (dev version of
|
||||
# dask)
|
||||
sample = self.df.random_split(
|
||||
[self.sample_fraction, 1 - self.sample_fraction],
|
||||
random_state=self.random_state)
|
||||
[self.sample_fraction, 1 - self.sample_fraction],
|
||||
random_state=self.random_state)
|
||||
inp = extract_pandas_matrix(sample[0][self._x_columns].compute()).tolist()
|
||||
out = extract_pandas_matrix(sample[0][self._y_columns].compute())
|
||||
# convert to correct dtype
|
||||
@ -811,7 +874,6 @@ class DaskDataFeeder(object):
|
||||
out_max = self._y.max().compute().values[0]
|
||||
encoded_out = np.zeros((out.size, out_max + 1), dtype=self._output_dtype)
|
||||
encoded_out[np.arange(out.size), out] = 1
|
||||
return {input_placeholder.name: inp,
|
||||
output_placeholder.name: encoded_out}
|
||||
return {input_placeholder.name: inp, output_placeholder.name: encoded_out}
|
||||
|
||||
return _feed_dict_fn
|
||||
|
@ -253,20 +253,20 @@ class DataFeederTest(test.TestCase):
|
||||
inp, out = df.input_builder()
|
||||
feed_dict_fn = df.get_feed_dict_fn()
|
||||
feed_dict = feed_dict_fn()
|
||||
self._assertAllClose(inp, [[1, 2], [3, 4]], feed_dict, 'name')
|
||||
self._assertAllClose(out, [1, 2], feed_dict, 'name')
|
||||
self._assertAllClose(inp, [[[1, 2]], [[3, 4]]], feed_dict, 'name')
|
||||
self._assertAllClose(out, [[[1], [2]], [[2], [2]]], feed_dict, 'name')
|
||||
|
||||
def x_iter(wrap_dict=False):
|
||||
yield np.array([1, 2]) if not wrap_dict else self._wrap_dict(
|
||||
np.array([1, 2]), 'in')
|
||||
yield np.array([3, 4]) if not wrap_dict else self._wrap_dict(
|
||||
np.array([3, 4]), 'in')
|
||||
yield np.array([[1, 2]]) if not wrap_dict else self._wrap_dict(
|
||||
np.array([[1, 2]]), 'in')
|
||||
yield np.array([[3, 4]]) if not wrap_dict else self._wrap_dict(
|
||||
np.array([[3, 4]]), 'in')
|
||||
|
||||
def y_iter(wrap_dict=False):
|
||||
yield np.array([1]) if not wrap_dict else self._wrap_dict(
|
||||
np.array([1]), 'out')
|
||||
yield np.array([2]) if not wrap_dict else self._wrap_dict(
|
||||
np.array([2]), 'out')
|
||||
yield np.array([[1], [2]]) if not wrap_dict else self._wrap_dict(
|
||||
np.array([[1], [2]]), 'out')
|
||||
yield np.array([[2], [2]]) if not wrap_dict else self._wrap_dict(
|
||||
np.array([[2], [2]]), 'out')
|
||||
|
||||
func(
|
||||
data_feeder.StreamingDataFeeder(
|
||||
|
@ -25,6 +25,7 @@ import six
|
||||
from tensorflow.contrib.framework import tensor_util as contrib_tensor_util
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
@ -139,10 +140,11 @@ class LinearOperatorDerivedClassTest(test.TestCase):
|
||||
|
||||
def test_to_dense(self):
|
||||
self._maybe_skip("to_dense")
|
||||
with self.test_session() as sess:
|
||||
for use_placeholder in False, True:
|
||||
for shape in self._shapes_to_test:
|
||||
for dtype in self._dtypes_to_test:
|
||||
for use_placeholder in False, True:
|
||||
for shape in self._shapes_to_test:
|
||||
for dtype in self._dtypes_to_test:
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
|
||||
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
|
||||
shape, dtype, use_placeholder=use_placeholder)
|
||||
op_dense = operator.to_dense()
|
||||
@ -153,14 +155,15 @@ class LinearOperatorDerivedClassTest(test.TestCase):
|
||||
|
||||
def test_det(self):
|
||||
self._maybe_skip("det")
|
||||
with self.test_session() as sess:
|
||||
for use_placeholder in False, True:
|
||||
for shape in self._shapes_to_test:
|
||||
for dtype in self._dtypes_to_test:
|
||||
if dtype.is_complex:
|
||||
self.skipTest(
|
||||
"tf.matrix_determinant does not work with complex, so this "
|
||||
"test is being skipped.")
|
||||
for use_placeholder in False, True:
|
||||
for shape in self._shapes_to_test:
|
||||
for dtype in self._dtypes_to_test:
|
||||
if dtype.is_complex:
|
||||
self.skipTest(
|
||||
"tf.matrix_determinant does not work with complex, so this "
|
||||
"test is being skipped.")
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
|
||||
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
|
||||
shape, dtype, use_placeholder=use_placeholder)
|
||||
op_det = operator.determinant()
|
||||
@ -173,11 +176,12 @@ class LinearOperatorDerivedClassTest(test.TestCase):
|
||||
|
||||
def test_apply(self):
|
||||
self._maybe_skip("apply")
|
||||
with self.test_session() as sess:
|
||||
for use_placeholder in False, True:
|
||||
for shape in self._shapes_to_test:
|
||||
for dtype in self._dtypes_to_test:
|
||||
for adjoint in False, True:
|
||||
for use_placeholder in False, True:
|
||||
for shape in self._shapes_to_test:
|
||||
for dtype in self._dtypes_to_test:
|
||||
for adjoint in False, True:
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
|
||||
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
|
||||
shape, dtype, use_placeholder=use_placeholder)
|
||||
x = self._make_x(operator, adjoint=adjoint)
|
||||
@ -191,11 +195,12 @@ class LinearOperatorDerivedClassTest(test.TestCase):
|
||||
|
||||
def test_solve(self):
|
||||
self._maybe_skip("solve")
|
||||
with self.test_session() as sess:
|
||||
for use_placeholder in False, True:
|
||||
for shape in self._shapes_to_test:
|
||||
for dtype in self._dtypes_to_test:
|
||||
for adjoint in False, True:
|
||||
for use_placeholder in False, True:
|
||||
for shape in self._shapes_to_test:
|
||||
for dtype in self._dtypes_to_test:
|
||||
for adjoint in False, True:
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
|
||||
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
|
||||
shape, dtype, use_placeholder=use_placeholder)
|
||||
rhs = self._make_rhs(operator, adjoint=adjoint)
|
||||
@ -209,10 +214,11 @@ class LinearOperatorDerivedClassTest(test.TestCase):
|
||||
|
||||
def test_add_to_tensor(self):
|
||||
self._maybe_skip("add_to_tensor")
|
||||
with self.test_session() as sess:
|
||||
for use_placeholder in False, True:
|
||||
for shape in self._shapes_to_test:
|
||||
for dtype in self._dtypes_to_test:
|
||||
for use_placeholder in False, True:
|
||||
for shape in self._shapes_to_test:
|
||||
for dtype in self._dtypes_to_test:
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
|
||||
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
|
||||
shape, dtype, use_placeholder=use_placeholder)
|
||||
op_plus_2mat = operator.add_to_tensor(2 * mat)
|
||||
|
@ -21,7 +21,7 @@ echo "false")
|
||||
|
||||
# Hexagon integration
|
||||
ifdef HEXAGON_LIBS
|
||||
LIBGEMM_WRAPPER := $(HEXAGON_LIBS)/libgemm_wrapper.so
|
||||
LIBGEMM_WRAPPER := $(HEXAGON_LIBS)/libhexagon_controller.so
|
||||
ifeq ($(shell test -f $(LIBGEMM_WRAPPER) 2> /dev/null; echo $$?), 0)
|
||||
$(info "Use hexagon libs at " $(LIBGEMM_WRAPPER))
|
||||
else
|
||||
@ -271,7 +271,7 @@ ifeq ($(TARGET),ANDROID)
|
||||
|
||||
ifdef HEXAGON_LIBS
|
||||
INCLUDES += -I$(HEXAGON_INCLUDE)
|
||||
LIBS += -lgemm_wrapper
|
||||
LIBS += -lhexagon_controller
|
||||
LDFLAGS += -L$(HEXAGON_LIBS)
|
||||
CXXFLAGS += -DUSE_HEXAGON_LIBS
|
||||
endif
|
||||
|
@ -22,21 +22,32 @@ usage() {
|
||||
echo "-s [sub_makefiles] sub makefiles separated by white space"
|
||||
echo "-t [build_target] build target for Android makefile [default=all]"
|
||||
echo "-T only build tensorflow"
|
||||
echo "-x use hexagon library located at ../hexagon/<libs and include>"
|
||||
echo "-x use hexagon library located at tensorflow/contrib/makefile/downloads/hexagon"
|
||||
echo "-X download hexagon deps and run hexagon_graph_execution"
|
||||
exit 1
|
||||
}
|
||||
|
||||
download_and_push() {
|
||||
URL="$1"
|
||||
LOCAL_DEST="$2"
|
||||
ANDROID_DEST="$3"
|
||||
curl -Ls "${URL}" -o "${LOCAL_DEST}"
|
||||
adb shell mkdir -p "${ANDROID_DEST}"
|
||||
adb push "${LOCAL_DEST}" "${ANDROID_DEST}"
|
||||
}
|
||||
|
||||
if [[ -z "${NDK_ROOT}" ]]; then
|
||||
echo "NDK_ROOT should be set as an environment variable" 1>&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
while getopts "s:t:Tx" opt_name; do
|
||||
while getopts "s:t:TxX" opt_name; do
|
||||
case "$opt_name" in
|
||||
s) SUB_MAKEFILES="${OPTARG}";;
|
||||
t) BUILD_TARGET="${OPTARG}";;
|
||||
T) ONLY_MAKE_TENSORFLOW="true";;
|
||||
x) USE_HEXAGON="true";;
|
||||
X) DOWNLOAD_AND_USE_HEXAGON="true";;
|
||||
*) usage;;
|
||||
esac
|
||||
done
|
||||
@ -49,6 +60,8 @@ cd ${SCRIPT_DIR}/../../../
|
||||
source "${SCRIPT_DIR}/build_helper.subr"
|
||||
JOB_COUNT="${JOB_COUNT:-$(get_job_count)}"
|
||||
|
||||
HEXAGON_DOWNLOAD_PATH="tensorflow/contrib/makefile/downloads/hexagon"
|
||||
|
||||
if [[ "${ONLY_MAKE_TENSORFLOW}" != "true" ]]; then
|
||||
# Remove any old files first.
|
||||
make -f tensorflow/contrib/makefile/Makefile clean
|
||||
@ -63,10 +76,30 @@ else
|
||||
make -f tensorflow/contrib/makefile/Makefile clean_except_protobuf_libs
|
||||
fi
|
||||
|
||||
if [[ "${DOWNLOAD_AND_USE_HEXAGON}" == "true" ]]; then
|
||||
URL_BASE="https://storage.googleapis.com/download.tensorflow.org"
|
||||
|
||||
rm -rf "${HEXAGON_DOWNLOAD_PATH}"
|
||||
mkdir -p "${HEXAGON_DOWNLOAD_PATH}/libs"
|
||||
|
||||
download_and_push "${URL_BASE}/deps/hexagon/libhexagon_controller.so" \
|
||||
"${HEXAGON_DOWNLOAD_PATH}/libs/libhexagon_controller.so" "/data/local/tmp"
|
||||
|
||||
download_and_push "${URL_BASE}/deps/hexagon/libhexagon_nn_skel.so" \
|
||||
"${HEXAGON_DOWNLOAD_PATH}/libs/libhexagon_nn_skel.so" "/vendor/lib/rfsa/adsp"
|
||||
|
||||
download_and_push "${URL_BASE}/example_images/img_299x299.jpg" \
|
||||
"${HEXAGON_DOWNLOAD_PATH}/img_299x299.jpg" "/data/local/tmp"
|
||||
|
||||
USE_HEXAGON="true"
|
||||
SUB_MAKEFILES="$(pwd)/tensorflow/contrib/makefile/sub_makefiles/hexagon_graph_execution/Makefile.in"
|
||||
BUILD_TARGET="hexagon_graph_execution"
|
||||
fi
|
||||
|
||||
if [[ "${USE_HEXAGON}" == "true" ]]; then
|
||||
HEXAGON_PARENT_DIR=$(cd ../hexagon && pwd)
|
||||
HEXAGON_PARENT_DIR=$(cd "${HEXAGON_DOWNLOAD_PATH}" && pwd)
|
||||
HEXAGON_LIBS="${HEXAGON_PARENT_DIR}/libs"
|
||||
HEXAGON_INCLUDE=$(cd tensorflow/core/platform/hexagon && pwd)
|
||||
HEXAGON_INCLUDE=$(cd "tensorflow/core/platform/hexagon" && pwd)
|
||||
fi
|
||||
|
||||
if [[ -z "${BUILD_TARGET}" ]]; then
|
||||
@ -80,3 +113,14 @@ else
|
||||
HEXAGON_LIBS="${HEXAGON_LIBS}" HEXAGON_INCLUDE="${HEXAGON_INCLUDE}" \
|
||||
SUB_MAKEFILES="${SUB_MAKEFILES}" "${BUILD_TARGET}"
|
||||
fi
|
||||
|
||||
if [[ "${DOWNLOAD_AND_USE_HEXAGON}" == "true" ]]; then
|
||||
ANDROID_EXEC_FILE_MODE=755
|
||||
echo "Run hexagon_graph_execution"
|
||||
adb push -p "./tensorflow/contrib/makefile/gen/bin/hexagon_graph_execution" "/data/local/tmp/"
|
||||
adb wait-for-device
|
||||
adb shell chmod "${ANDROID_EXEC_FILE_MODE}" "/data/local/tmp/hexagon_graph_execution"
|
||||
adb wait-for-device
|
||||
adb shell 'LD_LIBRARY_PATH=/data/local/tmp:$LD_LIBRARY_PATH' \
|
||||
"/data/local/tmp/hexagon_graph_execution"
|
||||
fi
|
||||
|
@ -4486,7 +4486,7 @@ class StreamingMeanIOUTest(test.TestCase):
|
||||
num_classes)
|
||||
sess.run(variables.local_variables_initializer())
|
||||
confusion_matrix = update_op.eval()
|
||||
self.assertAllEqual([[3, 2], [0, 5]], confusion_matrix)
|
||||
self.assertAllEqual([[3, 0], [2, 5]], confusion_matrix)
|
||||
desired_miou = np.mean([3. / 5., 5. / 7.])
|
||||
self.assertAlmostEqual(desired_miou, miou.eval())
|
||||
|
||||
@ -4509,7 +4509,7 @@ class StreamingMeanIOUTest(test.TestCase):
|
||||
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
|
||||
num_classes)
|
||||
sess.run(variables.local_variables_initializer())
|
||||
self.assertAllEqual([[0, 40], [0, 0]], update_op.eval())
|
||||
self.assertAllEqual([[0, 0], [40, 0]], update_op.eval())
|
||||
self.assertEqual(0., miou.eval())
|
||||
|
||||
def testResultsWithSomeMissing(self):
|
||||
@ -4540,7 +4540,7 @@ class StreamingMeanIOUTest(test.TestCase):
|
||||
miou, update_op = metrics.streaming_mean_iou(
|
||||
predictions, labels, num_classes, weights=weights)
|
||||
sess.run(variables.local_variables_initializer())
|
||||
self.assertAllEqual([[2, 2], [0, 4]], update_op.eval())
|
||||
self.assertAllEqual([[2, 0], [2, 4]], update_op.eval())
|
||||
desired_miou = np.mean([2. / 4., 4. / 6.])
|
||||
self.assertAlmostEqual(desired_miou, miou.eval())
|
||||
|
||||
|
@ -84,12 +84,14 @@ load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"tf_proto_library",
|
||||
"tf_proto_library_cc",
|
||||
"tf_additional_core_deps",
|
||||
"tf_additional_lib_defines",
|
||||
"tf_additional_lib_deps",
|
||||
"tf_additional_lib_hdrs",
|
||||
"tf_additional_lib_srcs",
|
||||
"tf_additional_minimal_lib_srcs",
|
||||
"tf_additional_proto_hdrs",
|
||||
"tf_additional_proto_srcs",
|
||||
"tf_additional_lib_deps",
|
||||
"tf_additional_stream_executor_srcs",
|
||||
"tf_additional_cupti_wrapper_deps",
|
||||
"tf_additional_libdevice_data",
|
||||
@ -1127,12 +1129,13 @@ cc_library(
|
||||
"platform/tracing.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
defines = tf_additional_lib_defines(),
|
||||
linkopts = ["-ldl"],
|
||||
deps = [
|
||||
deps = tf_additional_lib_deps() + [
|
||||
":lib_proto_parsing",
|
||||
":protos_all_cc",
|
||||
"//tensorflow/core/platform/default/build_config:platformlib",
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/core/platform/default/build_config:platformlib",
|
||||
"@zlib_archive//:zlib",
|
||||
],
|
||||
)
|
||||
@ -1352,7 +1355,7 @@ tf_cuda_library(
|
||||
":protos_all_cc",
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/core/kernels:required",
|
||||
] + tf_additional_lib_deps(),
|
||||
] + tf_additional_core_deps(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
|
@ -215,7 +215,7 @@ Status CUPTIManager::DisableTrace() {
|
||||
void CUPTIManager::InternalBufferRequested(uint8_t **buffer, size_t *size,
|
||||
size_t *maxNumRecords) {
|
||||
VLOG(2) << "BufferRequested";
|
||||
void *p = port::aligned_malloc(kBufferSize, kBufferAlignment);
|
||||
void *p = port::AlignedMalloc(kBufferSize, kBufferAlignment);
|
||||
*size = kBufferSize;
|
||||
*buffer = reinterpret_cast<uint8_t *>(p);
|
||||
*maxNumRecords = 0;
|
||||
@ -246,7 +246,7 @@ void CUPTIManager::InternalBufferCompleted(CUcontext ctx, uint32_t streamId,
|
||||
LOG(WARNING) << "Dropped " << dropped << " activity records";
|
||||
}
|
||||
}
|
||||
port::aligned_free(buffer);
|
||||
port::AlignedFree(buffer);
|
||||
}
|
||||
|
||||
CUPTIManager *GetCUPTIManager() {
|
||||
|
@ -171,9 +171,9 @@ class BasicCPUAllocator : public SubAllocator {
|
||||
~BasicCPUAllocator() override {}
|
||||
|
||||
void* Alloc(size_t alignment, size_t num_bytes) override {
|
||||
return port::aligned_malloc(num_bytes, alignment);
|
||||
return port::AlignedMalloc(num_bytes, alignment);
|
||||
}
|
||||
void Free(void* ptr, size_t num_bytes) override { port::aligned_free(ptr); }
|
||||
void Free(void* ptr, size_t num_bytes) override { port::AlignedFree(ptr); }
|
||||
};
|
||||
|
||||
// Allocator for pinned CPU RAM that is made known to CUDA for the
|
||||
|
@ -1,8 +1,12 @@
|
||||
# Description:
|
||||
# TensorFlow Debugger (tfdbg).
|
||||
#
|
||||
# Public Android targets:
|
||||
# filegroup ":android_srcs" - Debugger source files for Android.
|
||||
# Public target(s):
|
||||
#
|
||||
# ":debug" - Depending on this target causes a concrete implementation of
|
||||
# DebuggerState to be constructed at initialization time, enabling
|
||||
# TensorFlow Debugger (tfdbg) support. For details, please see
|
||||
# core/common_runtime/debugger_state_interface.h.
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
@ -39,14 +43,12 @@ tf_proto_library_cc(
|
||||
protodeps = ["//tensorflow/core:protos_all"],
|
||||
)
|
||||
|
||||
# Depending on this target causes a concrete DebuggerState implementation
|
||||
# to be registered at initialization time. For details, please see
|
||||
# core/common_runtime/debugger_state_interface.h.
|
||||
cc_library(
|
||||
name = "debug",
|
||||
srcs = ["debug.cc"],
|
||||
copts = tf_copts(),
|
||||
linkstatic = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":debug_graph_utils",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
|
@ -275,6 +275,7 @@ cc_library(
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc_unsecure",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "grpc++/grpc++.h"
|
||||
#include "grpc++/security/credentials.h"
|
||||
#include "grpc++/server_builder.h"
|
||||
#include "grpc/support/alloc.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
@ -41,6 +42,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -304,6 +306,11 @@ class GrpcServerFactory : public ServerFactory {
|
||||
class GrpcServerRegistrar {
|
||||
public:
|
||||
GrpcServerRegistrar() {
|
||||
gpr_allocation_functions alloc_fns;
|
||||
alloc_fns.malloc_fn = port::Malloc;
|
||||
alloc_fns.realloc_fn = port::Realloc;
|
||||
alloc_fns.free_fn = port::Free;
|
||||
gpr_set_allocation_functions(alloc_fns);
|
||||
ServerFactory::Register("GRPC_SERVER", new GrpcServerFactory());
|
||||
}
|
||||
};
|
||||
|
@ -68,7 +68,7 @@ class CPUAllocator : public Allocator {
|
||||
string Name() override { return "cpu"; }
|
||||
|
||||
void* AllocateRaw(size_t alignment, size_t num_bytes) override {
|
||||
void* p = port::aligned_malloc(num_bytes, alignment);
|
||||
void* p = port::AlignedMalloc(num_bytes, alignment);
|
||||
if (cpu_allocator_collect_stats) {
|
||||
const std::size_t alloc_size = port::MallocExtension_GetAllocatedSize(p);
|
||||
mutex_lock l(mu_);
|
||||
@ -89,7 +89,7 @@ class CPUAllocator : public Allocator {
|
||||
mutex_lock l(mu_);
|
||||
stats_.bytes_in_use -= alloc_size;
|
||||
}
|
||||
port::aligned_free(ptr);
|
||||
port::AlignedFree(ptr);
|
||||
}
|
||||
|
||||
void GetStats(AllocatorStats* stats) override {
|
||||
|
@ -211,43 +211,6 @@ Status AddRetName(NameInfoIndex* name_info, const string& ret,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BuildNodeOutputIndex(const FunctionDef::Node& node,
|
||||
const InstantiateAttrValueMap& attrs,
|
||||
GetFunctionSignature get_function,
|
||||
const int arg_index, NameInfoIndex* name_info) {
|
||||
const OpDef* node_sig = nullptr;
|
||||
TF_RETURN_IF_ERROR(get_function(node.op(), &node_sig));
|
||||
if (node_sig->output_arg_size() == 0) {
|
||||
// This node produces no output.
|
||||
if (node.ret_size() != 1) {
|
||||
return errors::InvalidArgument("Expect one ret name.");
|
||||
}
|
||||
return AddRetName(name_info, node.ret(0), {false, arg_index, 0, false, {}});
|
||||
}
|
||||
const int num_retval = node_sig->output_arg_size();
|
||||
if (num_retval != node.ret_size()) {
|
||||
return errors::InvalidArgument("Malformed function node (#ret): ",
|
||||
num_retval, " vs. ", node.ret_size());
|
||||
}
|
||||
int start = 0;
|
||||
bool is_type_list;
|
||||
DataTypeVector dtypes;
|
||||
for (int i = 0; i < num_retval; ++i) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddRetName(name_info, node.ret(i),
|
||||
{false, arg_index, start, is_type_list, dtypes}));
|
||||
for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddRetName(name_info, strings::StrCat(node.ret(i), ":", j),
|
||||
{false, arg_index, start + j, false, {dtypes[j]}}));
|
||||
}
|
||||
start += dtypes.size();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BuildNodeOutputIndex(const NodeDef& node,
|
||||
const InstantiateAttrValueMap& attrs,
|
||||
GetFunctionSignature get_function,
|
||||
@ -280,85 +243,6 @@ Status BuildNodeOutputIndex(const NodeDef& node,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InstantiateNode(const FunctionDef::Node& fnode,
|
||||
const InstantiateAttrValueMap& attrs,
|
||||
GetFunctionSignature get_function,
|
||||
const NameInfoIndex& name_info, GraphDef* gdef) {
|
||||
const OpDef* fnode_sig = nullptr;
|
||||
TF_CHECK_OK(get_function(fnode.op(), &fnode_sig));
|
||||
NodeDef* gnode = gdef->add_node();
|
||||
gnode->set_name(Name(gdef->node_size() - 1));
|
||||
gnode->set_op(fnode.op());
|
||||
|
||||
// Input
|
||||
const int num_args = fnode_sig->input_arg_size();
|
||||
bool is_type_list;
|
||||
DataTypeVector dtypes;
|
||||
int fnode_arg_index = 0;
|
||||
for (int i = 0; i < num_args; ++i) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
|
||||
if (!is_type_list) {
|
||||
const NameInfoItem* item =
|
||||
gtl::FindOrNull(name_info, fnode.arg(fnode_arg_index));
|
||||
if (item == nullptr) {
|
||||
return errors::InvalidArgument("arg[", i, "] is not found: ",
|
||||
ProtoShortDebugString(fnode));
|
||||
}
|
||||
if (dtypes != item->dtypes) {
|
||||
return errors::InvalidArgument("Invalid arg(", i,
|
||||
") for function arg: ",
|
||||
DataTypeSliceString(dtypes), " vs. ",
|
||||
DataTypeSliceString(item->dtypes), ".");
|
||||
}
|
||||
for (size_t j = 0; j < dtypes.size(); ++j) {
|
||||
if (item->is_func_arg) {
|
||||
gnode->add_input(Name(item->nid + j));
|
||||
} else {
|
||||
gnode->add_input(Name(item->nid, item->idx + j));
|
||||
}
|
||||
}
|
||||
++fnode_arg_index;
|
||||
} else {
|
||||
for (size_t j = 0; j < dtypes.size(); ++j) {
|
||||
const NameInfoItem* item =
|
||||
gtl::FindOrNull(name_info, fnode.arg(fnode_arg_index + j));
|
||||
if (item == nullptr) {
|
||||
return errors::InvalidArgument("arg[", i + j, "] is not found: ",
|
||||
ProtoShortDebugString(fnode));
|
||||
}
|
||||
if (item->dtypes.size() != 1 || (item->dtypes[0] != dtypes[j])) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid typelist arg(", i + j, ") for function arg: ",
|
||||
DataTypeSliceString(dtypes), " vs. ",
|
||||
DataTypeSliceString(item->dtypes), ".");
|
||||
}
|
||||
if (item->is_func_arg) {
|
||||
gnode->add_input(Name(item->nid));
|
||||
} else {
|
||||
gnode->add_input(Name(item->nid, item->idx));
|
||||
}
|
||||
}
|
||||
fnode_arg_index += dtypes.size();
|
||||
}
|
||||
}
|
||||
// Control deps.
|
||||
for (int i = 0; i < fnode.dep_size(); ++i) {
|
||||
const NameInfoItem* item = gtl::FindOrNull(name_info, fnode.dep(i));
|
||||
if (item == nullptr) {
|
||||
return errors::InvalidArgument("dep[", i, "] is not found.");
|
||||
}
|
||||
gnode->add_input(Dep(item->nid));
|
||||
}
|
||||
|
||||
// Attrs.
|
||||
for (const auto& p : attrs) {
|
||||
(*gnode->mutable_attr())[p.first] = p.second;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InstantiateNode(const NodeDef& fnode,
|
||||
const InstantiateAttrValueMap& attrs,
|
||||
GetFunctionSignature get_function,
|
||||
@ -448,38 +332,6 @@ Status InstantiateNode(const NodeDef& fnode,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// FunctionDef::Node version
|
||||
Status AddReturnNode(const OpDef::ArgDef& ret_def,
|
||||
const InstantiateAttrValueMap& attrs,
|
||||
const NameInfoIndex& name_info, int* ret_index,
|
||||
InstantiationResult* result) {
|
||||
bool is_type_list;
|
||||
DataTypeVector dtypes;
|
||||
TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
|
||||
CHECK_GE(dtypes.size(), size_t{1});
|
||||
const NameInfoItem* item = gtl::FindOrNull(name_info, ret_def.name());
|
||||
if (item == nullptr) {
|
||||
return errors::InvalidArgument("ret is not found.");
|
||||
}
|
||||
if (dtypes != item->dtypes) {
|
||||
return errors::InvalidArgument("Invalid ret types ", ret_def.name(), " : ",
|
||||
DataTypeVectorString(dtypes), " vs. ",
|
||||
DataTypeVectorString(item->dtypes));
|
||||
}
|
||||
GraphDef* gdef = &result->gdef;
|
||||
for (size_t i = 0; i < dtypes.size(); ++i) {
|
||||
NodeDef* gnode = gdef->add_node();
|
||||
gnode->set_name(Name(gdef->node_size() - 1));
|
||||
gnode->set_op("_Retval");
|
||||
gnode->add_input(Name(item->nid, item->idx + i));
|
||||
AddAttr("T", dtypes[i], gnode);
|
||||
AddAttr("index", (*ret_index)++, gnode);
|
||||
result->ret_types.push_back(dtypes[i]);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// NodeDef version
|
||||
Status AddReturnNode(const OpDef::ArgDef& ret_def,
|
||||
const InstantiateAttrValueMap& attrs,
|
||||
const ::tensorflow::protobuf::Map<string, string>& ret_map,
|
||||
@ -561,38 +413,6 @@ string Print(const AttrValue& attr_value) {
|
||||
return SummarizeAttrValue(attr_value);
|
||||
}
|
||||
|
||||
string Print(const FunctionDef::Node& node) {
|
||||
string out;
|
||||
for (int i = 0; i < node.ret_size(); ++i) {
|
||||
const auto& name = node.ret(i);
|
||||
if (i > 0) strings::StrAppend(&out, ", ");
|
||||
strings::StrAppend(&out, name);
|
||||
}
|
||||
strings::StrAppend(&out, " = ", node.op());
|
||||
if (node.attr_size() > 0) {
|
||||
std::vector<string> entries;
|
||||
for (auto p : node.attr()) {
|
||||
entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
|
||||
}
|
||||
sort(entries.begin(), entries.end());
|
||||
strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]");
|
||||
}
|
||||
strings::StrAppend(&out, "(");
|
||||
for (int i = 0; i < node.arg_size(); ++i) {
|
||||
if (i > 0) strings::StrAppend(&out, ", ");
|
||||
strings::StrAppend(&out, node.arg(i));
|
||||
}
|
||||
strings::StrAppend(&out, ")");
|
||||
if (node.dep_size() > 0) {
|
||||
strings::StrAppend(&out, " @ ");
|
||||
for (int i = 0; i < node.dep_size(); ++i) {
|
||||
if (i > 0) strings::StrAppend(&out, ", ");
|
||||
strings::StrAppend(&out, node.dep(i));
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// TODO(josh11b): Merge this with SummarizeNodeDef().
|
||||
string Print(const NodeDef& n) {
|
||||
string out;
|
||||
@ -650,17 +470,11 @@ string Print(const FunctionDef& fdef) {
|
||||
strings::StrAppend(&out, Print(sig.output_arg(i)));
|
||||
}
|
||||
strings::StrAppend(&out, ") {\n");
|
||||
if (fdef.node_def_size() > 0 || fdef.ret_size() > 0) {
|
||||
for (const auto& n : fdef.node_def()) {
|
||||
strings::StrAppend(&out, " ", Print(n), "\n");
|
||||
}
|
||||
for (const auto& r : fdef.ret()) {
|
||||
strings::StrAppend(&out, " return ", r.first, " = ", r.second, "\n");
|
||||
}
|
||||
} else { // TODO(josh11b): Eventually remove this case.
|
||||
for (const auto& n : fdef.node()) {
|
||||
strings::StrAppend(&out, " ", Print(n), "\n");
|
||||
}
|
||||
for (const auto& n : fdef.node_def()) {
|
||||
strings::StrAppend(&out, " ", Print(n), "\n");
|
||||
}
|
||||
for (const auto& r : fdef.ret()) {
|
||||
strings::StrAppend(&out, " return ", r.first, " = ", r.second, "\n");
|
||||
}
|
||||
strings::StrAppend(&out, "}\n");
|
||||
return out;
|
||||
@ -772,92 +586,47 @@ Status InstantiateFunction(const FunctionDef& fdef,
|
||||
// Makes a copy of all attrs in fdef and substitutes placeholders.
|
||||
// After this step, every attr is bound to a concrete value.
|
||||
std::vector<InstantiateAttrValueMap> node_attrs;
|
||||
if (fdef.node_def_size() > 0 || fdef.ret_size() > 0) {
|
||||
node_attrs.resize(fdef.node_def_size());
|
||||
for (int i = 0; i < fdef.node_def_size(); ++i) {
|
||||
for (auto attr : fdef.node_def(i).attr()) {
|
||||
if (!SubstitutePlaceholders(substitute, &attr.second)) {
|
||||
return errors::InvalidArgument("Failed to bind all placeholders in ",
|
||||
SummarizeAttrValue(attr.second));
|
||||
}
|
||||
if (!node_attrs[i].insert(attr).second) {
|
||||
return errors::Internal("Somehow duplicated: ", attr.first);
|
||||
}
|
||||
node_attrs.resize(fdef.node_def_size());
|
||||
for (int i = 0; i < fdef.node_def_size(); ++i) {
|
||||
for (auto attr : fdef.node_def(i).attr()) {
|
||||
if (!SubstitutePlaceholders(substitute, &attr.second)) {
|
||||
return errors::InvalidArgument("Failed to bind all placeholders in ",
|
||||
SummarizeAttrValue(attr.second));
|
||||
}
|
||||
if (!node_attrs[i].insert(attr).second) {
|
||||
return errors::Internal("Somehow duplicated: ", attr.first);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i]));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i]));
|
||||
}
|
||||
|
||||
for (int i = 0; i < fdef.node_def_size(); ++i) {
|
||||
s = BuildNodeOutputIndex(fdef.node_def(i), node_attrs[i], get_function,
|
||||
gdef->node_size() + i, &name_info);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
|
||||
return s;
|
||||
}
|
||||
for (int i = 0; i < fdef.node_def_size(); ++i) {
|
||||
s = BuildNodeOutputIndex(fdef.node_def(i), node_attrs[i], get_function,
|
||||
gdef->node_size() + i, &name_info);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
|
||||
return s;
|
||||
}
|
||||
// Emits one gdef.node for each fdef.node_def.
|
||||
for (int i = 0; i < fdef.node_def_size(); ++i) {
|
||||
s = InstantiateNode(fdef.node_def(i), node_attrs[i], get_function,
|
||||
name_info, gdef);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
|
||||
return s;
|
||||
}
|
||||
}
|
||||
// Emits one gdef.node for each fdef.node_def.
|
||||
for (int i = 0; i < fdef.node_def_size(); ++i) {
|
||||
s = InstantiateNode(fdef.node_def(i), node_attrs[i], get_function,
|
||||
name_info, gdef);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
|
||||
return s;
|
||||
}
|
||||
}
|
||||
|
||||
// Emits nodes for the function's return values.
|
||||
int ret_index = 0;
|
||||
for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
|
||||
s = AddReturnNode(ret_def, attr_values, fdef.ret(), name_info, &ret_index,
|
||||
result);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, "In function output ", Print(ret_def));
|
||||
return s;
|
||||
}
|
||||
}
|
||||
} else { // TODO(josh11b): Eventually remove this case.
|
||||
node_attrs.resize(fdef.node_size());
|
||||
for (int i = 0; i < fdef.node_size(); ++i) {
|
||||
for (auto attr : fdef.node(i).attr()) {
|
||||
if (!SubstitutePlaceholders(substitute, &attr.second)) {
|
||||
return errors::InvalidArgument("Failed to bind all placeholders in ",
|
||||
SummarizeAttrValue(attr.second));
|
||||
}
|
||||
if (!node_attrs[i].insert(attr).second) {
|
||||
return errors::Internal("Somehow duplicated: ", attr.first);
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddDefaultAttrs(fdef.node(i).op(), get_function, &node_attrs[i]));
|
||||
}
|
||||
|
||||
for (int i = 0; i < fdef.node_size(); ++i) {
|
||||
s = BuildNodeOutputIndex(fdef.node(i), node_attrs[i], get_function,
|
||||
gdef->node_size() + i, &name_info);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, "In ", Print(fdef.node(i)));
|
||||
return s;
|
||||
}
|
||||
}
|
||||
// Emits one gdef.node for each fdef.node.
|
||||
for (int i = 0; i < fdef.node_size(); ++i) {
|
||||
s = InstantiateNode(fdef.node(i), node_attrs[i], get_function, name_info,
|
||||
gdef);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, "In ", Print(fdef.node(i)));
|
||||
return s;
|
||||
}
|
||||
}
|
||||
|
||||
// Emits nodes for the function's return values.
|
||||
int ret_index = 0;
|
||||
for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
|
||||
s = AddReturnNode(ret_def, attr_values, name_info, &ret_index, result);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, "In function output ", Print(ret_def));
|
||||
return s;
|
||||
}
|
||||
// Emits nodes for the function's return values.
|
||||
int ret_index = 0;
|
||||
for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
|
||||
s = AddReturnNode(ret_def, attr_values, fdef.ret(), name_info, &ret_index,
|
||||
result);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, "In function output ", Print(ret_def));
|
||||
return s;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -30,61 +30,7 @@ message FunctionDef {
|
||||
// Attributes specific to this function definition.
|
||||
map<string, AttrValue> attr = 5;
|
||||
|
||||
// TO BE REPLACED
|
||||
|
||||
// The body of the function.
|
||||
repeated Node node = 2; // function.node.ret[*] are unique.
|
||||
|
||||
// A node is a multi-value assignment:
|
||||
// (ret[0], ret[1], ...) = func(arg[0], arg[1], ...)
|
||||
//
|
||||
// By convention, "func" is resolved by consulting with a user-defined
|
||||
// library first. If not resolved, "func" is assumed to be a builtin op.
|
||||
message Node {
|
||||
// This node produces multiple outputs. They are named ret[0],
|
||||
// ret[1], ..., etc.
|
||||
//
|
||||
// REQUIRES: function.node.ret[*] are unique across all nodes.
|
||||
// REQUIRES: ret.size == func/op def's number of output args.
|
||||
repeated string ret = 1;
|
||||
|
||||
// The op/function name.
|
||||
string op = 2;
|
||||
|
||||
// Arguments passed to this func/op.
|
||||
//
|
||||
// arg[i] must be either one of
|
||||
// function.signature.input_args[*].name or one of
|
||||
// function.node[*].ret[*].
|
||||
//
|
||||
// REQUIRES: arg.size == func/op def's number of input args.
|
||||
repeated string arg = 3;
|
||||
|
||||
// Control dependencies.
|
||||
//
|
||||
// dep[i] must be one of function.node[*].ret[*] or one of
|
||||
// function.signature.input_args[*].name.
|
||||
repeated string dep = 4;
|
||||
|
||||
// Attrs.
|
||||
//
|
||||
// 'attr' maps names defined by 'func's attr defs to attr values.
|
||||
// attr values may have placeholders which are substituted
|
||||
// recursively by concrete values when this node is instantiated.
|
||||
// These placeholders must name an attr listed in the FunctionDef's
|
||||
// signature.
|
||||
map<string, AttrValue> attr = 5;
|
||||
}
|
||||
|
||||
// WILL REPLACE THE ABOVE
|
||||
|
||||
// If node_def is present, and the consumer is at GraphDef version
|
||||
// >= 12, then these fields are used and `node` is ignored. If the
|
||||
// consumer's GraphDef version is < 12 or this field is empty, then
|
||||
// `node` is used. This allows producers to fill both fields to
|
||||
// remain compatible with old consumers. At some future GraphDef
|
||||
// version, `node` will be ignored even if `node_def` is empty.
|
||||
// TODO(josh11b): Finish this transition.
|
||||
// NOTE: field id 2 deleted on Jan 11, 2016, GraphDef version 21.
|
||||
|
||||
// In both of the following fields, there is the need to specify an
|
||||
// output that is used as either the input to another node (in
|
||||
@ -120,6 +66,10 @@ message FunctionDef {
|
||||
// The body of the function. Unlike the NodeDefs in a GraphDef, attrs
|
||||
// may have values of type `placeholder` and the `input` field uses
|
||||
// the "output" format above.
|
||||
|
||||
// By convention, "op" in node_def is resolved by consulting with a
|
||||
// user-defined library first. If not resolved, "func" is assumed to
|
||||
// be a builtin op.
|
||||
repeated NodeDef node_def = 3;
|
||||
|
||||
// A mapping from the output arg names from `signature` to the
|
||||
|
@ -48,52 +48,8 @@ y: A scalar in type T.
|
||||
|
||||
static InstantiateAttrValueMap kNoAttrs;
|
||||
|
||||
TEST(TFunc, SquarePlusOneOld) {
|
||||
auto fdef = FDH::Define( // Create a FunctionDef using Function::Nodes.
|
||||
// Name
|
||||
"SquarePlusOne",
|
||||
// Args
|
||||
{"x: T"},
|
||||
// Return values
|
||||
{"y: T"},
|
||||
// Attrs
|
||||
{"T: {float, double, int32, int64}"},
|
||||
// Nodes
|
||||
{// a = Square<T>(x)
|
||||
{{"a"}, "Square", {"x"}, {{"T", "$T"}}},
|
||||
// o = One<T>()
|
||||
// NOTE: We can also have a Cast<Tin, Tout>(x) instead.
|
||||
{{"o"}, "One", {}, {{"T", "$T"}}},
|
||||
// y = Add<T>(a, o)
|
||||
{{"y"}, "Add", {"a", "o"}, {{"T", "$T"}}}});
|
||||
|
||||
const char* e = R"P(
|
||||
SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
||||
a = Square[T=$T](x)
|
||||
o = One[T=$T]()
|
||||
y = Add[T=$T](a:y:0, o:y:0)
|
||||
return y = y:z:0
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(DebugString(fdef), e);
|
||||
|
||||
// Instantiate one with T=float
|
||||
InstantiationResult result;
|
||||
TF_ASSERT_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result));
|
||||
const char* e2 = R"P(
|
||||
(n0:float) -> (n3:float) {
|
||||
n1 = Square[T=float](n0)
|
||||
n2 = One[T=float]()
|
||||
n3 = Add[T=float](n1, n2)
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
|
||||
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
}
|
||||
|
||||
TEST(TFunc, SquarePlusOneNodeDef) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
TEST(TFunc, SquarePlusOne) {
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"SquarePlusOne",
|
||||
// Inputs
|
||||
@ -138,8 +94,8 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
}
|
||||
|
||||
TEST(TFunc, ControlDepNodeDef) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
TEST(TFunc, ControlDep) {
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"ControlDep",
|
||||
// Inputs
|
||||
@ -190,44 +146,8 @@ REGISTER_OP("HasDefaultType")
|
||||
// This verifies that a function using an op before a type attr (with
|
||||
// a default) is added, still works. This is important for backwards
|
||||
// compatibilty.
|
||||
TEST(TFunc, MissingTypeAttrOld) {
|
||||
auto fdef = FDH::Define( // Create a FunctionDef using Function::Nodes.
|
||||
// Name
|
||||
"BackCompat",
|
||||
// Args
|
||||
{},
|
||||
// Return values
|
||||
{"y: float"},
|
||||
// Attrs
|
||||
{},
|
||||
// Nodes
|
||||
{// y = HasDefaultType(x), T missing, defaults to float
|
||||
{{"y"}, "HasDefaultType", {}, {}}});
|
||||
|
||||
const char* e = R"P(
|
||||
BackCompat() -> (y:float) {
|
||||
y = HasDefaultType()
|
||||
return y = y:out:0
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(DebugString(fdef), e);
|
||||
|
||||
InstantiationResult result;
|
||||
TF_ASSERT_OK(
|
||||
InstantiateFunction(fdef, InstantiateAttrValueMap{}, GetOpSig, &result));
|
||||
// Should get T=float from Op's default.
|
||||
const char* e2 = R"P(
|
||||
() -> (n0:float) {
|
||||
n0 = HasDefaultType[T=float]()
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(result.arg_types, DataTypeVector());
|
||||
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
}
|
||||
|
||||
TEST(TFunc, MissingTypeAttrNodeDef) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
TEST(TFunc, MissingTypeAttr) {
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"BackCompat",
|
||||
// Args
|
||||
@ -264,11 +184,8 @@ BackCompat() -> (y:float) {
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
}
|
||||
|
||||
TEST(TFunc, NTimesTNodeDef) {
|
||||
// Note that the equivalent FunctionDef using FunctionDef::Node requires
|
||||
// using a _ListToArray to package up the two inputs to AddN as a single
|
||||
// N*T edge.
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
TEST(TFunc, NTimesT) {
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"NTimesT",
|
||||
// Inputs
|
||||
@ -790,8 +707,8 @@ TEST(InstantiateErrors, TypeList_Missing_Arg) {
|
||||
"input unknown is not found");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, NodeDef_TooManyInputs) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
TEST(InstantiateErrors, TooManyInputs) {
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"TooManyInputs",
|
||||
// Inputs
|
||||
@ -811,8 +728,8 @@ TEST(InstantiateErrors, NodeDef_TooManyInputs) {
|
||||
"Expected input[2] == 'x' to be a control input.");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, NodeDef_TooFewInputs) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
TEST(InstantiateErrors, TooFewInputs) {
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"TooFewInputs",
|
||||
// Inputs
|
||||
@ -832,8 +749,8 @@ TEST(InstantiateErrors, NodeDef_TooFewInputs) {
|
||||
"Attempt to access beyond input size: 2 >= 2");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray1) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
TEST(InstantiateErrors, TooManyInputsFromArray1) {
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"TooManyInputsFromArray",
|
||||
// Inputs
|
||||
@ -860,8 +777,8 @@ TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray1) {
|
||||
"Expected input[1] == 'y' to be a control input.");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray2) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
TEST(InstantiateErrors, TooManyInputsFromArray2) {
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"TooManyInputsFromArray",
|
||||
// Inputs
|
||||
@ -888,8 +805,8 @@ TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray2) {
|
||||
"Input a:output too long for inputs");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, NodeDef_TypeMismatch) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
TEST(InstantiateErrors, TypeMismatch) {
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"TypeMismatch",
|
||||
// Inputs
|
||||
|
@ -178,14 +178,8 @@ void OpsUsedByGraph(const GraphDef& graph_def,
|
||||
while (!functions_to_process.empty()) {
|
||||
const FunctionDef* fun = functions_to_process.back();
|
||||
functions_to_process.pop_back();
|
||||
if (fun->node_def_size() > 0) {
|
||||
for (const auto& node : fun->node_def()) {
|
||||
mark_op_as_used(node.op());
|
||||
}
|
||||
} else { // TODO(josh11b): Eventually drop support for this.
|
||||
for (const auto& node : fun->node()) {
|
||||
mark_op_as_used(node.op());
|
||||
}
|
||||
for (const auto& node : fun->node_def()) {
|
||||
mark_op_as_used(node.op());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -91,7 +92,7 @@ Status LoadLibrary(const char* library_filename, void** result,
|
||||
}
|
||||
string str;
|
||||
library.op_list.SerializeToString(&str);
|
||||
char* str_buf = reinterpret_cast<char*>(malloc(str.length()));
|
||||
char* str_buf = reinterpret_cast<char*>(port::Malloc(str.length()));
|
||||
memcpy(str_buf, str.data(), str.length());
|
||||
*buf = str_buf;
|
||||
*len = str.length();
|
||||
|
@ -185,6 +185,17 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||
std::vector<NameAttrList>* value) {
|
||||
const AttrValue* attr_value;
|
||||
TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
|
||||
TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(func)"));
|
||||
for (const auto& v : attr_value->list().func()) {
|
||||
value->emplace_back(v);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace { // Helper for InOutTypesForNode().
|
||||
|
||||
Status AddArgToSig(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
|
||||
|
@ -150,6 +150,9 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||
const NameAttrList** value); // type: "func"
|
||||
|
||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||
std::vector<NameAttrList>* value); // type: "list(func)"
|
||||
|
||||
// Computes the input and output types for a specific node.
|
||||
// REQUIRES: ValidateOpDef(op_def).ok()
|
||||
Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -27,7 +28,7 @@ class TestableSizeTrackingAllocator : public Allocator {
|
||||
public:
|
||||
string Name() override { return "test"; }
|
||||
void* AllocateRaw(size_t /*alignment*/, size_t num_bytes) override {
|
||||
void* ptr = malloc(num_bytes);
|
||||
void* ptr = port::Malloc(num_bytes);
|
||||
size_map_[ptr] = num_bytes;
|
||||
return ptr;
|
||||
}
|
||||
@ -35,7 +36,7 @@ class TestableSizeTrackingAllocator : public Allocator {
|
||||
const auto& iter = size_map_.find(ptr);
|
||||
EXPECT_NE(size_map_.end(), iter);
|
||||
size_map_.erase(iter);
|
||||
free(ptr);
|
||||
port::Free(ptr);
|
||||
}
|
||||
bool TracksAllocationSizes() override { return true; }
|
||||
size_t RequestedSize(void* ptr) override {
|
||||
|
@ -254,6 +254,10 @@ Node* Identity(Graph* g, Node* input, int index) {
|
||||
|
||||
Node* Add(Graph* g, Node* in0, Node* in1) { return Binary(g, "Add", in0, in1); }
|
||||
|
||||
Node* Reverse(Graph* g, Node* tensor, Node* axis) {
|
||||
return Binary(g, "ReverseV2", tensor, axis);
|
||||
}
|
||||
|
||||
Node* Error(Graph* g, Node* input, const string& errmsg) {
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error")
|
||||
|
@ -100,6 +100,9 @@ Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins);
|
||||
// Adds a binary add node in "g" doing in0 + in1.
|
||||
Node* Add(Graph* g, Node* in0, Node* in1);
|
||||
|
||||
// Reverses <axis> dimensions of <tensor>>
|
||||
Node* Reverse(Graph* g, Node* tensor, Node* axis);
|
||||
|
||||
// Generates random unit uniform distribution of the input shape.
|
||||
Node* RandomUniform(Graph* g, Node* input, DataType dtype);
|
||||
|
||||
|
@ -256,6 +256,15 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stage_op",
|
||||
srcs = ["stage_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "queue_base",
|
||||
srcs = ["queue_base.cc"],
|
||||
@ -1161,6 +1170,7 @@ cc_library(
|
||||
":session_ops",
|
||||
":sparse_conditional_accumulator_op",
|
||||
":stack_ops",
|
||||
":stage_op",
|
||||
":tensor_array_ops",
|
||||
],
|
||||
)
|
||||
@ -3228,6 +3238,7 @@ tf_kernel_library(
|
||||
prefix = "training_ops",
|
||||
deps = [
|
||||
":bounds_check",
|
||||
":variable_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:training_ops_op_lib",
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
@ -44,9 +45,9 @@ class LaunchConv2DOp {
|
||||
template <class T, size_t size>
|
||||
struct Im2ColBufferResource : public ResourceBase {
|
||||
Im2ColBufferResource<T, size>() {
|
||||
data = static_cast<T*>(malloc(size * sizeof(T)));
|
||||
data = static_cast<T*>(port::Malloc(size * sizeof(T)));
|
||||
}
|
||||
~Im2ColBufferResource<T, size>() { free(data); }
|
||||
~Im2ColBufferResource<T, size>() { port::Free(data); }
|
||||
// This mutex ensures that only a single operation at a time is able to use
|
||||
// the buffer memory held by this resource.
|
||||
mutex mu;
|
||||
|
@ -26,9 +26,9 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
const bool SHOW_DBG_IN_SOC = false;
|
||||
const bool DBG_DUMP_RESULT = false;
|
||||
const bool DBG_USE_DUMMY_INPUT = false;
|
||||
const bool DBG_USE_SAMPLE_INPUT = false;
|
||||
const bool DBG_SHOW_RESULT = false;
|
||||
const int64 FLAG_ENABLE_PANDA_BINARY_INPUT = 0x01;
|
||||
|
||||
#ifdef USE_HEXAGON_LIBS
|
||||
@ -169,7 +169,7 @@ bool HexagonControlWrapper::SetupGraph(
|
||||
return soc_interface_ConstructGraph();
|
||||
|
||||
// Keep following comment to use dummy graph construction
|
||||
// return soc_interface_SetupGraphDummy(3 /* inception version */);
|
||||
// return soc_interface_setupDummyGraph(3 /* inception version */);
|
||||
}
|
||||
|
||||
bool HexagonControlWrapper::ExecuteGraph() {
|
||||
@ -213,7 +213,7 @@ bool HexagonControlWrapper::ReadOutputNode(
|
||||
// TODO: Accept all results
|
||||
std::get<2>(output) = DT_FLOAT;
|
||||
outputs->emplace_back(output);
|
||||
if (DBG_DUMP_RESULT) {
|
||||
if (DBG_SHOW_RESULT) {
|
||||
const int byte_size = std::get<1>(output);
|
||||
const int element_count = byte_size / sizeof(float);
|
||||
const float* float_array = reinterpret_cast<float*>(std::get<0>(output));
|
||||
|
@ -27,23 +27,83 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
namespace {
|
||||
|
||||
// Reverse rows (middle dimension) of a three dimensional tensor.
|
||||
// NUM_CHANNELS can be <= 0 to compute it dynamically from <input>
|
||||
// Otherwise, it must equal input.dim_size(2) and is used as a compile-time
|
||||
// constant.
|
||||
template <int NUM_CHANNELS>
|
||||
void ReverseRows(OpKernelContext* context, const Tensor& input,
|
||||
Tensor* result) {
|
||||
auto work = [&input, result](int64 start, int64 end) {
|
||||
const int64 inner_size =
|
||||
NUM_CHANNELS > 0 ? NUM_CHANNELS : input.dim_size(2);
|
||||
const int64 middle_size = input.dim_size(1);
|
||||
const int64 row_size = inner_size * middle_size;
|
||||
DCHECK_EQ(input.dim_size(2), inner_size);
|
||||
|
||||
const int32* in_ptr = input.bit_casted_tensor<int32, 3>().data();
|
||||
int32* out_ptr = result->bit_casted_tensor<int32, 3>().data();
|
||||
|
||||
in_ptr += start * row_size;
|
||||
out_ptr += start * row_size;
|
||||
|
||||
for (int outer_dim = start; outer_dim < end; ++outer_dim) {
|
||||
out_ptr += row_size;
|
||||
int remaining = middle_size;
|
||||
while (remaining > 0) {
|
||||
out_ptr -= inner_size;
|
||||
memcpy(out_ptr, in_ptr, inner_size * sizeof(float));
|
||||
in_ptr += inner_size;
|
||||
--remaining;
|
||||
}
|
||||
|
||||
out_ptr += row_size;
|
||||
}
|
||||
};
|
||||
|
||||
// Shard across outer dimension.
|
||||
const int64 N = input.dim_size(0);
|
||||
const int64 cost_per_unit = input.NumElements() / N;
|
||||
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
|
||||
Shard(worker_threads->num_threads, worker_threads->workers, N, cost_per_unit,
|
||||
std::move(work));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Device, typename T, int NDIMS>
|
||||
void HandleReverseCase(OpKernelContext* context,
|
||||
typename TTypes<bool, 1>::ConstTensor dims,
|
||||
Tensor* result) {
|
||||
const Tensor& input = context->input(0);
|
||||
|
||||
// Use optimized reverse if possible.
|
||||
if (NDIMS == 3 && std::is_same<Device, CPUDevice>::value &&
|
||||
std::is_same<T, float>::value && (!dims(0) && dims(1) && !dims(2))) {
|
||||
if (input.dim_size(2) == 3) {
|
||||
ReverseRows<3>(context, input, result);
|
||||
} else {
|
||||
ReverseRows<-1>(context, input, result);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
typename Eigen::array<bool, NDIMS> axes_di;
|
||||
for (int i = 0; i < NDIMS; i++) {
|
||||
axes_di[i] = dims(i);
|
||||
}
|
||||
functor::Reverse<Device, T, NDIMS>()(context->eigen_device<Device>(),
|
||||
context->input(0).tensor<T, NDIMS>(),
|
||||
axes_di, result->tensor<T, NDIMS>());
|
||||
input.tensor<T, NDIMS>(), axes_di,
|
||||
result->tensor<T, NDIMS>());
|
||||
}
|
||||
|
||||
template <typename Device, typename T>
|
||||
@ -105,13 +165,26 @@ class ReverseOp : public OpKernel {
|
||||
template <typename Device, typename T, int NDIMS>
|
||||
void HandleReverseV2Case(OpKernelContext* context,
|
||||
const gtl::ArraySlice<bool>& axes, Tensor* result) {
|
||||
const Tensor& input = context->input(0);
|
||||
|
||||
// Use optimized reverse if possible.
|
||||
if (NDIMS == 3 && std::is_same<Device, CPUDevice>::value &&
|
||||
std::is_same<T, float>::value && (!axes[0] && axes[1] && !axes[2])) {
|
||||
if (input.dim_size(2) == 3) {
|
||||
ReverseRows<3>(context, input, result);
|
||||
} else {
|
||||
ReverseRows<-1>(context, input, result);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
typename Eigen::array<bool, NDIMS> axes_di;
|
||||
for (int i = 0; i < NDIMS; i++) {
|
||||
axes_di[i] = axes[i];
|
||||
}
|
||||
functor::Reverse<Device, T, NDIMS>()(context->eigen_device<Device>(),
|
||||
context->input(0).tensor<T, NDIMS>(),
|
||||
axes_di, result->tensor<T, NDIMS>());
|
||||
input.tensor<T, NDIMS>(), axes_di,
|
||||
result->tensor<T, NDIMS>());
|
||||
}
|
||||
|
||||
template <typename Device, typename T>
|
||||
@ -158,6 +231,11 @@ class ReverseV2Op : public OpKernel {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, input.shape(), &output));
|
||||
|
||||
// TODO(cwhipkey): we can do dimension folding to reduce, e.g., a reverse of
|
||||
// a single dimension to the dims=3 or dims=2 case, regardless of the number
|
||||
// of dimensions in the tensor. This would let some ops use faster
|
||||
// lower-dimension code (and use optimized versions).
|
||||
|
||||
#define HANDLE_REVERSE(NDIMS) \
|
||||
case NDIMS: \
|
||||
HandleReverseV2Case<Device, T, NDIMS>(context, axes_dense, output); \
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
@ -31,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -109,5 +111,104 @@ TEST_F(ReverseOpTest, Reverse_1234) {
|
||||
test::ExpectTensorEqual<float>(expected, *params_tensor);
|
||||
}
|
||||
|
||||
static SessionOptions GetOptions(int intra_threads) {
|
||||
SessionOptions opts;
|
||||
opts.config.set_intra_op_parallelism_threads(intra_threads);
|
||||
opts.config.set_inter_op_parallelism_threads(1);
|
||||
return opts;
|
||||
}
|
||||
|
||||
// Creates a Graph which "reduce"s a 3D float tensor of "num" elements
|
||||
// into a scalar.
|
||||
static Graph* Reverse(TensorShape shape, int reverse_axis) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor data(DT_FLOAT, shape);
|
||||
data.flat<float>().setRandom();
|
||||
Tensor axes(DT_INT32, TensorShape({1}));
|
||||
axes.flat<int32>()(0) = reverse_axis;
|
||||
test::graph::Reverse(g, test::graph::Constant(g, data),
|
||||
test::graph::Constant(g, axes));
|
||||
return g;
|
||||
}
|
||||
|
||||
static void RunReverseRowsBenchmark(int iters, int outer_dim, int middle_dim,
|
||||
int intra_threads, int channels) {
|
||||
SessionOptions opts = GetOptions(intra_threads);
|
||||
TensorShape shape{outer_dim, middle_dim, channels};
|
||||
const int64 num_items = static_cast<int64>(iters) * shape.num_elements();
|
||||
testing::ItemsProcessed(num_items);
|
||||
testing::BytesProcessed(num_items * sizeof(float));
|
||||
testing::UseRealTime();
|
||||
test::Benchmark("cpu", Reverse(shape, 1), &opts).Run(iters);
|
||||
}
|
||||
|
||||
static void BM_ReverseRowsOf1Channel_1T(int iters, int outer_dim,
|
||||
int middle_dim) {
|
||||
RunReverseRowsBenchmark(iters, outer_dim, middle_dim, 1 /* intra_threads */,
|
||||
1 /* channels */);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ReverseRowsOf1Channel_1T)
|
||||
->ArgPair(288, 288)
|
||||
->ArgPair(1024, 1024)
|
||||
->ArgPair(10 * 1024, 1024);
|
||||
|
||||
static void BM_ReverseRowsOf1Channel_4T(int iters, int outer_dim,
|
||||
int middle_dim) {
|
||||
RunReverseRowsBenchmark(iters, outer_dim, middle_dim, 4 /* intra_threads */,
|
||||
1 /* channels */);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ReverseRowsOf1Channel_4T)
|
||||
->ArgPair(288, 288)
|
||||
->ArgPair(1024, 1024)
|
||||
->ArgPair(10 * 1024, 1024);
|
||||
|
||||
static void BM_ReverseRowsOf3Channels_1T(int iters, int outer_dim,
|
||||
int middle_dim) {
|
||||
RunReverseRowsBenchmark(iters, outer_dim, middle_dim, 1 /* intra_threads */,
|
||||
3 /* channels */);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ReverseRowsOf3Channels_1T)
|
||||
->ArgPair(288, 288)
|
||||
->ArgPair(224, 224)
|
||||
->ArgPair(1024, 1024)
|
||||
->ArgPair(10 * 1024, 1024);
|
||||
|
||||
static void BM_ReverseRowsOf3Channels_4T(int iters, int outer_dim,
|
||||
int middle_dim) {
|
||||
RunReverseRowsBenchmark(iters, outer_dim, middle_dim, 4 /* intra_threads */,
|
||||
3 /* channels */);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ReverseRowsOf3Channels_4T)
|
||||
->ArgPair(288, 288)
|
||||
->ArgPair(224, 224)
|
||||
->ArgPair(1024, 1024)
|
||||
->ArgPair(10 * 1024, 1024);
|
||||
|
||||
static void BM_ReverseRowsOf4Channels_1T(int iters, int outer_dim,
|
||||
int middle_dim) {
|
||||
RunReverseRowsBenchmark(iters, outer_dim, middle_dim, 1 /* intra_threads */,
|
||||
4 /* channels */);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ReverseRowsOf4Channels_1T)
|
||||
->ArgPair(288, 288)
|
||||
->ArgPair(1024, 1024)
|
||||
->ArgPair(10 * 1024, 1024);
|
||||
|
||||
static void BM_ReverseRowsOf4Channels_4T(int iters, int outer_dim,
|
||||
int middle_dim) {
|
||||
RunReverseRowsBenchmark(iters, outer_dim, middle_dim, 4 /* intra_threads */,
|
||||
4 /* channels */);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ReverseRowsOf4Channels_4T)
|
||||
->ArgPair(288, 288)
|
||||
->ArgPair(1024, 1024)
|
||||
->ArgPair(10 * 1024, 1024);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
130
tensorflow/core/kernels/stage_op.cc
Normal file
130
tensorflow/core/kernels/stage_op.cc
Normal file
@ -0,0 +1,130 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <deque>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
class Buffer : public ResourceBase {
|
||||
public:
|
||||
explicit Buffer() {}
|
||||
|
||||
typedef std::vector<Tensor> Tuple;
|
||||
|
||||
// the Buffer takes ownership of the Tuple
|
||||
void Put(Tuple* tuple) {
|
||||
mutex_lock l(mu_);
|
||||
buf_.push_back(std::move(*tuple));
|
||||
non_empty_cond_var_.notify_one(); // maybe possible to optimize by reducing
|
||||
// how often this signal is sent
|
||||
}
|
||||
|
||||
void Get(Tuple* tuple) { // TODO(zhifengc): Support cancellation.
|
||||
mutex_lock l(mu_);
|
||||
while (buf_.empty()) {
|
||||
non_empty_cond_var_.wait(l);
|
||||
}
|
||||
|
||||
*tuple = std::move(buf_.front());
|
||||
buf_.pop_front();
|
||||
}
|
||||
|
||||
string DebugString() {
|
||||
mutex_lock l(mu_);
|
||||
return strings::StrCat("Staging size: ", buf_.size());
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
condition_variable non_empty_cond_var_;
|
||||
std::deque<Tuple> buf_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
Status CreateBuffer(Buffer** ret) {
|
||||
*ret = new Buffer;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetBuffer(OpKernelContext* ctx, const NodeDef& ndef, Buffer** buf) {
|
||||
auto rm = ctx->resource_manager();
|
||||
ContainerInfo cinfo;
|
||||
TF_RETURN_IF_ERROR(cinfo.Init(rm, ndef, true /* use name() */));
|
||||
TF_RETURN_IF_ERROR(rm->LookupOrCreate<Buffer>(cinfo.container(), cinfo.name(),
|
||||
buf, CreateBuffer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class StageOp : public OpKernel {
|
||||
public:
|
||||
explicit StageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Buffer* buf = nullptr;
|
||||
OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
|
||||
core::ScopedUnref scope(buf);
|
||||
Buffer::Tuple tuple;
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
tuple.push_back(ctx->input(i));
|
||||
}
|
||||
buf->Put(&tuple);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_CPU), StageOp);
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_GPU), StageOp);
|
||||
#endif
|
||||
|
||||
class UnstageOp : public OpKernel {
|
||||
public:
|
||||
explicit UnstageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
// Using this op in such a way that it blocks forever
|
||||
// is an error. As such cancellation is not handled.
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Buffer* buf = nullptr;
|
||||
OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
|
||||
core::ScopedUnref scope(buf);
|
||||
Buffer::Tuple tuple;
|
||||
buf->Get(&tuple);
|
||||
OP_REQUIRES(
|
||||
ctx, tuple.size() == ctx->num_outputs(),
|
||||
errors::InvalidArgument("Mismatch stage/unstage: ", tuple.size(),
|
||||
" vs. ", ctx->num_outputs()));
|
||||
for (int i = 0; i < tuple.size(); ++i) {
|
||||
ctx->set_output(i, tuple[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_CPU), UnstageOp);
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_GPU), UnstageOp);
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/variable_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -292,10 +293,26 @@ struct ApplyCenteredRMSProp<CPUDevice, T> {
|
||||
|
||||
} // namespace functor
|
||||
|
||||
mutex* GetMutex(OpKernelContext* ctx, int input) {
|
||||
if (ctx->input_dtype(input) == DT_RESOURCE) {
|
||||
Var* var;
|
||||
if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
|
||||
return var->mu();
|
||||
} else {
|
||||
ctx->CtxFailureWithWarning(
|
||||
errors::Internal("Invalid variable reference."));
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return ctx->input_ref_mutex(input);
|
||||
}
|
||||
|
||||
// MaybeLockMutexesInOrder is a helper function to acquire mutexes in address
|
||||
// order to mitigate deadlock. Returns a vector of acquired mutexes.
|
||||
// Safe to pass duplicates - will only lock each distinct mutex once.
|
||||
// If do_lock is false, returns immediately.
|
||||
// order to mitigate deadlock. Returns a vector of acquired mutexes. Safe to
|
||||
// pass duplicates - will only lock each distinct mutex once. If do_lock is
|
||||
// false, returns immediately. Note that this silently doesn't lock mutexes for
|
||||
// invalid variable references; in all usages this is followed by GetInputTensor
|
||||
// which will signal a failure.
|
||||
std::vector<mutex_lock> MaybeLockMutexesInOrder(
|
||||
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) {
|
||||
std::vector<mutex_lock> locks;
|
||||
@ -305,7 +322,7 @@ std::vector<mutex_lock> MaybeLockMutexesInOrder(
|
||||
std::vector<mutex*> mutexes;
|
||||
std::vector<int> acquire_order;
|
||||
for (auto input : input_ids) {
|
||||
auto* mutex = ctx->input_ref_mutex(input);
|
||||
mutex* mutex = GetMutex(ctx, input);
|
||||
// Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
|
||||
if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
|
||||
acquire_order.push_back(input);
|
||||
@ -316,11 +333,41 @@ std::vector<mutex_lock> MaybeLockMutexesInOrder(
|
||||
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
|
||||
|
||||
for (auto input : acquire_order) {
|
||||
locks.emplace_back(*ctx->input_ref_mutex(input));
|
||||
mutex* mu = GetMutex(ctx, input);
|
||||
if (mu != nullptr) {
|
||||
locks.emplace_back(*mu);
|
||||
}
|
||||
}
|
||||
return locks;
|
||||
}
|
||||
|
||||
Status GetInputTensor(OpKernelContext* ctx, int input, bool lock_held,
|
||||
Tensor* out) {
|
||||
if (ctx->input_dtype(input) == DT_RESOURCE) {
|
||||
Var* var;
|
||||
if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
|
||||
if (lock_held) {
|
||||
*out = *var->tensor();
|
||||
} else {
|
||||
mutex_lock ml(*var->mu());
|
||||
*out = *var->tensor();
|
||||
}
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::Internal("Invalid variable reference.");
|
||||
}
|
||||
}
|
||||
*out = ctx->mutable_input(input, lock_held);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
|
||||
int output) {
|
||||
if (ctx->input_dtype(input) != DT_RESOURCE) {
|
||||
ctx->forward_ref_input_to_ref_output(input, output);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Device, typename T>
|
||||
class ApplyGradientDescentOp : public OpKernel {
|
||||
public:
|
||||
@ -330,7 +377,8 @@ class ApplyGradientDescentOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0});
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -351,7 +399,7 @@ class ApplyGradientDescentOp : public OpKernel {
|
||||
functor::ApplyGradientDescent<Device, T>()(
|
||||
device, var.flat<T>(), alpha.scalar<T>(), delta.flat<T>());
|
||||
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -361,7 +409,11 @@ class ApplyGradientDescentOp : public OpKernel {
|
||||
#define REGISTER_KERNELS(D, T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ApplyGradientDescent").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
ApplyGradientDescentOp<D##Device, T>);
|
||||
ApplyGradientDescentOp<D##Device, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ResourceApplyGradientDescent") \
|
||||
.Device(DEVICE_##D) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
ApplyGradientDescentOp<D##Device, T>);
|
||||
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
|
||||
|
||||
TF_CALL_half(REGISTER_CPU_KERNELS);
|
||||
@ -406,7 +458,7 @@ class ApplyAdadeltaOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
if (use_exclusive_lock_) {
|
||||
mutex_lock l1(*ctx->input_ref_mutex(0));
|
||||
mutex_lock l1(*GetMutex(ctx, 0));
|
||||
// Don't try to acquire a lock on the second ref as they share the same
|
||||
// mutex.
|
||||
//
|
||||
@ -419,16 +471,20 @@ class ApplyAdadeltaOp : public OpKernel {
|
||||
if (!ctx->status().ok()) return;
|
||||
DoCompute(ctx);
|
||||
}
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_exclusive_lock_;
|
||||
|
||||
void DoValidate(OpKernelContext* ctx) {
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor accum_update = ctx->mutable_input(2, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
Tensor accum_update;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
GetInputTensor(ctx, 2, use_exclusive_lock_, &accum_update));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -474,9 +530,13 @@ class ApplyAdadeltaOp : public OpKernel {
|
||||
|
||||
void DoCompute(OpKernelContext* ctx) {
|
||||
const Device& device = ctx->template eigen_device<Device>();
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor accum_update = ctx->mutable_input(2, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
Tensor accum_update;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
GetInputTensor(ctx, 2, use_exclusive_lock_, &accum_update));
|
||||
|
||||
const Tensor& lr = ctx->input(3);
|
||||
const Tensor& rho = ctx->input(4);
|
||||
@ -492,9 +552,12 @@ class ApplyAdadeltaOp : public OpKernel {
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
#define REGISTER_KERNELS(D, T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ApplyAdadelta").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
#define REGISTER_KERNELS(D, T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ApplyAdadelta").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
ApplyAdadeltaOp<D##Device, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ResourceApplyAdadelta").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
ApplyAdadeltaOp<D##Device, T>);
|
||||
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
|
||||
|
||||
@ -536,7 +599,7 @@ class SparseApplyAdadeltaOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
mutex* mu_var = ctx->input_ref_mutex(0);
|
||||
mutex* mu_var = GetMutex(ctx, 0);
|
||||
// mu_accum is actually the same mutex as mu_var since currently we use a
|
||||
// global mutex.
|
||||
//
|
||||
@ -544,9 +607,14 @@ class SparseApplyAdadeltaOp : public OpKernel {
|
||||
if (use_exclusive_lock_) {
|
||||
mu_var->lock();
|
||||
}
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor accum_grad = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor accum_update = ctx->mutable_input(2, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum_grad;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
GetInputTensor(ctx, 1, use_exclusive_lock_, &accum_grad));
|
||||
Tensor accum_update;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
GetInputTensor(ctx, 2, use_exclusive_lock_, &accum_update));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -642,7 +710,7 @@ class SparseApplyAdadeltaOp : public OpKernel {
|
||||
mu_var->unlock();
|
||||
}
|
||||
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -654,6 +722,11 @@ class SparseApplyAdadeltaOp : public OpKernel {
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyAdadeltaOp<T, Tindices>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdadelta") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyAdadeltaOp<T, Tindices>);
|
||||
#define REGISTER_CPU_KERNELS(T) \
|
||||
REGISTER_KERNELS(T, int32); \
|
||||
@ -677,7 +750,8 @@ class ApplyProximalGradientDescentOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0});
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -710,17 +784,21 @@ class ApplyProximalGradientDescentOp : public OpKernel {
|
||||
device, var.flat<T>(), alpha.scalar<T>(), l1.scalar<T>(),
|
||||
l2.scalar<T>(), delta.flat<T>());
|
||||
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_exclusive_lock_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS(D, T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("ApplyProximalGradientDescent") \
|
||||
.Device(DEVICE_##D) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
#define REGISTER_KERNELS(D, T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("ApplyProximalGradientDescent") \
|
||||
.Device(DEVICE_##D) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
ApplyProximalGradientDescentOp<D##Device, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ResourceApplyProximalGradientDescent") \
|
||||
.Device(DEVICE_##D) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
ApplyProximalGradientDescentOp<D##Device, T>);
|
||||
|
||||
REGISTER_KERNELS(CPU, float);
|
||||
@ -738,7 +816,8 @@ class SparseApplyProximalGradientDescentOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
|
||||
errors::InvalidArgument("var must be at least 1 dimensional"));
|
||||
|
||||
@ -846,18 +925,23 @@ class SparseApplyProximalGradientDescentOp : public OpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_exclusive_lock_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS(T, Tindices) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseApplyProximalGradientDescent") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
#define REGISTER_KERNELS(T, Tindices) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseApplyProximalGradientDescent") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyProximalGradientDescentOp<T, Tindices>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyProximalGradientDescent") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyProximalGradientDescentOp<T, Tindices>);
|
||||
|
||||
REGISTER_KERNELS(float, int32);
|
||||
@ -875,8 +959,10 @@ class ApplyAdagradOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -905,7 +991,7 @@ class ApplyAdagradOp : public OpKernel {
|
||||
functor::ApplyAdagrad<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
|
||||
lr.scalar<T>(), grad.flat<T>());
|
||||
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -915,9 +1001,12 @@ class ApplyAdagradOp : public OpKernel {
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
#define REGISTER_KERNELS(D, T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ApplyAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
#define REGISTER_KERNELS(D, T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ApplyAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
ApplyAdagradOp<D##Device, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ResourceApplyAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
ApplyAdagradOp<D##Device, T>);
|
||||
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
|
||||
|
||||
@ -957,8 +1046,10 @@ class ApplyProximalAdagradOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1004,7 +1095,7 @@ class ApplyProximalAdagradOp : public OpKernel {
|
||||
device, var.flat<T>(), accum.flat<T>(), lr.scalar<T>(), l1.scalar<T>(),
|
||||
l2.scalar<T>(), grad.flat<T>());
|
||||
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -1017,7 +1108,11 @@ using GPUDevice = Eigen::GpuDevice;
|
||||
#define REGISTER_KERNELS(D, T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ApplyProximalAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
ApplyProximalAdagradOp<D##Device, T>);
|
||||
ApplyProximalAdagradOp<D##Device, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ResourceApplyProximalAdagrad") \
|
||||
.Device(DEVICE_##D) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
ApplyProximalAdagradOp<D##Device, T>);
|
||||
|
||||
REGISTER_KERNELS(CPU, float);
|
||||
REGISTER_KERNELS(CPU, double);
|
||||
@ -1053,8 +1148,10 @@ class SparseApplyAdagradOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1142,7 +1239,7 @@ class SparseApplyAdagradOp : public OpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -1154,6 +1251,11 @@ class SparseApplyAdagradOp : public OpKernel {
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyAdagradOp<T, Tindices>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagrad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyAdagradOp<T, Tindices>);
|
||||
#define REGISTER_CPU_KERNELS(T) \
|
||||
REGISTER_KERNELS(T, int32); \
|
||||
@ -1177,8 +1279,10 @@ class SparseApplyProximalAdagradOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1311,18 +1415,23 @@ class SparseApplyProximalAdagradOp : public OpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_exclusive_lock_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS(T, Tindices) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseApplyProximalAdagrad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
#define REGISTER_KERNELS(T, Tindices) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseApplyProximalAdagrad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyProximalAdagradOp<T, Tindices>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyProximalAdagrad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyProximalAdagradOp<T, Tindices>);
|
||||
|
||||
REGISTER_KERNELS(float, int32);
|
||||
@ -1340,9 +1449,14 @@ class ApplyAdagradDAOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor gradient_accum = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor gradient_squared_accum = ctx->mutable_input(2, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor gradient_accum;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &gradient_accum));
|
||||
Tensor gradient_squared_accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_,
|
||||
&gradient_squared_accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1399,7 +1513,7 @@ class ApplyAdagradDAOp : public OpKernel {
|
||||
global_step.scalar<int64>()(), l1.scalar<T>(), l2.scalar<T>(),
|
||||
grad.flat<T>());
|
||||
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -1428,9 +1542,14 @@ class SparseApplyAdagradDAOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor gradient_accum = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor gradient_squared_accum = ctx->mutable_input(2, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor gradient_accum;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &gradient_accum));
|
||||
Tensor gradient_squared_accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_,
|
||||
&gradient_squared_accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1580,7 +1699,7 @@ class SparseApplyAdagradDAOp : public OpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -1592,6 +1711,11 @@ class SparseApplyAdagradDAOp : public OpKernel {
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyAdagradDAOp<T, Tindices>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagradDA") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyAdagradDAOp<T, Tindices>);
|
||||
|
||||
REGISTER_KERNELS(float, int32);
|
||||
@ -1610,9 +1734,12 @@ class ApplyFtrlOp : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
|
||||
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor linear = ctx->mutable_input(2, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
Tensor linear;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &linear));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1677,7 +1804,7 @@ class ApplyFtrlOp : public OpKernel {
|
||||
lr.scalar<T>(), l1.scalar<T>(),
|
||||
l2.scalar<T>(), lr_power.scalar<T>());
|
||||
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -1687,9 +1814,12 @@ class ApplyFtrlOp : public OpKernel {
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
#define REGISTER_KERNELS(D, T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ApplyFtrl").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
#define REGISTER_KERNELS(D, T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ApplyFtrl").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
ApplyFtrlOp<D##Device, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ResourceApplyFtrl").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
ApplyFtrlOp<D##Device, T>);
|
||||
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
|
||||
|
||||
@ -1710,9 +1840,12 @@ class SparseApplyFtrlOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor linear = ctx->mutable_input(2, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
Tensor linear;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &linear));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1874,18 +2007,23 @@ class SparseApplyFtrlOp : public OpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_exclusive_lock_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS(T, Tindices) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseApplyFtrl") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
#define REGISTER_KERNELS(T, Tindices) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseApplyFtrl") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyFtrlOp<CPUDevice, T, Tindices>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyFtrl") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyFtrlOp<CPUDevice, T, Tindices>);
|
||||
#define REGISTER_CPU_KERNELS(T) \
|
||||
REGISTER_KERNELS(T, int32); \
|
||||
@ -1909,8 +2047,10 @@ class ApplyMomentumOp : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -1944,7 +2084,7 @@ class ApplyMomentumOp : public OpKernel {
|
||||
functor::ApplyMomentum<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
|
||||
lr.scalar<T>(), grad.flat<T>(),
|
||||
momentum.scalar<T>(), use_nesterov_);
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -1955,9 +2095,12 @@ class ApplyMomentumOp : public OpKernel {
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
#define REGISTER_KERNELS(D, T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ApplyMomentum").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
#define REGISTER_KERNELS(D, T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ApplyMomentum").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
ApplyMomentumOp<D##Device, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ResourceApplyMomentum").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
ApplyMomentumOp<D##Device, T>);
|
||||
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
|
||||
|
||||
@ -2001,8 +2144,10 @@ class SparseApplyMomentumOp : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
|
||||
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor accum;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -2072,7 +2217,7 @@ class SparseApplyMomentumOp : public OpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -2085,6 +2230,11 @@ class SparseApplyMomentumOp : public OpKernel {
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyMomentumOp<T, Tindices>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyMomentum") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyMomentumOp<T, Tindices>);
|
||||
#define REGISTER_CPU_KERNELS(T) \
|
||||
REGISTER_KERNELS(T, int32); \
|
||||
@ -2107,9 +2257,12 @@ class ApplyAdamOp : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
|
||||
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor m = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor v = ctx->mutable_input(2, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor m;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &m));
|
||||
Tensor v;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &v));
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
errors::FailedPrecondition(
|
||||
@ -2171,7 +2324,7 @@ class ApplyAdamOp : public OpKernel {
|
||||
beta1.scalar<T>(), beta2.scalar<T>(),
|
||||
epsilon.scalar<T>(), grad.flat<T>());
|
||||
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -2181,9 +2334,12 @@ class ApplyAdamOp : public OpKernel {
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
#define REGISTER_KERNELS(D, T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ApplyAdam").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
#define REGISTER_KERNELS(D, T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ApplyAdam").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
ApplyAdamOp<D##Device, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ResourceApplyAdam").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
ApplyAdamOp<D##Device, T>);
|
||||
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
|
||||
|
||||
@ -2236,9 +2392,12 @@ class ApplyRMSPropOp : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
|
||||
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor ms = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor mom = ctx->mutable_input(2, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor ms;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &ms));
|
||||
Tensor mom;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &mom));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -2294,7 +2453,7 @@ class ApplyRMSPropOp : public OpKernel {
|
||||
rho.scalar<T>(), momentum.scalar<T>(),
|
||||
epsilon.scalar<T>(), grad.flat<T>());
|
||||
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -2312,10 +2471,14 @@ class ApplyCenteredRMSPropOp : public OpKernel {
|
||||
auto locks =
|
||||
MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2, 3});
|
||||
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor mg = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor ms = ctx->mutable_input(2, use_exclusive_lock_);
|
||||
Tensor mom = ctx->mutable_input(3, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor mg;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &mg));
|
||||
Tensor ms;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &ms));
|
||||
Tensor mom;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 3, use_exclusive_lock_, &mom));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -2379,7 +2542,7 @@ class ApplyCenteredRMSPropOp : public OpKernel {
|
||||
device, var.flat<T>(), mg.flat<T>(), ms.flat<T>(), mom.flat<T>(),
|
||||
lr.scalar<T>(), rho.scalar<T>(), momentum.scalar<T>(),
|
||||
epsilon.scalar<T>(), grad.flat<T>());
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -2395,7 +2558,14 @@ using GPUDevice = Eigen::GpuDevice;
|
||||
ApplyRMSPropOp<D##Device, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ApplyCenteredRMSProp").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
ApplyCenteredRMSPropOp<D##Device, T>);
|
||||
ApplyCenteredRMSPropOp<D##Device, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ResourceApplyRMSProp").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
ApplyRMSPropOp<D##Device, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ResourceApplyCenteredRMSProp") \
|
||||
.Device(DEVICE_##D) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
ApplyCenteredRMSPropOp<D##Device, T>);
|
||||
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
|
||||
|
||||
TF_CALL_half(REGISTER_CPU_KERNELS);
|
||||
@ -2449,9 +2619,12 @@ class SparseApplyRMSPropOp : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
|
||||
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
|
||||
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor ms = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor mom = ctx->mutable_input(2, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor ms;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &ms));
|
||||
Tensor mom;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &mom));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -2552,7 +2725,7 @@ class SparseApplyRMSPropOp : public OpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -2572,10 +2745,14 @@ class SparseApplyCenteredRMSPropOp : public OpKernel {
|
||||
auto locks =
|
||||
MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2, 3});
|
||||
|
||||
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
|
||||
Tensor mg = ctx->mutable_input(1, use_exclusive_lock_);
|
||||
Tensor ms = ctx->mutable_input(2, use_exclusive_lock_);
|
||||
Tensor mom = ctx->mutable_input(3, use_exclusive_lock_);
|
||||
Tensor var;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
|
||||
Tensor mg;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &mg));
|
||||
Tensor ms;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &ms));
|
||||
Tensor mom;
|
||||
OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 3, use_exclusive_lock_, &mom));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, var.IsInitialized(),
|
||||
@ -2685,23 +2862,33 @@ class SparseApplyCenteredRMSPropOp : public OpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
ctx->forward_ref_input_to_ref_output(0, 0);
|
||||
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_exclusive_lock_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS(T, Tindices) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseApplyRMSProp") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyRMSPropOp<T, Tindices>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseApplyCenteredRMSProp") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
#define REGISTER_KERNELS(T, Tindices) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseApplyRMSProp") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyRMSPropOp<T, Tindices>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseApplyCenteredRMSProp") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyCenteredRMSPropOp<T, Tindices>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyRMSProp") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyRMSPropOp<T, Tindices>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyCenteredRMSProp") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<Tindices>("Tindices"), \
|
||||
SparseApplyCenteredRMSPropOp<T, Tindices>);
|
||||
|
||||
REGISTER_KERNELS(Eigen::half, int32);
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -48,7 +49,8 @@ Arena::Arena(const size_t block_size)
|
||||
overflow_blocks_(NULL) {
|
||||
assert(block_size > kDefaultAlignment);
|
||||
|
||||
first_blocks_[0].mem = reinterpret_cast<char*>(malloc(block_size_));
|
||||
first_blocks_[0].mem =
|
||||
reinterpret_cast<char*>(port::AlignedMalloc(block_size_, sizeof(void*)));
|
||||
|
||||
first_blocks_[0].size = block_size_;
|
||||
|
||||
@ -59,7 +61,9 @@ Arena::~Arena() {
|
||||
FreeBlocks();
|
||||
assert(overflow_blocks_ == NULL); // FreeBlocks() should do that
|
||||
// The first X blocks stay allocated always by default. Delete them now.
|
||||
for (size_t i = 0; i < blocks_alloced_; ++i) free(first_blocks_[i].mem);
|
||||
for (size_t i = 0; i < blocks_alloced_; ++i) {
|
||||
port::AlignedFree(first_blocks_[i].mem);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true iff it advances freestart_ to the first position
|
||||
@ -162,8 +166,11 @@ Arena::AllocatedBlock* Arena::AllocNewBlock(const size_t block_size,
|
||||
|
||||
// Must be a multiple of kDefaultAlignment, unless requested
|
||||
// alignment is 1, in which case we don't care at all.
|
||||
const uint32 adjusted_alignment =
|
||||
uint32 adjusted_alignment =
|
||||
(alignment > 1 ? LeastCommonMultiple(alignment, kDefaultAlignment) : 1);
|
||||
// Required minimum alignment for port::AlignedMalloc().
|
||||
adjusted_alignment =
|
||||
std::max(adjusted_alignment, static_cast<uint32>(sizeof(void*)));
|
||||
|
||||
CHECK_LE(adjusted_alignment, static_cast<uint32>(1 << 20))
|
||||
<< "Alignment on boundaries greater than 1MB not supported.";
|
||||
@ -171,16 +178,12 @@ Arena::AllocatedBlock* Arena::AllocNewBlock(const size_t block_size,
|
||||
// If block_size > alignment we force block_size to be a multiple
|
||||
// of alignment; if block_size < alignment we make no adjustment.
|
||||
size_t adjusted_block_size = block_size;
|
||||
if (adjusted_alignment > 1) {
|
||||
if (adjusted_block_size > adjusted_alignment) {
|
||||
const uint32 excess = adjusted_block_size % adjusted_alignment;
|
||||
adjusted_block_size += (excess > 0 ? adjusted_alignment - excess : 0);
|
||||
}
|
||||
block->mem = reinterpret_cast<char*>(
|
||||
port::aligned_malloc(adjusted_block_size, adjusted_alignment));
|
||||
} else {
|
||||
block->mem = reinterpret_cast<char*>(malloc(adjusted_block_size));
|
||||
if (adjusted_block_size > adjusted_alignment) {
|
||||
const uint32 excess = adjusted_block_size % adjusted_alignment;
|
||||
adjusted_block_size += (excess > 0 ? adjusted_alignment - excess : 0);
|
||||
}
|
||||
block->mem = reinterpret_cast<char*>(
|
||||
port::AlignedMalloc(adjusted_block_size, adjusted_alignment));
|
||||
block->size = adjusted_block_size;
|
||||
CHECK(NULL != block->mem) << "block_size=" << block_size
|
||||
<< " adjusted_block_size=" << adjusted_block_size
|
||||
@ -242,7 +245,7 @@ void* Arena::GetMemoryFallback(const size_t size, const int alignment) {
|
||||
|
||||
void Arena::FreeBlocks() {
|
||||
for (size_t i = 1; i < blocks_alloced_; ++i) { // keep first block alloced
|
||||
free(first_blocks_[i].mem);
|
||||
port::AlignedFree(first_blocks_[i].mem);
|
||||
first_blocks_[i].mem = NULL;
|
||||
first_blocks_[i].size = 0;
|
||||
}
|
||||
@ -250,7 +253,7 @@ void Arena::FreeBlocks() {
|
||||
if (overflow_blocks_ != NULL) {
|
||||
std::vector<AllocatedBlock>::iterator it;
|
||||
for (it = overflow_blocks_->begin(); it != overflow_blocks_->end(); ++it) {
|
||||
free(it->mem);
|
||||
port::AlignedFree(it->mem);
|
||||
}
|
||||
delete overflow_blocks_; // These should be used very rarely
|
||||
overflow_blocks_ = NULL;
|
||||
|
@ -45,6 +45,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/manual_constructor.h"
|
||||
#include "tensorflow/core/platform/cpu_info.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
#include <initializer_list> // NOLINT(build/include_order)
|
||||
@ -353,7 +354,7 @@ class InlinedVector {
|
||||
size_t n = size();
|
||||
Destroy(base, n);
|
||||
if (!is_inline()) {
|
||||
free(base);
|
||||
port::Free(base);
|
||||
}
|
||||
}
|
||||
|
||||
@ -434,7 +435,7 @@ class InlinedVector {
|
||||
}
|
||||
|
||||
T* src = data();
|
||||
T* dst = static_cast<T*>(malloc(target * sizeof(T)));
|
||||
T* dst = static_cast<T*>(port::Malloc(target * sizeof(T)));
|
||||
|
||||
// Need to copy elem before discarding src since it might alias src.
|
||||
InitType{}(dst + s, std::forward<Args>(args)...);
|
||||
|
@ -30,7 +30,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mem.h" // For aligned_malloc/aligned_free
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gtl {
|
||||
@ -127,9 +127,9 @@ class ManualConstructor {
|
||||
// Support users creating arrays of ManualConstructor<>s. This ensures that
|
||||
// the array itself has the correct alignment.
|
||||
static void* operator new[](size_t size) {
|
||||
return port::aligned_malloc(size, TF_LIB_GTL_ALIGN_OF(Type));
|
||||
return port::AlignedMalloc(size, TF_LIB_GTL_ALIGN_OF(Type));
|
||||
}
|
||||
static void operator delete[](void* mem) { port::aligned_free(mem); }
|
||||
static void operator delete[](void* mem) { port::AlignedFree(mem); }
|
||||
|
||||
inline Type* get() { return reinterpret_cast<Type*>(space_); }
|
||||
inline const Type* get() const {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -2180,4 +2180,35 @@ Delete the tensor specified by its handle in the session.
|
||||
handle: The handle for a tensor stored in the session state.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("Stage")
|
||||
.Input("values: dtypes")
|
||||
.Attr("dtypes: list(type)")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''")
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.SetIsStateful()
|
||||
.Doc(R"doc(
|
||||
Stage values similar to a lightweight Enqueue. The basic functionality of this
|
||||
Op is similar to a queue with many fewer capabilities and options. This Op is
|
||||
optimized for performance.
|
||||
|
||||
values: a list of tensors
|
||||
container: If non-empty, this queue is placed in the given container. Otherwise,
|
||||
a default container is used.
|
||||
shared_name: It is necessary to match this name to the matching Unstage Op.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("Unstage")
|
||||
.Output("values: dtypes")
|
||||
.Attr("dtypes: list(type)")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''")
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.SetIsStateful()
|
||||
.Doc(R"doc(
|
||||
Op is similar to a lightweight Dequeue. The basic funtionality is similar to
|
||||
dequeue with many fewer capabilities and options. This Op is optimized for
|
||||
performance.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -64,24 +64,20 @@ REGISTER_OP("DenseToDenseSetOperation")
|
||||
}
|
||||
// The following should stay in sync with `ComputeDenseToDense` shape
|
||||
// assertions in kernels/set_kernels.cc.
|
||||
// Dimension n contains the set values to be compared, so ranks and the
|
||||
// first n-1 dimensions of inputs and output must match.
|
||||
// Dimension n contains the set values to be compared, so ranks must be
|
||||
// >= 2, and the first n-1 dimensions of inputs and output must be
|
||||
// compatible.
|
||||
DimensionHandle output_rank;
|
||||
ShapeHandle input0_shape = c->input(0);
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input0_shape, 2, &input0_shape));
|
||||
if (c->RankKnown(input0_shape)) {
|
||||
const int32 input0_rank = c->Rank(input0_shape);
|
||||
if (input0_rank < 2) {
|
||||
return errors::InvalidArgument("Input 0, expected rank >= 2, got ",
|
||||
input0_rank, ".");
|
||||
}
|
||||
ShapeHandle input1_shape = c->input(1);
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(input1_shape, input0_rank, &input1_shape));
|
||||
if (c->RankKnown(input1_shape)) {
|
||||
// If both ranks are specified, the first n-1 dims must be compatible.
|
||||
const int32 rank = c->Rank(input1_shape);
|
||||
if (input0_rank != rank) {
|
||||
return errors::InvalidArgument("Ranks do not match: input 0 ",
|
||||
input0_rank, ", input 1 ", rank,
|
||||
".");
|
||||
}
|
||||
ShapeHandle group0_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Subshape(input0_shape, 0, rank - 1, &group0_shape));
|
||||
@ -95,28 +91,16 @@ REGISTER_OP("DenseToDenseSetOperation")
|
||||
output_rank = c->MakeDim(input0_rank);
|
||||
} else {
|
||||
ShapeHandle input1_shape = c->input(1);
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input1_shape, 2, &input1_shape));
|
||||
if (c->RankKnown(input1_shape)) {
|
||||
const int32 input1_rank = c->Rank(input1_shape);
|
||||
if (input1_rank < 2) {
|
||||
return errors::InvalidArgument("Input 0, expected rank >= 2, got ",
|
||||
input1_rank, ".");
|
||||
}
|
||||
output_rank = c->MakeDim(input1_rank);
|
||||
output_rank = c->MakeDim(c->Rank(input1_shape));
|
||||
} else {
|
||||
output_rank = c->UnknownDim();
|
||||
}
|
||||
}
|
||||
DimensionHandle output_num_elements = c->Dim(input0_shape, 0);
|
||||
if (!c->ValueKnown(output_num_elements)) {
|
||||
ShapeHandle input1_shape = c->input(1);
|
||||
output_num_elements = c->Dim(input1_shape, 0);
|
||||
if (!c->ValueKnown(output_num_elements)) {
|
||||
output_num_elements = c->UnknownDim();
|
||||
}
|
||||
}
|
||||
|
||||
c->set_output(0, c->Matrix(output_num_elements, output_rank));
|
||||
c->set_output(1, c->Vector(output_num_elements));
|
||||
c->set_output(0, c->Matrix(c->UnknownDim(), output_rank));
|
||||
c->set_output(1, c->Vector(c->UnknownDim()));
|
||||
c->set_output(2, c->Vector(output_rank));
|
||||
return Status::OK();
|
||||
})
|
||||
@ -159,30 +143,30 @@ REGISTER_OP("DenseToSparseSetOperation")
|
||||
}
|
||||
// The following should stay in sync with `ComputeDenseToSparse` shape
|
||||
// assertions in kernels/set_kernels.cc.
|
||||
// Dimension n contains the set values to be compared, so ranks and the
|
||||
// first n-1 dimensions of inputs and output must match.
|
||||
DimensionHandle output_rank;
|
||||
// Ranks must be compatible, and be >= 2.
|
||||
ShapeHandle input1_shape_shape = c->input(3);
|
||||
TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
|
||||
c, c->input(1), c->input(2), input1_shape_shape));
|
||||
|
||||
DimensionHandle input1_rank_dim = c->Dim(input1_shape_shape, 0);
|
||||
|
||||
DimensionHandle output_rank_dim;
|
||||
ShapeHandle input0_shape = c->input(0);
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input0_shape, 2, &input0_shape));
|
||||
if (c->RankKnown(input0_shape)) {
|
||||
const int32 input0_rank = c->Rank(input0_shape);
|
||||
if (input0_rank < 2) {
|
||||
return errors::InvalidArgument("Input 0, expected rank >= 2, got ",
|
||||
input0_rank, ".");
|
||||
}
|
||||
output_rank = c->MakeDim(input0_rank);
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithValue(input1_rank_dim, input0_rank, &input1_rank_dim));
|
||||
output_rank_dim = c->MakeDim(input0_rank);
|
||||
} else if (c->ValueKnown(input1_rank_dim)) {
|
||||
output_rank_dim = input1_rank_dim;
|
||||
} else {
|
||||
output_rank = c->UnknownDim();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
|
||||
c, c->input(1), c->input(2), c->input(3)));
|
||||
DimensionHandle output_num_elements = c->Dim(input0_shape, 0);
|
||||
if (!c->ValueKnown(output_num_elements)) {
|
||||
output_num_elements = c->UnknownDim();
|
||||
output_rank_dim = c->UnknownDim();
|
||||
}
|
||||
|
||||
c->set_output(0, c->Matrix(output_num_elements, output_rank));
|
||||
c->set_output(1, c->Vector(output_num_elements));
|
||||
c->set_output(2, c->Vector(output_rank));
|
||||
c->set_output(0, c->Matrix(c->UnknownDim(), output_rank_dim));
|
||||
c->set_output(1, c->Vector(c->UnknownDim()));
|
||||
c->set_output(2, c->Vector(output_rank_dim));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
@ -239,13 +223,40 @@ REGISTER_OP("SparseToSparseSetOperation")
|
||||
}
|
||||
// The following should stay in sync with `ComputeSparseToSparse` shape
|
||||
// assertions in kernels/set_kernels.cc.
|
||||
// Ranks must be compatible, and be >= 2.
|
||||
ShapeHandle input0_shape_shape = c->input(2);
|
||||
ShapeHandle input1_shape_shape = c->input(5);
|
||||
TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
|
||||
c, c->input(0), c->input(1), c->input(2)));
|
||||
c, c->input(0), c->input(1), input0_shape_shape));
|
||||
TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
|
||||
c, c->input(3), c->input(4), c->input(5)));
|
||||
c->set_output(0, c->Matrix(c->UnknownDim(), c->UnknownDim()));
|
||||
c, c->input(3), c->input(4), input1_shape_shape));
|
||||
|
||||
DimensionHandle input0_rank_dim = c->Dim(input0_shape_shape, 0);
|
||||
DimensionHandle input1_rank_dim = c->Dim(input1_shape_shape, 0);
|
||||
DimensionHandle output_rank_dim;
|
||||
if (c->ValueKnown(input0_rank_dim)) {
|
||||
const int32 input0_rank = c->Value(input0_rank_dim);
|
||||
if (input0_rank < 2) {
|
||||
return errors::InvalidArgument("Input 0, expected rank >= 2, got ",
|
||||
input0_rank, ".");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithValue(input1_rank_dim, input0_rank, &input1_rank_dim));
|
||||
output_rank_dim = input0_rank_dim;
|
||||
} else if (c->ValueKnown(input1_rank_dim)) {
|
||||
const int32 input1_rank = c->Value(input1_rank_dim);
|
||||
if (input1_rank < 2) {
|
||||
return errors::InvalidArgument("Input 1, expected rank >= 2, got ",
|
||||
input1_rank, ".");
|
||||
}
|
||||
output_rank_dim = input1_rank_dim;
|
||||
} else {
|
||||
output_rank_dim = c->UnknownDim();
|
||||
}
|
||||
|
||||
c->set_output(0, c->Matrix(c->UnknownDim(), output_rank_dim));
|
||||
c->set_output(1, c->Vector(c->UnknownDim()));
|
||||
c->set_output(2, c->Vector(c->UnknownDim()));
|
||||
c->set_output(2, c->Vector(output_rank_dim));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
|
@ -34,16 +34,16 @@ TEST(SetOpsTest, DenseToDenseShape) {
|
||||
INFER_OK(op, "?;?", "[?,?];[?];[?]");
|
||||
|
||||
// Invalid rank.
|
||||
INFER_ERROR("expected rank >= 2", op, "[?];?");
|
||||
INFER_ERROR("expected rank >= 2", op, "?;[?]");
|
||||
INFER_ERROR("expected rank >= 2", op, "[2];?");
|
||||
INFER_ERROR("expected rank >= 2", op, "?;[2]");
|
||||
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[?];?");
|
||||
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "?;[?]");
|
||||
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[2];?");
|
||||
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "?;[2]");
|
||||
|
||||
// Mismatched ranks.
|
||||
INFER_ERROR("Ranks do not match", op, "[?,?];[?,?,?]");
|
||||
INFER_ERROR("Ranks do not match", op, "[?,?,?];[?,?]");
|
||||
INFER_ERROR("Ranks do not match", op, "[2,1];[2,1,2]");
|
||||
INFER_ERROR("Ranks do not match", op, "[2,1,2];[2,1]");
|
||||
INFER_ERROR("Shape must be rank 2 but is rank 3", op, "[?,?];[?,?,?]");
|
||||
INFER_ERROR("Shape must be rank 3 but is rank 2", op, "[?,?,?];[?,?]");
|
||||
INFER_ERROR("Shape must be rank 2 but is rank 3", op, "[2,1];[2,1,2]");
|
||||
INFER_ERROR("Shape must be rank 3 but is rank 2", op, "[2,1,2];[2,1]");
|
||||
|
||||
// Rank 2, unknown dims.
|
||||
INFER_OK(op, "[?,?];?", "[?,2];[?];[2]");
|
||||
@ -55,26 +55,26 @@ TEST(SetOpsTest, DenseToDenseShape) {
|
||||
INFER_OK(op, "?;[?,?,?,?]", "[?,4];[?];[4]");
|
||||
INFER_OK(op, "[?,?,?,?];[?,?,?,?]", "[?,4];[?];[4]");
|
||||
|
||||
// Known dimension 0.
|
||||
INFER_OK(op, "[4,?,?,?];?", "[d0_0,4];[d0_0];[4]");
|
||||
INFER_OK(op, "?;[4,?,?,?]", "[d1_0,4];[d1_0];[4]");
|
||||
INFER_OK(op, "[4,?,?,?];[?,?,?,?]", "[d0_0,4];[d0_0];[4]");
|
||||
INFER_OK(op, "[?,?,?,?];[4,?,?,?]", "[d1_0,4];[d1_0];[4]");
|
||||
INFER_OK(op, "[4,?,?,?];[4,?,?,?]", "[d0_0,4];[d0_0];[4]");
|
||||
// Known rank for 1 input.
|
||||
INFER_OK(op, "[5,3,2,1];?", "[?,4];[?];[4]");
|
||||
INFER_OK(op, "?;[5,3,2,1]", "[?,4];[?];[4]");
|
||||
INFER_OK(op, "[5,3,2,1];[?,?,?,?]", "[?,4];[?];[4]");
|
||||
INFER_OK(op, "[?,?,?,?];[5,3,2,1]", "[?,4];[?];[4]");
|
||||
INFER_OK(op, "[5,3,2,1];[?,?,?,?]", "[?,4];[?];[4]");
|
||||
|
||||
// Mismatched known n-1 dims.
|
||||
// Mismatched n-1 dims.
|
||||
INFER_ERROR("Dimension 0 in both shapes must be equal", op,
|
||||
"[4,?,2,?];[3,1,?,5]");
|
||||
INFER_ERROR("Dimension 2 in both shapes must be equal", op,
|
||||
"[4,3,2,1];[4,3,3,1]");
|
||||
|
||||
// Matched known n-1 dims.
|
||||
INFER_OK(op, "[4,5,6,7];[?,?,?,?]", "[d0_0,4];[d0_0];[4]");
|
||||
INFER_OK(op, "[4,5,6,7];[?,?,?,4]", "[d0_0,4];[d0_0];[4]");
|
||||
INFER_OK(op, "[?,?,?,?];[4,5,6,7]", "[d1_0,4];[d1_0];[4]");
|
||||
INFER_OK(op, "[4,?,2,?];[?,1,?,5]", "[d0_0,4];[d0_0];[4]");
|
||||
INFER_OK(op, "[4,5,6,7];[4,?,6,?]", "[d0_0,4];[d0_0];[4]");
|
||||
INFER_OK(op, "[4,5,6,7];[4,5,6,4]", "[d0_0,4];[d0_0];[4]");
|
||||
// Matched n-1 dims.
|
||||
INFER_OK(op, "[4,5,6,7];[?,?,?,?]", "[?,4];[?];[4]");
|
||||
INFER_OK(op, "[4,5,6,7];[?,?,?,4]", "[?,4];[?];[4]");
|
||||
INFER_OK(op, "[?,?,?,?];[4,5,6,7]", "[?,4];[?];[4]");
|
||||
INFER_OK(op, "[4,?,2,?];[?,1,?,5]", "[?,4];[?];[4]");
|
||||
INFER_OK(op, "[4,5,6,7];[4,?,6,?]", "[?,4];[?];[4]");
|
||||
INFER_OK(op, "[4,5,6,7];[4,5,6,4]", "[?,4];[?];[4]");
|
||||
}
|
||||
|
||||
TEST(SetOpsTest, DenseToSparseShape_InvalidNumberOfInputs) {
|
||||
@ -89,35 +89,37 @@ TEST(SetOpsTest, DenseToSparseShape) {
|
||||
|
||||
// Unknown shapes.
|
||||
INFER_OK(op, "?;?;?;?", "[?,?];[?];[?]");
|
||||
INFER_OK(op, "?;[?,?];[?];[?]", "[?,?];[?];[?]");
|
||||
|
||||
// Invalid rank.
|
||||
INFER_ERROR("expected rank >= 2", op, "[?];?;?;?");
|
||||
INFER_ERROR("expected rank >= 2", op, "[?];[?,?];[?];[?]");
|
||||
INFER_ERROR("expected rank >= 2", op, "[?];[5,3];[5];[3]");
|
||||
INFER_ERROR("expected rank >= 2", op, "[2];?;?;?");
|
||||
INFER_ERROR("expected rank >= 2", op, "[2];[?,?];[?];[?]");
|
||||
INFER_ERROR("expected rank >= 2", op, "[2];[5,3];[5];[3]");
|
||||
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[?];?;?;?");
|
||||
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
|
||||
"[?];[?,?];[?];[?]");
|
||||
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
|
||||
"[?];[5,3];[5];[3]");
|
||||
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[2];?;?;?");
|
||||
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
|
||||
"[2];[?,?];[?];[?]");
|
||||
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
|
||||
"[2];[5,3];[5];[3]");
|
||||
|
||||
// Rank 2, unknown dims.
|
||||
// Unknown sparse rank.
|
||||
INFER_OK(op, "[?,?];?;?;?", "[?,2];[?];[2]");
|
||||
INFER_OK(op, "[?,?];[?,?];[?];[?]", "[?,2];[?];[2]");
|
||||
INFER_OK(op, "[?,?];[5,3];[5];[3]", "[?,2];[?];[2]");
|
||||
|
||||
// Rank 4, unknown dims.
|
||||
INFER_OK(op, "[?,?,?,?];?;?;?", "[?,4];[?];[4]");
|
||||
INFER_OK(op, "[?,?,?,?];[?,?];[?];[?]", "[?,4];[?];[4]");
|
||||
INFER_OK(op, "[?,?,?,?];[5,3];[5];[3]", "[?,4];[?];[4]");
|
||||
// Unknown dense rank.
|
||||
INFER_OK(op, "?;[?,2];[?];[2]", "[?,d3_0];[?];[d3_0]");
|
||||
INFER_OK(op, "?;[5,2];[5];[2]", "[?,d3_0];[?];[d3_0]");
|
||||
|
||||
// Known dimension 0.
|
||||
INFER_OK(op, "[4,?,?,?];?;?;?", "[d0_0,4];[d0_0];[4]");
|
||||
INFER_OK(op, "[4,?,?,?];[?,?];[?];[?]", "[d0_0,4];[d0_0];[4]");
|
||||
INFER_OK(op, "[4,?,?,?];[5,3];[5];[3]", "[d0_0,4];[d0_0];[4]");
|
||||
// Known both ranks.
|
||||
INFER_OK(op, "[?,?];[5,2];[5];[2]", "[?,2];[?];[2]");
|
||||
INFER_OK(op, "[4,3];[5,2];[5];[2]", "[?,2];[?];[2]");
|
||||
|
||||
// Invalid input sparse tensor.
|
||||
INFER_ERROR("elements in index (5) and values (6) do not match", op,
|
||||
"[?,?];[5,3];[6];[3]");
|
||||
"?;[5,3];[6];[3]");
|
||||
INFER_ERROR("rank (3) and shape rank (4) do not match", op,
|
||||
"[?,?];[5,3];[5];[4]");
|
||||
"?;[5,3];[5];[4]");
|
||||
}
|
||||
|
||||
TEST(SetOpsTest, SparseToSparseShape_InvalidNumberOfInputs) {
|
||||
@ -128,7 +130,21 @@ TEST(SetOpsTest, SparseToSparseShape_InvalidNumberOfInputs) {
|
||||
|
||||
TEST(SetOpsTest, SparseToSparseShape) {
|
||||
ShapeInferenceTestOp op("SparseToSparseSetOperation");
|
||||
|
||||
// Unknown.
|
||||
INFER_OK(op, "?;?;?;?;?;?", "[?,?];[?];[?]");
|
||||
INFER_OK(op, "[?,?];[?];[?];[?,?];[?];[?]", "[?,?];[?];[?]");
|
||||
INFER_OK(op, "?;?;?;[?,?];[?];[?]", "[?,?];[?];[?]");
|
||||
INFER_OK(op, "[?,?];[?];[?];?;?;?", "[?,?];[?];[?]");
|
||||
|
||||
// Known rank for 1 input.
|
||||
INFER_OK(op, "[?,2];[?];[2];?;?;?", "[?,d2_0];[?];[d2_0]");
|
||||
INFER_OK(op, "?;?;?;[?,2];[?];[2]", "[?,d5_0];[?];[d5_0]");
|
||||
INFER_OK(op, "[?,2];[?];[2];[?,?];[?];[?]", "[?,d2_0];[?];[d2_0]");
|
||||
INFER_OK(op, "[?,?];[?];[?];[?,2];[?];[2]", "[?,d5_0];[?];[d5_0]");
|
||||
|
||||
// Known rank for both inputs.
|
||||
INFER_OK(op, "[?,2];[?];[2];[?,2];[?];[2]", "[?,d2_0];[?];[d2_0]");
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user