Branch 174861804 (#14326)
* Add ImportGraphDefTest.testMultipleImport to importer_test.py This tests the name deduping behavior of import_graph_def. This behavior is actually defined by the op creation logic, not import_graph_def, but I added a test here since the C++ ImportGraphDef function must emulate it (and presumably we'd like to maintain the import_graph_def behavior moving forward). PiperOrigin-RevId: 174536014 * Apply lib_internal defines to both lib_internal and lib_internal_impl Should fix checkpoint reading with snappy compression. Will follow up with testing for this sort of checkpoint issue. PiperOrigin-RevId: 174538693 * n/a (internal change only) PiperOrigin-RevId: 174539513 * A few changes to ApiDef generation: - Create a separate api_def_*.pbtxt file for each op. - Add attribute and argument descriptions to ApiDef. - Apply overrides based on op_gen_overrides.pbtxt file. PiperOrigin-RevId: 174540421 * Add uniquify_names option to ImportGraphDef. This option allows ImportGraphDef to mimic the behavior of the Python import_graph_def function, which automatically creates unique node names instead of raising an exception (this is due to the Python op construction logic, not import_graph_def directly). This change is a steps towards switching import_graph_def to use the C API version. PiperOrigin-RevId: 174541334 * Fix bad_color param on tf.contrib.summary.image PiperOrigin-RevId: 174549117 * Hlo parser: support control-predecessors. Also, - Changed from printing control-sucessors to printing control-predecessors because predecessors are defined before use. - Surround the predecessors with {}. PiperOrigin-RevId: 174552224 * Support pad node. PiperOrigin-RevId: 174581035 * Add tf.contrib.framework.sort, wrapping tf.nn.top_k (#288). Comparable to np.sort, but their "kind" parameter is not implemented (only one sort algorithm) and "order" is not applicable (tensors do not have fields). PiperOrigin-RevId: 174588000 * [TF2XLA] Don't change output port for control dependency in CopySubgraph. If the output is being squashed then we want control output 0, except where the input is a control dependency. PiperOrigin-RevId: 174633829 * Use latest nsync: allows running bazel after having downloaded for "make" build The downloads directory for the make build is within the source tree seen by bazel, which means that BUILD files (by whatever name) without those downloaded trees must all be valid in their new location, or not recognized by bazel as being BUILD files. The new version of nsync handles that, and this change pulls in that new version. PiperOrigin-RevId: 174652898 * Add profiling support to Service::ExecuteParallel. PiperOrigin-RevId: 174682772 * Replicate `Estimator.model_fn` across available GPUs. def replicate_model_fn(model_fn, optimizer_fn, devices=None): """Replicate `Estimator.model_fn` over GPUs. ... I tested that it seems to give the right result on cnn_mnist.py on 1 CPU, 1 real GPU, 4 allow_soft_placement=True GPUs. Some measurements on CNN MNIST across steps 19300-20000: 1) no replicate_model_fn call: global_step/sec: 156.254 global_step/sec: 155.074 global_step/sec: 155.74 global_step/sec: 153.636 global_step/sec: 157.218 global_step/sec: 159.644 2) replicate across one hardware GPU: global_step/sec: 158.171 global_step/sec: 165.618 global_step/sec: 162.773 global_step/sec: 159.204 global_step/sec: 162.289 global_step/sec: 167.173 3) replicate across 4 software GPUs on one hardware GPU (soft placement): global_step/sec: 75.47 global_step/sec: 76.16 global_step/sec: 75.18 Loss numbers didn't change across the three configurations. PiperOrigin-RevId: 174704385 * Enables wrapping input pipeline into tf.while_loop for all users. PiperOrigin-RevId: 174708213 * SerializeIterator: do not unref the resource until we're finished using it. This change avoids a potential use-after-free error if the resource is concurrently serialized and destroyed (e.g. by a DestroyResourceOp or Session::Reset()). PiperOrigin-RevId: 174713115 * Improve error message when a function is already defined with the same name and different hash string. PiperOrigin-RevId: 174715563 * Fix generate_examples build - Add -march=native to host_copts and host_cxxopts in configure.py - Make string.h for abstracting string differences at core interpreter level - Use tensorflow special arg parse instead of flags - Switch to using tool instead of data for dependency - Fix python3 compatibility + Use six.StringIO instead of StringIO.StringIO + Use print_function + Properly set binary flags on TempFile's used in toco_convert - Misc other path fixes PiperOrigin-RevId: 174717673 * Add input format agnostic way to parse HLOs. PiperOrigin-RevId: 174719153 * Remove misleading comment from Eigen build file. PiperOrigin-RevId: 174719222 * Basic plumbing for calling C API from import_graph_def() PiperOrigin-RevId: 174724070 * Memory leak detected when running a heap checker in our tests. PiperOrigin-RevId: 174726228 * [tpu:profiler] Support the Input Pipeline Analyzer tool in TPU profiler (WIP) o. move input pipeline analyzer related proto for grpc between red and green VMs o. rename perftools.gputools.profiler.collector::TfStatsHelperResult to tensorflow::tpu::TfOpStats. PiperOrigin-RevId: 174730411 * Clean up some reference cycles in eager mode. ResourceVariables enter graph mode to get a handle. We should probably revisit that, but in the meantime we can break the resulting reference cycles. PiperOrigin-RevId: 174732964 * Improved encoding on shapes in grappler. PiperOrigin-RevId: 174733491 * [tf.data] Remove unused members from IteratorContext. PiperOrigin-RevId: 174734277 * Refactor helper functions a bit for virtual gpu changes later. PiperOrigin-RevId: 174735029 * Fix invalid flush_secs argument. PiperOrigin-RevId: 174745329 * Replace the implementation of tf.flags with absl.flags. Previous tf.flags implementation is based on argparse. It contains -h/--help flags, which displays all flags. absl.app's --help flag only displays flags defined in the main module. There is a --helpfull flag that displays all flags. So added --helpshort --helpfull flags. app.run now raises SystemError on unknown flags (fixes #11195). Accessing flags before flags are parsed will now raise an UnparsedFlagAccessError, instead of causing implicit flag parsing previously. PiperOrigin-RevId: 174747028 * Fold Transpose into Matmul and SparseMatmul. Fold ConjugateTranspose in BatchMatmul. PiperOrigin-RevId: 174750173 * BUGFIX: special_math.ndtri didn't work with dynamic shapes. This was due to use of constant_op.constant(..., shape=p.shape), where sometimes p was a Tensor of unknown shape. PiperOrigin-RevId: 174764744 * Create a routine that can collapse a subgraph into a fused op PiperOrigin-RevId: 174765540 * Force CUDA runtime initialization only when device count is larger than 0. PiperOrigin-RevId: 174767565 * Remove use of xrange which is not python3 compatible. PiperOrigin-RevId: 174768741 * More thoroughly disable the should_use_result decorator when executing eagerly. It was creating reference cycles. Adds a test that TensorArrays create no reference cycles in eager mode. PiperOrigin-RevId: 174768765 * Fix device querying in Keras backend. PiperOrigin-RevId: 174769308 * Fix race bug in AdaptiveSharedBatchScheduler. In ASBSQueue::Schedule, when a new batch is created, it was added to the scheduler outside of the queue's lock. This was done to prevent any unforeseen interactions between the queue lock and scheduler lock. However, this wasn't being done in a thread safe way. PiperOrigin-RevId: 174769383 * Supports multi-dimensional logits and labels in multi class head. PiperOrigin-RevId: 174770444 * Refactor eager benchmarks to subclass Benchmark. PiperOrigin-RevId: 174770787 * Add `parallel_interleave` to tf/contrib/data/__init__.py so that it is directly addressable from tf.contrib.data. PiperOrigin-RevId: 174771870 * Fix DepthToSpaceGrad and SpaceToDepthGrad on data_format NCHW. This fixes #14243. PiperOrigin-RevId: 174772870 * Allow for an old_row_vocab_size, in case a subset of the old_row_vocab_file was used during the checkpoint creation (as is allowed in FeatureColumn._VocabularyListCategoricalColumn). PiperOrigin-RevId: 174781749 * Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 174781987 * [BufferAssignment] Sort allocation's "Assigned" objects before converting to a proto. This makes the buffer assignment's proto dump deterministic. RELNOTES: BufferAssignment's protocol buffer dump is now deterministic. PiperOrigin-RevId: 174783549 * [TF TensorArray] allow reading from an unwritten index if fully defined element_shape is given. This allows one to write to only some indices of a TensorArray before calling stack. Elements that were not written to are treated as all zero tensors. PiperOrigin-RevId: 174783569 * Remove binary dependency from optimize_for_inference_lib PiperOrigin-RevId: 174787363 * Update ops-related pbtxt files. PiperOrigin-RevId: 174787397 * Automated g4 rollback of changelist 174523638 PiperOrigin-RevId: 174788331 * Skip non-existent fetch nodes PiperOrigin-RevId: 174795864 * Automated g4 rollback of changelist 174735029 PiperOrigin-RevId: 174796480 * Add InceptionResNetV2 to tf.keras and update applications module to match Keras 2.0.9. PiperOrigin-RevId: 174796893 * Fix for LLVM API changes for fast math (https://reviews.llvm.org/rL317488). PiperOrigin-RevId: 174799735 * [TF:XLA] Add two disabled tests with while ops that permute tuple elements. These tests permute the tuple elements of a 3-tuple in each iteration in the following cyclic manner (132), i.e. a shift to the left. The first test just return the result tuple, the second returns the sum of all tuple elements (which is expected to be constant 6, no matter which permutation) Both tests are disabled for now because they fail on all back-ends. PiperOrigin-RevId: 174806092 * Refactor function Optimize. PiperOrigin-RevId: 174813300 * Add a unit test for gradient computation with layout optimizer. PiperOrigin-RevId: 174814136 * Previously if ComputeConstant seen a parameter it failed to proceed. After this change we can specify a list of parameters to it and if we specify enough then it will do the computation. The primary goal of this change is to make the HloEvaluator usable with ComputationBuilder from tests through ComputeConstant in cases where the input is a parameter (fed by a literal). PiperOrigin-RevId: 174845108 * Use nesting to reduce the number of modules listed in the API TOC. PiperOrigin-RevId: 174846842 * Added CPU matrix exponential op to TensorFlow. Uses Eigen's unsupported implementation. PiperOrigin-RevId: 174858966 * variables_to_restore: Differentiate python variables by string name rather than object. variables_to_restore ensured that duplicate variables weren't added to the return map by comparing python variable object. Normally there is only one Variable object for each underlying variable, so this wasn't a problem. But when one initializes a graph by importing a GraphDef, duplicate python Variable objects are created for each occurrence of a variable in a collection (say, global variables and moving average variables). This change fixes variables_to_restore to work with an imported graph def by not comparing Variable objects. PiperOrigin-RevId: 174861804
This commit is contained in:
parent
00e097241e
commit
4e69e02241
@ -25,10 +25,12 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
try:
|
||||
from shutil import which
|
||||
except ImportError:
|
||||
from distutils.spawn import find_executable as which
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
_TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
'.tf_configure.bazelrc')
|
||||
@ -485,7 +487,10 @@ def set_cc_opt_flags(environ_cp):
|
||||
cc_opt_flags = get_from_env_or_user_or_default(environ_cp, 'CC_OPT_FLAGS',
|
||||
question, default_cc_opt_flags)
|
||||
for opt in cc_opt_flags.split():
|
||||
write_to_bazelrc('build:opt --cxxopt=%s --copt=%s' % (opt, opt))
|
||||
host_opt = '-march=native' # It should be safe on the same build host.
|
||||
write_to_bazelrc(
|
||||
'build:opt --cxxopt=%s --copt=%s' % (opt, opt) +
|
||||
' --host_cxxopt=%s --host_copt=%s' % (host_opt, host_opt))
|
||||
|
||||
|
||||
def set_tf_cuda_clang(environ_cp):
|
||||
|
@ -130,7 +130,9 @@ Status CopySubgraph(const Graph& graph, const Frame* frame,
|
||||
stack.push_back(src);
|
||||
}
|
||||
Node* src_copy = (*node_map)[e->src()->id()];
|
||||
int src_output = squash_src_outputs[e->src()->id()] ? 0 : e->src_output();
|
||||
int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge()
|
||||
? 0
|
||||
: e->src_output();
|
||||
Node* dst_copy = (*node_map)[e->dst()->id()];
|
||||
output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
|
||||
}
|
||||
|
@ -77,18 +77,6 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice(
|
||||
out_shape.dim_sizes());
|
||||
}
|
||||
|
||||
// Degenerate case: single slice.
|
||||
if (num_indices == 1) {
|
||||
auto index = builder->Reshape(indices, {1});
|
||||
auto start_index = builder->Pad(
|
||||
index, XlaHelpers::Zero(builder, index_type),
|
||||
xla::MakeEdgePaddingConfig(
|
||||
{{input_shape_pre_axis.dims(), input_shape_post_axis.dims()}}));
|
||||
auto slice =
|
||||
builder->DynamicSlice(input, start_index, slice_shape.dim_sizes());
|
||||
return builder->Reshape(slice, out_shape.dim_sizes());
|
||||
}
|
||||
|
||||
// Specify the shape of the loop-carried Tensor tuple.
|
||||
xla::PrimitiveType ptype;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(dtype, &ptype));
|
||||
|
@ -1309,7 +1309,7 @@ Status ComputationBuilder::SetReturnValue(
|
||||
}
|
||||
|
||||
StatusOr<bool> ComputationBuilder::IsConstant(
|
||||
const ComputationDataHandle& operand) {
|
||||
const ComputationDataHandle& operand, int64 num_parameters) {
|
||||
if (!first_error_.ok()) {
|
||||
return first_error_;
|
||||
}
|
||||
@ -1317,6 +1317,7 @@ StatusOr<bool> ComputationBuilder::IsConstant(
|
||||
IsConstantRequest request;
|
||||
*request.mutable_computation() = computation_.handle();
|
||||
*request.mutable_operand() = operand;
|
||||
request.set_num_parameters(num_parameters);
|
||||
IsConstantResponse response;
|
||||
|
||||
VLOG(2) << "making IsConstant request";
|
||||
@ -1330,7 +1331,8 @@ StatusOr<bool> ComputationBuilder::IsConstant(
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> ComputationBuilder::ComputeConstant(
|
||||
const ComputationDataHandle& operand, const Layout* output_layout) {
|
||||
const ComputationDataHandle& operand, const Layout* output_layout,
|
||||
tensorflow::gtl::ArraySlice<Literal> parameters) {
|
||||
if (!first_error_.ok()) {
|
||||
return first_error_;
|
||||
}
|
||||
@ -1341,6 +1343,9 @@ StatusOr<std::unique_ptr<Literal>> ComputationBuilder::ComputeConstant(
|
||||
if (output_layout != nullptr) {
|
||||
*request.mutable_output_layout() = *output_layout;
|
||||
}
|
||||
for (const auto& param : parameters) {
|
||||
*request.add_parameters() = param.ToProto();
|
||||
}
|
||||
|
||||
ComputeConstantResponse response;
|
||||
|
||||
|
@ -746,11 +746,12 @@ class ComputationBuilder {
|
||||
ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle);
|
||||
|
||||
// Returns true if 'operand' is a compile-time constant. A compile-time
|
||||
// constant does not depend on parameters, or on stateful operators such
|
||||
// as `RngNormal` or `Infeed`. Unlike `ComputeConstant`, `IsConstant` tests
|
||||
// whether a computation is a compile-time constant without evaluating the
|
||||
// computation.
|
||||
StatusOr<bool> IsConstant(const ComputationDataHandle& operand);
|
||||
// constant does not depend on parameters with higher index then
|
||||
// `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`.
|
||||
// Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a
|
||||
// compile-time constant without evaluating the computation.
|
||||
StatusOr<bool> IsConstant(const ComputationDataHandle& operand,
|
||||
int64 num_parameters = 0);
|
||||
|
||||
// Normalizes operand across spatial and batch dimensions for each feature.
|
||||
//
|
||||
@ -795,7 +796,7 @@ class ComputationBuilder {
|
||||
float epsilon, int64 feature_index);
|
||||
|
||||
// Computes the value of a constant indicated by a
|
||||
// ComputationDataHandle.
|
||||
// ComputationDataHandle using a non-optimized interpreter on the host.
|
||||
//
|
||||
// The operand must be from the computation currently being built -
|
||||
// i.e., returned from this builder with no intervening call to
|
||||
@ -803,8 +804,11 @@ class ComputationBuilder {
|
||||
// that may stop working at any time.
|
||||
//
|
||||
// The operand must represent a constant value, which in this case
|
||||
// means that it must not statically depend on a parameter to the
|
||||
// computation that is being built.
|
||||
// means that it must not statically depend on any parameter of the
|
||||
// computation that is being built other then the ones specified on the
|
||||
// paramtere list. The parameters in the list will be indexed by their
|
||||
// parameter id property so the number of parameters specified should be at
|
||||
// least as many as the largest used parameter index.
|
||||
//
|
||||
// `IsConstant` can be used to test whether a computation is a compile-time
|
||||
// constant without evaluation it. `ComputeConstant` only succeeds for
|
||||
@ -822,7 +826,8 @@ class ComputationBuilder {
|
||||
// will be stored using that layout.
|
||||
StatusOr<std::unique_ptr<Literal>> ComputeConstant(
|
||||
const ComputationDataHandle& operand,
|
||||
const Layout* output_layout = nullptr);
|
||||
const Layout* output_layout = nullptr,
|
||||
tensorflow::gtl::ArraySlice<Literal> parameters = {});
|
||||
|
||||
// Returns a new ComputationBuilder whose resultant Computation is used only
|
||||
// by this ComputationBuilder. The sub-ComputationBuilder has the same
|
||||
|
@ -101,6 +101,11 @@ BufferAllocationProto BufferAllocation::ToProto() const {
|
||||
proto_assigned->set_offset(buffer_offset_size.second.offset);
|
||||
proto_assigned->set_size(buffer_offset_size.second.size);
|
||||
}
|
||||
std::sort(proto.mutable_assigned()->begin(), proto.mutable_assigned()->end(),
|
||||
[](const BufferAllocationProto::Assigned& assign1,
|
||||
const BufferAllocationProto::Assigned& assign2) {
|
||||
return assign1.logical_buffer_id() < assign2.logical_buffer_id();
|
||||
});
|
||||
return proto;
|
||||
}
|
||||
|
||||
|
@ -52,7 +52,7 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
|
||||
llvm::IRBuilder<> ir_builder(vector_tanh_body);
|
||||
|
||||
llvm::FastMathFlags fast_math_flags;
|
||||
fast_math_flags.setUnsafeAlgebra();
|
||||
fast_math_flags.setFast();
|
||||
ir_builder.setFastMathFlags(fast_math_flags);
|
||||
|
||||
llvm::Value* input = &*vector_tanh_function->arg_begin();
|
||||
|
@ -88,6 +88,16 @@ class Executable {
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>>
|
||||
arguments);
|
||||
|
||||
// Populates `hlo_execution_profile` from `executor`. This is implicit in any
|
||||
// Execute* API call that takes a hlo_execution_profile argument, but must be
|
||||
// called explicitly for other (async, for example) variants after the stream
|
||||
// has completed.
|
||||
virtual Status PopulateExecutionProfile(
|
||||
HloExecutionProfile* hlo_execution_profile,
|
||||
perftools::gputools::StreamExecutor* executor) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Convenience wrapper for calling Executable::ExecuteOnStream. Sets up a
|
||||
// timer for the execution, sets up HLO profiling if enabled, and fills in the
|
||||
// given ExecutionProfile if non-null. The ExecuteOnStream overloads have
|
||||
|
@ -1901,12 +1901,13 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
|
||||
if (has_sharding()) {
|
||||
extra.push_back(StrCat("sharding=", sharding().ToString()));
|
||||
}
|
||||
if (!control_successors_.empty()) {
|
||||
extra.push_back(StrCat(
|
||||
"control-successors=",
|
||||
Join(control_successors_, ", ", [](string* out, HloInstruction* succ) {
|
||||
StrAppend(out, succ->name());
|
||||
})));
|
||||
if (!control_predecessors_.empty()) {
|
||||
extra.push_back(StrCat("control-predecessors={",
|
||||
Join(control_predecessors_, ", ",
|
||||
[](string* out, HloInstruction* pre) {
|
||||
StrAppend(out, pre->name());
|
||||
}),
|
||||
"}"));
|
||||
}
|
||||
return extra;
|
||||
}
|
||||
|
@ -41,11 +41,21 @@ namespace se = ::perftools::gputools;
|
||||
namespace xla {
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<HloModule>>
|
||||
HloRunner::ReadModuleFromHloProtoFile(const char* filename,
|
||||
HloRunner::ReadModuleFromHloProtoFile(const std::string& filename,
|
||||
const DebugOptions& debug_options) {
|
||||
HloProto proto;
|
||||
TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
|
||||
filename, &proto));
|
||||
|
||||
const Status s =
|
||||
tensorflow::ReadBinaryProto(tensorflow::Env::Default(), filename, &proto);
|
||||
|
||||
if (!s.ok()) {
|
||||
const Status s2 =
|
||||
tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto);
|
||||
if (!s2.ok()) {
|
||||
return Status(s2.code(), s.error_message() + "\n" + s2.error_message());
|
||||
}
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloModuleConfig config,
|
||||
HloModule::CreateModuleConfigFromProto(proto.hlo_module()));
|
||||
@ -56,7 +66,7 @@ HloRunner::ReadModuleFromHloProtoFile(const char* filename,
|
||||
}
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<HloModule>>
|
||||
HloRunner::ReadModuleFromHloTextDumpFile(const char* filename,
|
||||
HloRunner::ReadModuleFromHloTextDumpFile(const std::string& filename,
|
||||
const DebugOptions& debug_options) {
|
||||
string hlo_string;
|
||||
TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(),
|
||||
@ -66,6 +76,19 @@ HloRunner::ReadModuleFromHloTextDumpFile(const char* filename,
|
||||
return tools::Parse(hlo_string, config);
|
||||
}
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<HloModule>> HloRunner::ReadModule(
|
||||
const std::string& filename, const DebugOptions& debug_options) {
|
||||
auto module = HloRunner::ReadModuleFromHloProtoFile(filename, debug_options);
|
||||
if (module.ok()) {
|
||||
return module;
|
||||
}
|
||||
const std::string e = module.status().error_message();
|
||||
module = HloRunner::ReadModuleFromHloTextDumpFile(filename, debug_options);
|
||||
return module.ok() ? std::move(module)
|
||||
: Status(module.status().code(),
|
||||
e + "\n" + module.status().error_message());
|
||||
}
|
||||
|
||||
// Define this in .cc file to avoid having to include eigen or forward declare
|
||||
// these types in the header.
|
||||
struct HloRunner::EigenThreadPoolWrapper {
|
||||
|
@ -44,15 +44,23 @@ class HloRunner {
|
||||
|
||||
~HloRunner();
|
||||
|
||||
// Reads the binary proto file in xla.HloProto format, creates and returns the
|
||||
// HloModule.
|
||||
// Reads the proto file in xla.HloProto format, creates and returns the
|
||||
// HloModule. Will try to parse the filename as binary proto, then try as
|
||||
// text proto if that fails.
|
||||
static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloProtoFile(
|
||||
const char* filename, const DebugOptions& debug_options);
|
||||
const std::string& filename, const DebugOptions& debug_options);
|
||||
|
||||
// Reads the hlo text dump file in HloModule::ToString format, creates and
|
||||
// returns the HloModule.
|
||||
static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloTextDumpFile(
|
||||
const char* filename, const DebugOptions& debug_options);
|
||||
const std::string& filename, const DebugOptions& debug_options);
|
||||
|
||||
// Tries to parse the filename specified first as binary proto format, then
|
||||
// as a textual proto format, then textual IR, then gives up if both fail.
|
||||
// ReadModuleFromHloProtoFile or ReadModuleFromHloTextDumpFile should be used
|
||||
// explicitly when you know the format, this if you don't.
|
||||
static StatusOr<std::unique_ptr<HloModule>> ReadModule(
|
||||
const std::string& filename, const DebugOptions& debug_options);
|
||||
|
||||
// Executes the given module with given literals as input and returns the
|
||||
// result as a Literal. The LiteralPtr type accepts Literal* or
|
||||
|
@ -555,8 +555,9 @@ int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) {
|
||||
llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled) {
|
||||
llvm::FastMathFlags flags;
|
||||
if (fast_math_enabled) {
|
||||
// UnsafeAlgebra implies NoInfs, NoNaNs, NoSignedZeros, and AllowReciprocal.
|
||||
flags.setUnsafeAlgebra();
|
||||
// Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros,
|
||||
// AllowReciprocal, AllowContract, and ApproxFunc.
|
||||
flags.setFast();
|
||||
}
|
||||
return flags;
|
||||
}
|
||||
|
@ -490,14 +490,20 @@ Service::ExecuteParallelAndRegisterResult(
|
||||
std::vector<perftools::gputools::DeviceMemoryBase>>
|
||||
arguments,
|
||||
Backend* backend, tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
|
||||
tensorflow::gtl::ArraySlice<string> result_tags) {
|
||||
tensorflow::gtl::ArraySlice<string> result_tags,
|
||||
ExecutionProfile* profile) {
|
||||
// Streams where the computation are launched, so we can wait on the streams
|
||||
// to complete.
|
||||
std::vector<Pool<se::Stream>::SmartPtr> streams;
|
||||
std::vector<std::unique_ptr<perftools::gputools::Timer>> timers;
|
||||
|
||||
// Global data handles for the computation results, one for each computation.
|
||||
std::vector<GlobalDataHandle> result_handles;
|
||||
|
||||
// Device ID to stream executor, populated only with devices that are being
|
||||
// profiled.
|
||||
std::map<int64, se::Stream*> index_to_profiled_streams;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
||||
backend->computation_placer()->AssignDevices(
|
||||
options_.number_of_replicas(), executables.size()));
|
||||
@ -510,6 +516,21 @@ Service::ExecuteParallelAndRegisterResult(
|
||||
backend->BorrowStream(replicas[replica]));
|
||||
streams.push_back(std::move(stream));
|
||||
|
||||
if (replica == 0 && profile != nullptr) {
|
||||
timers.emplace_back(
|
||||
new perftools::gputools::Timer(streams.back()->parent()));
|
||||
streams.back()
|
||||
->InitTimer(timers.back().get())
|
||||
.ThenStartTimer(timers.back().get());
|
||||
CHECK(timers.front() != nullptr);
|
||||
}
|
||||
|
||||
if (replica == 0 &&
|
||||
executables[i]->module_config().debug_options().xla_hlo_profile() &&
|
||||
executables[i]->hlo_profiling_enabled()) {
|
||||
index_to_profiled_streams[i] = streams.back().get();
|
||||
}
|
||||
|
||||
// Set up run options.
|
||||
ExecutableRunOptions options;
|
||||
options.set_stream(streams.back().get());
|
||||
@ -526,6 +547,10 @@ Service::ExecuteParallelAndRegisterResult(
|
||||
perftools::gputools::DeviceMemoryBase result,
|
||||
executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i]));
|
||||
|
||||
if (replica == 0 && profile != nullptr) {
|
||||
streams.back()->ThenStopTimer(timers.back().get());
|
||||
}
|
||||
|
||||
// All replicas share the same device address for the result allocation,
|
||||
// so only one of the replicas need to register the result handle.
|
||||
if (replica == 0) {
|
||||
@ -543,6 +568,69 @@ Service::ExecuteParallelAndRegisterResult(
|
||||
}
|
||||
}
|
||||
|
||||
// For every stream that had profiling enabled, obtain and debug-dump the HLO
|
||||
// profile.
|
||||
for (auto& index_to_profiled_stream : index_to_profiled_streams) {
|
||||
int64 device = index_to_profiled_stream.first;
|
||||
se::Stream* stream = index_to_profiled_stream.second;
|
||||
HloExecutionProfile hlo_profile;
|
||||
TF_RETURN_IF_ERROR(executables[device]->PopulateExecutionProfile(
|
||||
&hlo_profile, stream->parent()));
|
||||
|
||||
std::unordered_set<const xla::HloComputation*> profiled_computations =
|
||||
hlo_profile.profiled_computations();
|
||||
// To ensure we have print the profiles in a stable order, iterate over the
|
||||
// computations in post order.
|
||||
auto& module = executables[device]->module();
|
||||
std::list<xla::HloComputation*> all_computations =
|
||||
module.MakeComputationPostOrder();
|
||||
for (xla::HloComputation* computation : all_computations) {
|
||||
if (profiled_computations.count(computation) > 0) {
|
||||
string profile_string = hlo_profile.ToString(
|
||||
*computation, streams[0]->parent()->GetDeviceDescription(),
|
||||
executables[device]->CreateCostAnalysis().get());
|
||||
if (!profile_string.empty()) {
|
||||
LOG(INFO) << "HLO profile for execution on device " << device
|
||||
<< ":\n";
|
||||
XLA_LOG_LINES(tensorflow::INFO, profile_string);
|
||||
}
|
||||
}
|
||||
}
|
||||
hlo_graph_dumper::MaybeDumpHloModule(module, "Service::Execute",
|
||||
&hlo_profile);
|
||||
}
|
||||
|
||||
if (profile != nullptr) {
|
||||
CHECK(!timers.empty());
|
||||
std::vector<uint64> timer_nanoseconds;
|
||||
timer_nanoseconds.reserve(timers.size());
|
||||
for (auto& timer : timers) {
|
||||
timer_nanoseconds.push_back(timer->Nanoseconds());
|
||||
}
|
||||
uint64 nanoseconds =
|
||||
*std::max_element(timer_nanoseconds.begin(), timer_nanoseconds.end());
|
||||
|
||||
// Merge in run-time profile information from execution_profile on the
|
||||
// zeroth device.
|
||||
profile->MergeFrom(executables[0]->execution_profile());
|
||||
|
||||
// Overall execution time (in nanoseconds) from the executor timer.
|
||||
profile->set_compute_and_transfer_time_ns(nanoseconds);
|
||||
|
||||
// TODO(b/28123297): On GPU we end up including transfer time in
|
||||
// the compute time this way. Instead, we should get the correct
|
||||
// value by measuring it. Setting the field here at least lets
|
||||
// benchmarks provide *some* value for GPU computations.
|
||||
//
|
||||
// TODO(b/28447609): The value in compute_and_transfer_time_ns is actually
|
||||
// the compute time without the transfer time, so this way we get the
|
||||
// correct compute time. We should instead have the correct value for
|
||||
// compute_and_transfer_time and set compute_time to the compute time.
|
||||
if (profile->compute_time_ns() == 0) {
|
||||
profile->set_compute_time_ns(profile->compute_and_transfer_time_ns());
|
||||
}
|
||||
}
|
||||
|
||||
return result_handles;
|
||||
}
|
||||
|
||||
@ -715,14 +803,16 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
|
||||
|
||||
// Execute the generated executables in parallel and return the device
|
||||
// handles for each computation's output.
|
||||
ExecutionProfile profile;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<GlobalDataHandle> outputs,
|
||||
ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments,
|
||||
execute_backend_.get(), device_handles,
|
||||
computation_names));
|
||||
computation_names, &profile));
|
||||
for (const GlobalDataHandle& output : outputs) {
|
||||
ExecuteResponse response;
|
||||
*response.mutable_output() = output;
|
||||
*response.mutable_profile() = profile;
|
||||
*result->add_responses() = response;
|
||||
}
|
||||
|
||||
@ -1082,8 +1172,9 @@ tensorflow::Status Service::IsConstant(const IsConstantRequest* arg,
|
||||
return InvalidArgument("computations may not be empty");
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool is_constant,
|
||||
user_computation->IsConstant(arg->operand()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool is_constant,
|
||||
user_computation->IsConstant(arg->operand(), arg->num_parameters()));
|
||||
|
||||
result->set_is_constant(is_constant);
|
||||
return tensorflow::Status::OK();
|
||||
@ -1101,8 +1192,9 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
|
||||
return InvalidArgument("computations may not be empty");
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool is_constant,
|
||||
user_computation->IsConstant(arg->operand()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool is_constant,
|
||||
user_computation->IsConstant(arg->operand(), arg->parameters_size()));
|
||||
if (!is_constant) {
|
||||
return InvalidArgument("Operand to ComputeConstant depends on parameter.");
|
||||
}
|
||||
@ -1141,8 +1233,18 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
|
||||
/*include_unreachable_instructions=*/
|
||||
false));
|
||||
|
||||
std::vector<Literal> parameters(arg->parameters_size());
|
||||
for (int64 i = 0; i < arg->parameters_size(); ++i) {
|
||||
parameters[i] = Literal(arg->parameters(i));
|
||||
}
|
||||
std::vector<const Literal*> parameter_ptrs;
|
||||
std::transform(parameters.begin(), parameters.end(),
|
||||
std::back_inserter(parameter_ptrs),
|
||||
[](const Literal& literal) { return &literal; });
|
||||
|
||||
HloEvaluator evaluator;
|
||||
TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {}));
|
||||
TF_ASSIGN_OR_RETURN(auto result_literal,
|
||||
evaluator.Evaluate(*module, parameter_ptrs));
|
||||
// Since the shape_with_output_layout option in ExecutionOption is
|
||||
// non-effective to the Evaluator results, explicit relayout here.
|
||||
if (arg->has_output_layout()) {
|
||||
|
@ -327,7 +327,8 @@ class Service : public ServiceInterface {
|
||||
arguments,
|
||||
Backend* backend,
|
||||
tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
|
||||
tensorflow::gtl::ArraySlice<string> result_tags);
|
||||
tensorflow::gtl::ArraySlice<string> result_tags,
|
||||
ExecutionProfile* profile);
|
||||
|
||||
// Convenience function for adding a function to a user computation.
|
||||
template <typename RequestT, typename ResponseT>
|
||||
|
@ -1482,14 +1482,15 @@ UserComputation::ComputeProgramShape(
|
||||
|
||||
namespace {
|
||||
|
||||
// A visitor which checks whether an operation is a compile-time constant. That
|
||||
// is, the operation does not depend on any parameter instructions. The visitor
|
||||
// walks the computation starting at a given operation and sets is_constant to
|
||||
// false iff a parameter or RNG operation is encountered.
|
||||
void ConstantVisitor(const SessionComputation& session_computation,
|
||||
const ComputationDataHandle& handle,
|
||||
std::set<int64>* visited, bool* is_constant) {
|
||||
if (visited->count(handle.handle()) != 0 || !*is_constant) {
|
||||
// A visitor which checks whether an operation is pure functional meaning that
|
||||
// it doesn't depend on any parameter with an index higher then num_parameters.
|
||||
// The visitor walks the computation starting at a given operation and sets
|
||||
// is_functional to false iff a parameter or RNG operation is encountered.
|
||||
void PureFunctionalVisitor(const SessionComputation& session_computation,
|
||||
const ComputationDataHandle& handle,
|
||||
int64 num_parameters, std::set<int64>* visited,
|
||||
bool* is_functional) {
|
||||
if (visited->count(handle.handle()) != 0 || !*is_functional) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -1497,7 +1498,7 @@ void ConstantVisitor(const SessionComputation& session_computation,
|
||||
session_computation.requests().at(handle.handle());
|
||||
switch (request.request().op_case()) {
|
||||
case OpRequest::kRngRequest:
|
||||
*is_constant = false;
|
||||
*is_functional = false;
|
||||
break;
|
||||
|
||||
case OpRequest::kConstantRequest:
|
||||
@ -1506,41 +1507,43 @@ void ConstantVisitor(const SessionComputation& session_computation,
|
||||
case OpRequest::kGetTupleElementRequest: {
|
||||
const GetTupleElementRequest& get_tuple_element_request =
|
||||
request.request().get_tuple_element_request();
|
||||
ConstantVisitor(session_computation, get_tuple_element_request.operand(),
|
||||
visited, is_constant);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
get_tuple_element_request.operand(), num_parameters,
|
||||
visited, is_functional);
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kSliceRequest: {
|
||||
const SliceRequest& slice_request = request.request().slice_request();
|
||||
ConstantVisitor(session_computation, slice_request.operand(), visited,
|
||||
is_constant);
|
||||
PureFunctionalVisitor(session_computation, slice_request.operand(),
|
||||
num_parameters, visited, is_functional);
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kDynamicSliceRequest: {
|
||||
const DynamicSliceRequest& dynamic_slice_request =
|
||||
request.request().dynamic_slice_request();
|
||||
ConstantVisitor(session_computation, dynamic_slice_request.operand(),
|
||||
visited, is_constant);
|
||||
ConstantVisitor(session_computation,
|
||||
dynamic_slice_request.start_indices(), visited,
|
||||
is_constant);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
dynamic_slice_request.operand(), num_parameters,
|
||||
visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
dynamic_slice_request.start_indices(),
|
||||
num_parameters, visited, is_functional);
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kDynamicUpdateSliceRequest: {
|
||||
const DynamicUpdateSliceRequest& dynamic_update_slice_request =
|
||||
request.request().dynamic_update_slice_request();
|
||||
ConstantVisitor(session_computation,
|
||||
dynamic_update_slice_request.operand(), visited,
|
||||
is_constant);
|
||||
ConstantVisitor(session_computation,
|
||||
dynamic_update_slice_request.update(), visited,
|
||||
is_constant);
|
||||
ConstantVisitor(session_computation,
|
||||
dynamic_update_slice_request.start_indices(), visited,
|
||||
is_constant);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
dynamic_update_slice_request.operand(),
|
||||
num_parameters, visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
dynamic_update_slice_request.update(),
|
||||
num_parameters, visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
dynamic_update_slice_request.start_indices(),
|
||||
num_parameters, visited, is_functional);
|
||||
break;
|
||||
}
|
||||
|
||||
@ -1549,7 +1552,8 @@ void ConstantVisitor(const SessionComputation& session_computation,
|
||||
request.request().concatenate_request();
|
||||
for (const ComputationDataHandle& handle :
|
||||
concatenate_request.operands()) {
|
||||
ConstantVisitor(session_computation, handle, visited, is_constant);
|
||||
PureFunctionalVisitor(session_computation, handle, num_parameters,
|
||||
visited, is_functional);
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -1557,61 +1561,63 @@ void ConstantVisitor(const SessionComputation& session_computation,
|
||||
case OpRequest::kConvolveRequest: {
|
||||
const ConvolveRequest& convolve_request =
|
||||
request.request().convolve_request();
|
||||
ConstantVisitor(session_computation, convolve_request.lhs(), visited,
|
||||
is_constant);
|
||||
ConstantVisitor(session_computation, convolve_request.rhs(), visited,
|
||||
is_constant);
|
||||
PureFunctionalVisitor(session_computation, convolve_request.lhs(),
|
||||
num_parameters, visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation, convolve_request.rhs(),
|
||||
num_parameters, visited, is_functional);
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kCrossReplicaSumRequest: {
|
||||
// TODO(b/33009255): Implmement constant folding for cross replica sum.
|
||||
*is_constant = false;
|
||||
*is_functional = false;
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kInfeedRequest: {
|
||||
*is_constant = false;
|
||||
*is_functional = false;
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kOutfeedRequest: {
|
||||
*is_constant = false;
|
||||
*is_functional = false;
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kCallRequest: {
|
||||
const CallRequest& call_request = request.request().call_request();
|
||||
for (const ComputationDataHandle& handle : call_request.operands()) {
|
||||
ConstantVisitor(session_computation, handle, visited, is_constant);
|
||||
PureFunctionalVisitor(session_computation, handle, num_parameters,
|
||||
visited, is_functional);
|
||||
}
|
||||
// TODO(b/32495713): We aren't checking the to_apply computation itself,
|
||||
// so we conservatively say that computations containing the Call op
|
||||
// cannot be constant. We cannot set is_constant=false in other similar
|
||||
// cannot be constant. We cannot set is_functional=false in other similar
|
||||
// cases since we're already relying on IsConstant to return true.
|
||||
*is_constant = false;
|
||||
*is_functional = false;
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kCustomCallRequest: {
|
||||
*is_constant = false;
|
||||
*is_functional = false;
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kSendRequest: {
|
||||
*is_constant = false;
|
||||
*is_functional = false;
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kRecvRequest: {
|
||||
*is_constant = false;
|
||||
*is_functional = false;
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kMapRequest: {
|
||||
const MapRequest& map_request = request.request().map_request();
|
||||
for (const ComputationDataHandle& handle : map_request.operands()) {
|
||||
ConstantVisitor(session_computation, handle, visited, is_constant);
|
||||
PureFunctionalVisitor(session_computation, handle, num_parameters,
|
||||
visited, is_functional);
|
||||
}
|
||||
// TODO(b/32495713): We aren't checking the to_apply computation itself.
|
||||
break;
|
||||
@ -1619,10 +1625,10 @@ void ConstantVisitor(const SessionComputation& session_computation,
|
||||
|
||||
case OpRequest::kReduceRequest: {
|
||||
const ReduceRequest& reduce_request = request.request().reduce_request();
|
||||
ConstantVisitor(session_computation, reduce_request.operand(), visited,
|
||||
is_constant);
|
||||
ConstantVisitor(session_computation, reduce_request.init_value(), visited,
|
||||
is_constant);
|
||||
PureFunctionalVisitor(session_computation, reduce_request.operand(),
|
||||
num_parameters, visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation, reduce_request.init_value(),
|
||||
num_parameters, visited, is_functional);
|
||||
// TODO(b/32495713): We aren't checking the to_apply computation itself.
|
||||
break;
|
||||
}
|
||||
@ -1630,10 +1636,12 @@ void ConstantVisitor(const SessionComputation& session_computation,
|
||||
case OpRequest::kReduceWindowRequest: {
|
||||
const ReduceWindowRequest& reduce_window_request =
|
||||
request.request().reduce_window_request();
|
||||
ConstantVisitor(session_computation, reduce_window_request.operand(),
|
||||
visited, is_constant);
|
||||
ConstantVisitor(session_computation, reduce_window_request.init_value(),
|
||||
visited, is_constant);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
reduce_window_request.operand(), num_parameters,
|
||||
visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
reduce_window_request.init_value(), num_parameters,
|
||||
visited, is_functional);
|
||||
// TODO(b/32495713): We aren't checking the to_apply computation itself.
|
||||
break;
|
||||
}
|
||||
@ -1641,13 +1649,15 @@ void ConstantVisitor(const SessionComputation& session_computation,
|
||||
case OpRequest::kSelectAndScatterRequest: {
|
||||
const SelectAndScatterRequest& select_and_scatter_request =
|
||||
request.request().select_and_scatter_request();
|
||||
ConstantVisitor(session_computation, select_and_scatter_request.operand(),
|
||||
visited, is_constant);
|
||||
ConstantVisitor(session_computation, select_and_scatter_request.source(),
|
||||
visited, is_constant);
|
||||
ConstantVisitor(session_computation,
|
||||
select_and_scatter_request.init_value(), visited,
|
||||
is_constant);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
select_and_scatter_request.operand(),
|
||||
num_parameters, visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
select_and_scatter_request.source(), num_parameters,
|
||||
visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
select_and_scatter_request.init_value(),
|
||||
num_parameters, visited, is_functional);
|
||||
// TODO(b/32495713): We aren't checking the select and scatter
|
||||
// computations themselves.
|
||||
break;
|
||||
@ -1656,76 +1666,80 @@ void ConstantVisitor(const SessionComputation& session_computation,
|
||||
case OpRequest::kBroadcastRequest: {
|
||||
const BroadcastRequest& broadcast_request =
|
||||
request.request().broadcast_request();
|
||||
ConstantVisitor(session_computation, broadcast_request.operand(), visited,
|
||||
is_constant);
|
||||
PureFunctionalVisitor(session_computation, broadcast_request.operand(),
|
||||
num_parameters, visited, is_functional);
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kReshapeRequest: {
|
||||
const ReshapeRequest& reshape_request =
|
||||
request.request().reshape_request();
|
||||
ConstantVisitor(session_computation, reshape_request.operand(), visited,
|
||||
is_constant);
|
||||
PureFunctionalVisitor(session_computation, reshape_request.operand(),
|
||||
num_parameters, visited, is_functional);
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kReverseRequest: {
|
||||
const ReverseRequest& reverse_request =
|
||||
request.request().reverse_request();
|
||||
ConstantVisitor(session_computation, reverse_request.operand(), visited,
|
||||
is_constant);
|
||||
PureFunctionalVisitor(session_computation, reverse_request.operand(),
|
||||
num_parameters, visited, is_functional);
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kPadRequest: {
|
||||
const PadRequest& pad_request = request.request().pad_request();
|
||||
ConstantVisitor(session_computation, pad_request.operand(), visited,
|
||||
is_constant);
|
||||
ConstantVisitor(session_computation, pad_request.padding_value(), visited,
|
||||
is_constant);
|
||||
PureFunctionalVisitor(session_computation, pad_request.operand(),
|
||||
num_parameters, visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation, pad_request.padding_value(),
|
||||
num_parameters, visited, is_functional);
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kParameterRequest: {
|
||||
*is_constant = false;
|
||||
const ParameterRequest& parameter_request =
|
||||
request.request().parameter_request();
|
||||
if (parameter_request.parameter() >= num_parameters) {
|
||||
*is_functional = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kConvertRequest: {
|
||||
const ConvertRequest& convert_request =
|
||||
request.request().convert_request();
|
||||
ConstantVisitor(session_computation, convert_request.operand(), visited,
|
||||
is_constant);
|
||||
PureFunctionalVisitor(session_computation, convert_request.operand(),
|
||||
num_parameters, visited, is_functional);
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kWhileRequest: {
|
||||
const WhileRequest& while_request = request.request().while_request();
|
||||
ConstantVisitor(session_computation, while_request.init(), visited,
|
||||
is_constant);
|
||||
PureFunctionalVisitor(session_computation, while_request.init(),
|
||||
num_parameters, visited, is_functional);
|
||||
// TODO(b/32495713): We aren't checking the condition and body
|
||||
// computations themselves.
|
||||
*is_constant = false;
|
||||
*is_functional = false;
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kTernaryOpRequest: {
|
||||
const TernaryOpRequest& ternary_op_request =
|
||||
request.request().ternary_op_request();
|
||||
ConstantVisitor(session_computation, ternary_op_request.lhs(), visited,
|
||||
is_constant);
|
||||
ConstantVisitor(session_computation, ternary_op_request.rhs(), visited,
|
||||
is_constant);
|
||||
ConstantVisitor(session_computation, ternary_op_request.ehs(), visited,
|
||||
is_constant);
|
||||
PureFunctionalVisitor(session_computation, ternary_op_request.lhs(),
|
||||
num_parameters, visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation, ternary_op_request.rhs(),
|
||||
num_parameters, visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation, ternary_op_request.ehs(),
|
||||
num_parameters, visited, is_functional);
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kTransposeRequest: {
|
||||
const TransposeRequest& transpose_request =
|
||||
request.request().transpose_request();
|
||||
ConstantVisitor(session_computation, transpose_request.operand(), visited,
|
||||
is_constant);
|
||||
PureFunctionalVisitor(session_computation, transpose_request.operand(),
|
||||
num_parameters, visited, is_functional);
|
||||
break;
|
||||
}
|
||||
|
||||
@ -1734,7 +1748,8 @@ void ConstantVisitor(const SessionComputation& session_computation,
|
||||
request.request().variadic_op_request();
|
||||
for (const ComputationDataHandle& handle :
|
||||
variadic_op_request.operands()) {
|
||||
ConstantVisitor(session_computation, handle, visited, is_constant);
|
||||
PureFunctionalVisitor(session_computation, handle, num_parameters,
|
||||
visited, is_functional);
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -1742,67 +1757,74 @@ void ConstantVisitor(const SessionComputation& session_computation,
|
||||
case OpRequest::kUnaryOpRequest: {
|
||||
const UnaryOpRequest& unary_op_request =
|
||||
request.request().unary_op_request();
|
||||
ConstantVisitor(session_computation, unary_op_request.operand(), visited,
|
||||
is_constant);
|
||||
PureFunctionalVisitor(session_computation, unary_op_request.operand(),
|
||||
num_parameters, visited, is_functional);
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kBatchNormTrainingRequest: {
|
||||
const BatchNormTrainingRequest& batch_norm_training_request =
|
||||
request.request().batch_norm_training_request();
|
||||
ConstantVisitor(session_computation,
|
||||
batch_norm_training_request.operand(), visited,
|
||||
is_constant);
|
||||
ConstantVisitor(session_computation, batch_norm_training_request.scale(),
|
||||
visited, is_constant);
|
||||
ConstantVisitor(session_computation, batch_norm_training_request.offset(),
|
||||
visited, is_constant);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
batch_norm_training_request.operand(),
|
||||
num_parameters, visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
batch_norm_training_request.scale(), num_parameters,
|
||||
visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
batch_norm_training_request.offset(),
|
||||
num_parameters, visited, is_functional);
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kBatchNormInferenceRequest: {
|
||||
const BatchNormInferenceRequest& batch_norm_inference_request =
|
||||
request.request().batch_norm_inference_request();
|
||||
ConstantVisitor(session_computation,
|
||||
batch_norm_inference_request.operand(), visited,
|
||||
is_constant);
|
||||
ConstantVisitor(session_computation, batch_norm_inference_request.scale(),
|
||||
visited, is_constant);
|
||||
ConstantVisitor(session_computation,
|
||||
batch_norm_inference_request.offset(), visited,
|
||||
is_constant);
|
||||
ConstantVisitor(session_computation, batch_norm_inference_request.mean(),
|
||||
visited, is_constant);
|
||||
ConstantVisitor(session_computation,
|
||||
batch_norm_inference_request.variance(), visited,
|
||||
is_constant);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
batch_norm_inference_request.operand(),
|
||||
num_parameters, visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
batch_norm_inference_request.scale(),
|
||||
num_parameters, visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
batch_norm_inference_request.offset(),
|
||||
num_parameters, visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
batch_norm_inference_request.mean(), num_parameters,
|
||||
visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
batch_norm_inference_request.variance(),
|
||||
num_parameters, visited, is_functional);
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kBatchNormGradRequest: {
|
||||
const BatchNormGradRequest& batch_norm_grad_request =
|
||||
request.request().batch_norm_grad_request();
|
||||
ConstantVisitor(session_computation, batch_norm_grad_request.operand(),
|
||||
visited, is_constant);
|
||||
ConstantVisitor(session_computation, batch_norm_grad_request.scale(),
|
||||
visited, is_constant);
|
||||
ConstantVisitor(session_computation, batch_norm_grad_request.mean(),
|
||||
visited, is_constant);
|
||||
ConstantVisitor(session_computation, batch_norm_grad_request.variance(),
|
||||
visited, is_constant);
|
||||
ConstantVisitor(session_computation,
|
||||
batch_norm_grad_request.grad_output(), visited,
|
||||
is_constant);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
batch_norm_grad_request.operand(), num_parameters,
|
||||
visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
batch_norm_grad_request.scale(), num_parameters,
|
||||
visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation, batch_norm_grad_request.mean(),
|
||||
num_parameters, visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
batch_norm_grad_request.variance(), num_parameters,
|
||||
visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation,
|
||||
batch_norm_grad_request.grad_output(),
|
||||
num_parameters, visited, is_functional);
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kBinaryOpRequest: {
|
||||
const BinaryOpRequest& binary_op_request =
|
||||
request.request().binary_op_request();
|
||||
ConstantVisitor(session_computation, binary_op_request.lhs(), visited,
|
||||
is_constant);
|
||||
ConstantVisitor(session_computation, binary_op_request.rhs(), visited,
|
||||
is_constant);
|
||||
PureFunctionalVisitor(session_computation, binary_op_request.lhs(),
|
||||
num_parameters, visited, is_functional);
|
||||
PureFunctionalVisitor(session_computation, binary_op_request.rhs(),
|
||||
num_parameters, visited, is_functional);
|
||||
break;
|
||||
}
|
||||
|
||||
@ -1817,8 +1839,8 @@ void ConstantVisitor(const SessionComputation& session_computation,
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> UserComputation::IsConstant(
|
||||
const ComputationDataHandle& handle) {
|
||||
StatusOr<bool> UserComputation::IsConstant(const ComputationDataHandle& handle,
|
||||
int64 num_parameters) {
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
|
||||
// Verify that the handle is valid.
|
||||
@ -1829,7 +1851,8 @@ StatusOr<bool> UserComputation::IsConstant(
|
||||
|
||||
bool is_constant = true;
|
||||
std::set<int64> visited;
|
||||
ConstantVisitor(session_computation_, handle, &visited, &is_constant);
|
||||
PureFunctionalVisitor(session_computation_, handle, num_parameters, &visited,
|
||||
&is_constant);
|
||||
|
||||
return is_constant;
|
||||
}
|
||||
|
@ -250,9 +250,11 @@ class UserComputation {
|
||||
StatusOr<std::shared_ptr<const ProgramShape>> ComputeProgramShape(
|
||||
VersionedComputationHandle::Version version) const;
|
||||
|
||||
// Returns true if the given data handle does not depend on any
|
||||
// parameters. That is, the value can be computed at compile time.
|
||||
StatusOr<bool> IsConstant(const ComputationDataHandle& handle);
|
||||
// Returns true if the given data handle does not depend on any parameter with
|
||||
// index higher then num_parameters. That is, the value can be computed at
|
||||
// compile time if we know the first num_parameters arguments.
|
||||
StatusOr<bool> IsConstant(const ComputationDataHandle& handle,
|
||||
int64 num_parameters);
|
||||
|
||||
// Returns the output shape of the operation indicated by the given handle.
|
||||
StatusOr<Shape> GetShape(const ComputationDataHandle& handle);
|
||||
|
@ -71,24 +71,27 @@ class ComputeConstantTest : public ::testing::Test {
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
|
||||
Client* client, const ComputationDataHandle& operand,
|
||||
ComputationBuilder* builder, Layout* output_layout = nullptr) {
|
||||
TF_ASSIGN_OR_RETURN(auto computed,
|
||||
builder->ComputeConstant(operand, output_layout));
|
||||
ComputationBuilder* builder, Layout* output_layout = nullptr,
|
||||
tensorflow::gtl::ArraySlice<Literal> parameters = {}) {
|
||||
TF_ASSIGN_OR_RETURN(auto computed, builder->ComputeConstant(
|
||||
operand, output_layout, parameters));
|
||||
return std::move(computed);
|
||||
}
|
||||
|
||||
template <class Scalar>
|
||||
StatusOr<Scalar> ComputeConstantScalar(Client* client,
|
||||
const ComputationDataHandle& operand,
|
||||
ComputationBuilder* builder) {
|
||||
TF_ASSIGN_OR_RETURN(auto literal,
|
||||
ComputeConstantLiteral(client, operand, builder));
|
||||
StatusOr<Scalar> ComputeConstantScalar(
|
||||
Client* client, const ComputationDataHandle& operand,
|
||||
ComputationBuilder* builder,
|
||||
tensorflow::gtl::ArraySlice<Literal> parameters = {}) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto literal,
|
||||
ComputeConstantLiteral(client, operand, builder, nullptr, parameters));
|
||||
return literal->Get<Scalar>({});
|
||||
}
|
||||
|
||||
bool IsConstant(const ComputationDataHandle& operand,
|
||||
ComputationBuilder* builder) {
|
||||
StatusOr<bool> result = builder->IsConstant(operand);
|
||||
ComputationBuilder* builder, int64 num_parameters = 0) {
|
||||
StatusOr<bool> result = builder->IsConstant(operand, num_parameters);
|
||||
EXPECT_TRUE(result.ok()) << result.status();
|
||||
return result.ok() ? result.ValueOrDie() : false;
|
||||
}
|
||||
@ -138,7 +141,25 @@ TEST_F(ComputeConstantTest, ScalarRng) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ComputeConstantTest, DirectParam) {
|
||||
TEST_F(ComputeConstantTest, Param) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
ComputationBuilder b(client, TestName());
|
||||
auto param = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "lhs");
|
||||
auto computation = b.Add(param, b.ConstantR0<float>(1.5f));
|
||||
|
||||
std::vector<Literal> arguments;
|
||||
arguments.emplace_back(*Literal::CreateR0(42.5f));
|
||||
EXPECT_TRUE(IsConstant(computation, &b, arguments.size()));
|
||||
|
||||
auto value =
|
||||
ComputeConstantScalar<float>(client, computation, &b, arguments);
|
||||
ASSERT_TRUE(value.ok()) << value.status();
|
||||
EXPECT_EQ(value.ValueOrDie(), 44.0f);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ComputeConstantTest, DirectParamMissing) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
ComputationBuilder b(client, TestName());
|
||||
@ -152,7 +173,7 @@ TEST_F(ComputeConstantTest, DirectParam) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ComputeConstantTest, IndirectParam) {
|
||||
TEST_F(ComputeConstantTest, IndirectParamMissing) {
|
||||
for (ClientType client_type : client_types) {
|
||||
Client* client = ClientOrDie(platform_, client_type);
|
||||
ComputationBuilder b(client, TestName());
|
||||
|
@ -357,6 +357,111 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
|
||||
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result.
|
||||
TEST_F(WhileTest, DISABLED_WhileWithPermutationAndTupleResult) {
|
||||
std::vector<Shape> shape_elements = {
|
||||
ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}),
|
||||
ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})};
|
||||
Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
|
||||
|
||||
// Create a computation for the condition.
|
||||
// Repeat for N iterations.
|
||||
const int N = 2;
|
||||
Computation condition;
|
||||
{
|
||||
ComputationBuilder builder(client_, "condition");
|
||||
auto prev = builder.Parameter(0, result_shape, "prev");
|
||||
auto iteration = builder.GetTupleElement(prev, 0);
|
||||
builder.Gt(builder.ConstantR0<int32>(N), iteration);
|
||||
condition = builder.Build().ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
// Create a computation for the body.
|
||||
// Add 1 to the iteration variable and permute the weights.
|
||||
Computation body;
|
||||
{
|
||||
ComputationBuilder builder(client_, "body");
|
||||
auto prev = builder.Parameter(0, result_shape, "prev");
|
||||
auto iteration = builder.GetTupleElement(prev, 0);
|
||||
auto w1 = builder.GetTupleElement(prev, 1);
|
||||
auto w2 = builder.GetTupleElement(prev, 2);
|
||||
auto w3 = builder.GetTupleElement(prev, 3);
|
||||
auto result = builder.Tuple(
|
||||
{builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2});
|
||||
body = builder.Build().ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
// Create a While node with computations for the condition and the body.
|
||||
ComputationBuilder builder(client_, "while");
|
||||
auto init = builder.Tuple(
|
||||
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f),
|
||||
builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)});
|
||||
auto result = builder.While(condition, body, init);
|
||||
VLOG(2) << "result = "
|
||||
<< ShapeUtil::HumanString(
|
||||
*builder.GetShape(result).ConsumeValueOrDie());
|
||||
|
||||
auto expected_counter = Literal::CreateR0<int32>(N);
|
||||
auto expected_w1 = Literal::CreateR1<float>({1.0f, 1.0f, 1.0f});
|
||||
auto expected_w2 = Literal::CreateR1<float>({2.0f, 2.0f, 2.0f});
|
||||
auto expected_w3 = Literal::CreateR1<float>({3.0f, 3.0f, 3.0f});
|
||||
auto expected = Literal::MakeTuple({expected_counter.get(), expected_w2.get(),
|
||||
expected_w3.get(), expected_w1.get()});
|
||||
VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
|
||||
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result.
|
||||
TEST_F(WhileTest, DISABLED_WhileWithPermutationAndVectorResult) {
|
||||
std::vector<Shape> shape_elements = {
|
||||
ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}),
|
||||
ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})};
|
||||
Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
|
||||
|
||||
// Create a computation for the condition.
|
||||
// Repeat for N iterations.
|
||||
const int N = 2;
|
||||
Computation condition;
|
||||
{
|
||||
ComputationBuilder builder(client_, "condition");
|
||||
auto prev = builder.Parameter(0, result_shape, "prev");
|
||||
auto iteration = builder.GetTupleElement(prev, 0);
|
||||
builder.Gt(builder.ConstantR0<int32>(N), iteration);
|
||||
condition = builder.Build().ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
// Create a computation for the body.
|
||||
// Add 1 to the iteration variable permute the weights.
|
||||
Computation body;
|
||||
{
|
||||
ComputationBuilder builder(client_, "body");
|
||||
auto prev = builder.Parameter(0, result_shape, "prev");
|
||||
auto iteration = builder.GetTupleElement(prev, 0);
|
||||
auto w1 = builder.GetTupleElement(prev, 1);
|
||||
auto w2 = builder.GetTupleElement(prev, 2);
|
||||
auto w3 = builder.GetTupleElement(prev, 3);
|
||||
auto result = builder.Tuple(
|
||||
{builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2});
|
||||
body = builder.Build().ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
// Create a While node with computations for the condition and the body.
|
||||
ComputationBuilder builder(client_, "while");
|
||||
auto init = builder.Tuple(
|
||||
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f),
|
||||
builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)});
|
||||
auto xla_while = builder.While(condition, body, init);
|
||||
|
||||
auto add12 = builder.Add(builder.GetTupleElement(xla_while, 1),
|
||||
builder.GetTupleElement(xla_while, 2));
|
||||
auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3));
|
||||
VLOG(2) << "result = "
|
||||
<< ShapeUtil::HumanString(
|
||||
*builder.GetShape(result).ConsumeValueOrDie());
|
||||
std::vector<float> expected = {6.f, 6.f, 6.f};
|
||||
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
// Tests a while node when the result type T is a Tuple.
|
||||
//
|
||||
// tuple<int32, vector<float>> result(0, vector<float>(10, 0.0f));
|
||||
|
@ -58,6 +58,7 @@ class HloParser {
|
||||
string* root_name);
|
||||
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
|
||||
bool ParseSharding(HloInstruction* instruction);
|
||||
bool ParseControlPredecessors(HloInstruction* instruction);
|
||||
bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
|
||||
bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
|
||||
bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
|
||||
@ -436,10 +437,35 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
return TokenError(StrCat("parsing not yet implemented for op: ",
|
||||
HloOpcodeString(opcode)));
|
||||
}
|
||||
// Parse "sharding=".
|
||||
if (lexer_.GetKind() == TokKind::kComma) {
|
||||
if (!ParseSharding(instruction)) {
|
||||
return false;
|
||||
|
||||
bool has_sharding = false;
|
||||
bool has_control = false;
|
||||
while (EatIfPresent(TokKind::kComma)) {
|
||||
string attribute_name;
|
||||
if (!ParseAttributeName(&attribute_name)) {
|
||||
return TokenError("expects ', sharding=' or ', control-predecessors='");
|
||||
}
|
||||
|
||||
if (attribute_name == "sharding") {
|
||||
// Parse "sharding=".
|
||||
if (has_sharding) {
|
||||
return TokenError("expects at most 1 'sharding='");
|
||||
}
|
||||
has_sharding = true;
|
||||
if (!ParseSharding(instruction)) {
|
||||
return false;
|
||||
}
|
||||
} else if (attribute_name == "control-predecessors") {
|
||||
// Parse "control-predecessors"
|
||||
if (has_control) {
|
||||
return TokenError("expects at most 1 'control-predecessors='");
|
||||
}
|
||||
has_control = true;
|
||||
if (!ParseControlPredecessors(instruction)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return TokenError(StrCat("unexpected attribute: ", attribute_name));
|
||||
}
|
||||
}
|
||||
|
||||
@ -449,15 +475,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? ('devices=' ('['
|
||||
// dims ']')* device_list)? '}' dims ::= int_list device_list ::= int_list
|
||||
bool HloParser::ParseSharding(HloInstruction* instruction) {
|
||||
if (!ParseToken(TokKind::kComma,
|
||||
"expects ',' in front of an extra attribute")) {
|
||||
return false;
|
||||
}
|
||||
string attribute_name;
|
||||
if (!ParseAttributeName(&attribute_name) || attribute_name != "sharding") {
|
||||
return TokenError("expects attribute name: sharding");
|
||||
}
|
||||
|
||||
if (!ParseToken(TokKind::kLbrace,
|
||||
"expected '{' to start sharding attribute")) {
|
||||
return false;
|
||||
@ -577,6 +594,34 @@ bool HloParser::ParseSharding(HloInstruction* instruction) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// '{' name+ '}'
|
||||
bool HloParser::ParseControlPredecessors(HloInstruction* instruction) {
|
||||
if (!ParseToken(TokKind::kLbrace,
|
||||
"expects '{' at the beginning of control predecessors")) {
|
||||
return false;
|
||||
}
|
||||
do {
|
||||
string name;
|
||||
if (!ParseName(&name)) {
|
||||
return TokenError("expects a control predecessor");
|
||||
}
|
||||
HloInstruction* pre =
|
||||
tensorflow::gtl::FindPtrOrNull(instruction_pool_, name);
|
||||
if (!pre) {
|
||||
return TokenError(
|
||||
StrCat("control predecessor ", name, " is not defined: "));
|
||||
}
|
||||
Status status = pre->AddControlDependencyTo(instruction);
|
||||
if (!status.ok()) {
|
||||
return TokenError(StrCat("error adding control dependency for: ", name,
|
||||
" status: ", status.ToString()));
|
||||
}
|
||||
} while (EatIfPresent(TokKind::kComma));
|
||||
|
||||
return ParseToken(TokKind::kRbrace,
|
||||
"expects '}' at the end of control predecessors");
|
||||
}
|
||||
|
||||
bool HloParser::SetValueInLiteral(int64 value, int64 linear_index,
|
||||
Literal* literal) {
|
||||
const Shape& shape = literal->shape();
|
||||
|
@ -214,7 +214,7 @@ R"(HloModule TwoSendRecvBothWayRecvFist_module:
|
||||
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
|
||||
%recv = f32[] recv(), channel_id=15, sharding={maximal device=1}
|
||||
ROOT %constant = f32[] constant(2.1), sharding={maximal device=0}
|
||||
%send = () send(f32[] %constant), channel_id=16, sharding={maximal device=0}
|
||||
%send = () send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
|
||||
}
|
||||
|
||||
)"
|
||||
|
@ -361,6 +361,7 @@ message WaitForExecutionResponse {
|
||||
message IsConstantRequest {
|
||||
ComputationHandle computation = 1;
|
||||
ComputationDataHandle operand = 2;
|
||||
int64 num_parameters = 3;
|
||||
}
|
||||
|
||||
message IsConstantResponse {
|
||||
@ -371,6 +372,7 @@ message ComputeConstantRequest {
|
||||
ComputationHandle computation = 1;
|
||||
ComputationDataHandle operand = 2;
|
||||
Layout output_layout = 3;
|
||||
repeated LiteralProto parameters = 4;
|
||||
}
|
||||
|
||||
message ComputeConstantResponse {
|
||||
|
@ -399,7 +399,7 @@ ASBSQueue<TaskType>::~ASBSQueue() {
|
||||
|
||||
template <typename TaskType>
|
||||
Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
|
||||
bool added_new_batch = false;
|
||||
ASBSBatch<TaskType>* new_batch = nullptr;
|
||||
size_t size = (*task)->size();
|
||||
if (size > options_.max_batch_size) {
|
||||
return errors::InvalidArgument("Task size ", size,
|
||||
@ -418,15 +418,14 @@ Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
|
||||
current_batch_ = nullptr;
|
||||
}
|
||||
if (!current_batch_) {
|
||||
added_new_batch = true;
|
||||
num_enqueued_batches_++;
|
||||
current_batch_ =
|
||||
current_batch_ = new_batch =
|
||||
new ASBSBatch<TaskType>(this, scheduler_->GetEnv()->NowMicros());
|
||||
}
|
||||
current_batch_->AddTask(std::move(*task));
|
||||
num_enqueued_tasks_++;
|
||||
}
|
||||
if (added_new_batch) scheduler_->AddBatch(current_batch_);
|
||||
if (new_batch != nullptr) scheduler_->AddBatch(new_batch);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -208,7 +208,7 @@ def extract_features(features, feature_columns):
|
||||
if tensor.dtype == dtypes.float32:
|
||||
if len(tensor.shape) > 1 and tensor.shape[1] > 1:
|
||||
unstacked = array_ops.unstack(tensor, axis=1)
|
||||
for i in xrange(len(unstacked)):
|
||||
for i in range(len(unstacked)):
|
||||
dense_float_names.append(_FEATURE_NAME_TEMPLATE % (key, i))
|
||||
dense_floats.append(array_ops.reshape(unstacked[i], [-1, 1]))
|
||||
else:
|
||||
|
@ -224,6 +224,7 @@ add_python_module("tensorflow/python/grappler")
|
||||
add_python_module("tensorflow/python/keras")
|
||||
add_python_module("tensorflow/python/keras/activations")
|
||||
add_python_module("tensorflow/python/keras/applications")
|
||||
add_python_module("tensorflow/python/keras/applications/inception_resnet_v2")
|
||||
add_python_module("tensorflow/python/keras/applications/inception_v3")
|
||||
add_python_module("tensorflow/python/keras/applications/mobilenet")
|
||||
add_python_module("tensorflow/python/keras/applications/resnet50")
|
||||
|
@ -30,6 +30,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview.
|
||||
@@make_saveable_from_iterator
|
||||
@@read_batch_features
|
||||
@@unbatch
|
||||
@@parallel_interleave
|
||||
@@rejection_resample
|
||||
@@sloppy_interleave
|
||||
|
||||
@ -50,6 +51,7 @@ from tensorflow.contrib.data.python.ops.dataset_ops import get_single_element
|
||||
from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset
|
||||
from tensorflow.contrib.data.python.ops.error_ops import ignore_errors
|
||||
from tensorflow.contrib.data.python.ops.grouping import group_by_window
|
||||
from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave
|
||||
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
|
||||
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
|
||||
from tensorflow.contrib.data.python.ops.readers import FixedLengthRecordDataset
|
||||
|
@ -191,9 +191,9 @@ def main(_):
|
||||
train_dir = None
|
||||
test_dir = None
|
||||
summary_writer = tf.contrib.summary.create_summary_file_writer(
|
||||
train_dir, flush_secs=10)
|
||||
train_dir, flush_millis=10000)
|
||||
test_summary_writer = tf.contrib.summary.create_summary_file_writer(
|
||||
test_dir, flush_secs=10, name='test')
|
||||
test_dir, flush_millis=10000, name='test')
|
||||
checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt')
|
||||
|
||||
with tf.device(device):
|
||||
|
@ -248,9 +248,9 @@ def main(_):
|
||||
log_dir = os.path.join(FLAGS.dir, "summaries")
|
||||
tf.gfile.MakeDirs(log_dir)
|
||||
train_summary_writer = tf.contrib.summary.create_summary_file_writer(
|
||||
os.path.join(log_dir, "train"), flush_secs=10)
|
||||
os.path.join(log_dir, "train"), flush_millis=10000)
|
||||
test_summary_writer = tf.contrib.summary.create_summary_file_writer(
|
||||
os.path.join(log_dir, "eval"), flush_secs=10, name="eval")
|
||||
os.path.join(log_dir, "eval"), flush_millis=10000, name="eval")
|
||||
|
||||
with tf.device(device):
|
||||
for epoch in range(FLAGS.num_epochs):
|
||||
|
@ -7,6 +7,7 @@ package(
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
@ -30,6 +31,7 @@ py_library(
|
||||
":head",
|
||||
":logit_fns",
|
||||
":multi_head",
|
||||
":replicate_model_fn",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
@ -227,9 +229,69 @@ py_test(
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python/estimator:metric_keys",
|
||||
"//tensorflow/python/estimator:model_fn",
|
||||
"//tensorflow/python/estimator:prediction_keys",
|
||||
"//tensorflow/python/ops/losses",
|
||||
"//tensorflow/python/saved_model:signature_constants",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "replicate_model_fn",
|
||||
srcs = [
|
||||
"python/estimator/replicate_model_fn.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:device",
|
||||
"//tensorflow/python:device_lib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/estimator:export_output",
|
||||
"//tensorflow/python/estimator:model_fn",
|
||||
"//tensorflow/python/estimator:util",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "replicate_model_fn_test",
|
||||
size = "small",
|
||||
srcs = ["python/estimator/replicate_model_fn_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python/estimator",
|
||||
"//tensorflow/python/estimator:dnn",
|
||||
"//tensorflow/python/estimator:export_export",
|
||||
"//tensorflow/python/estimator:export_output",
|
||||
"//tensorflow/python/estimator:model_fn",
|
||||
"//tensorflow/python/estimator:numpy_io",
|
||||
"//tensorflow/python/estimator:optimizers",
|
||||
"//tensorflow/python/estimator:prediction_keys",
|
||||
"//tensorflow/python/feature_column",
|
||||
"//tensorflow/python/ops/losses",
|
||||
"//tensorflow/python/saved_model:signature_constants",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:metrics",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:summary",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
":replicate_model_fn",
|
||||
],
|
||||
tags = ["requires-gpu-sm35"],
|
||||
)
|
||||
|
@ -0,0 +1,470 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities to replicate model_fn's over local GPUs.
|
||||
|
||||
This file contains util that allow to replicate `Estimator.model_fn` over
|
||||
GPUs. Replicated version of a `model_fn` is returned that can subsequently
|
||||
be used with `Estimator`.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.core.framework import node_def_pb2
|
||||
from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.estimator import util
|
||||
from tensorflow.python.estimator.export import export_output as export_output_lib
|
||||
from tensorflow.python.framework import device as framework_device
|
||||
from tensorflow.python.framework import ops as ops_lib
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gradients as gradients_lib
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
|
||||
def replicate_model_fn(model_fn, optimizer_fn, devices=None):
|
||||
"""Replicate `Estimator.model_fn` over GPUs within a single host.
|
||||
|
||||
The given `model_fn` specifies a single forward pass of a model. To replicate
|
||||
such a model over GPUs, each GPU gets its own instance of the forward pass
|
||||
(a.k.a. a tower). The input features and labels get sharded into the chunks
|
||||
that correspond to the number of GPUs. Each tower computes its own loss based
|
||||
on its input. For each such loss, gradients are computed. After that, the
|
||||
available losses are summed to form aggregated loss. The available
|
||||
gradients are summed too. Then, they update weights using the specified
|
||||
optimizer.
|
||||
|
||||
If `devices` are `None`, then all available GPUs are going to be used for
|
||||
replication. If no GPUs are available, then the model is going to be
|
||||
placed on the CPU.
|
||||
|
||||
Two modes of local replication over available GPUs are supported:
|
||||
1) If exactly 1 GPU is detected, then variables and operations are placed
|
||||
onto GPU.
|
||||
2) If more than 1 GPU is detected, then variables are going to be placed on
|
||||
the CPU. Replicas of operations are placed on each individual GPU.
|
||||
|
||||
Here is an example of how one might use their `model_fn` to run over GPUs:
|
||||
```python
|
||||
def optimizer_fn():
|
||||
return tf.train.GradientDescentOptimizer(learning_rate=0.001)
|
||||
...
|
||||
def model_fn(...): # See `model_fn` in `Estimator`.
|
||||
loss = ...
|
||||
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||
# See the section below on `EstimatorSpec.train_op`.
|
||||
return EstimatorSpec(mode=mode, loss=loss, train_op=tf.noop())
|
||||
|
||||
# No change for `ModeKeys.EVAL` or `ModeKeys.PREDICT`.
|
||||
return EstimatorSpec(...)
|
||||
...
|
||||
classifier = tf.estimator.Estimator(
|
||||
model_fn=replicate_model_fn.replicate_model_fn(model_fn, optimizer_fn))
|
||||
```
|
||||
|
||||
On `EstimatorSpec.train_op`:
|
||||
`model_fn` returns `EstimatorSpec.train_op` for
|
||||
`tf.estimator.GraphKeys.TRAIN`. It is typically derived using an optimizer.
|
||||
`replicate_model_fn` ignores the returned `EstimatorSpec.train_op`, so there
|
||||
is no need to use an optimizer inside the user's `model_fn`. The
|
||||
`EstimatorSpec.loss` subgraph is going to be executed, while
|
||||
`EstimatorSpec.train_op` isn't going to be executed. One could pass
|
||||
`train_op=tf.noop()` to `EstimatorSpec`.
|
||||
|
||||
On sharding input features and labels:
|
||||
Input features and labels are split for consumption by each tower. They are
|
||||
split across the dimension 0. Features and labels need to be batch major.
|
||||
|
||||
On reduction algorithms:
|
||||
Certain algorithms were chosen for aggregating results of computations on
|
||||
multiple towers:
|
||||
- Losses from all towers are reduced using sum.
|
||||
- Gradients are reduced using sum for each trainable variable.
|
||||
- `eval_metrics_ops` are reduced per metric using `reduce_mean`.
|
||||
- `EstimatorSpec.predictions` and `EstimatorSpec.export_outputs` are
|
||||
reduced using concatenation.
|
||||
- For all other fields of `EstimatorSpec` the values of the first tower
|
||||
are taken.
|
||||
|
||||
On replication of variables:
|
||||
Variables are not duplicated between towers. Instead, they are placed on a
|
||||
single device as defined above and shared across towers.
|
||||
|
||||
Other current limitations:
|
||||
- `predictions` are not supported for `ModeKeys.EVAL`. That is required for
|
||||
`tf.contrib.estimator.add_metrics`.
|
||||
|
||||
Args:
|
||||
model_fn: `model_fn` as defined in `Estimator`. See the section above about
|
||||
the train_op argument of `EstimatorSpec`.
|
||||
optimizer_fn: a function that returns an optimizer instance. The function
|
||||
may accept one `params` argument. This is the `params` argument as
|
||||
defined by `Estimator`. See the `Estimator` documentation for details.
|
||||
devices: Optional list of devices to replicate the model across. This
|
||||
argument can be used to replice only on the subset of available GPUs.
|
||||
If `None`, then all available GPUs are going to be used for replication.
|
||||
If no GPUs are available, then the model is going to be placed on the CPU.
|
||||
|
||||
Returns:
|
||||
A replicated version of the supplied `model_fn`. Returned function that
|
||||
conforms to the requirements of `Estimator`'s `model_fn` and can be used
|
||||
instead of the supplied `model_fn`.
|
||||
"""
|
||||
if not devices:
|
||||
devices = _get_local_devices('GPU') or _get_local_devices('CPU')
|
||||
|
||||
is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0]
|
||||
local_ps_device = '/{}:0'.format('GPU' if is_a_single_gpu_case else 'CPU')
|
||||
|
||||
tf_logging.info('Replicating the `model_fn` across {}. Local parameter '
|
||||
'server device is going to be {}.'.format(
|
||||
devices, local_ps_device))
|
||||
|
||||
def replicated_model_fn(mode, features, labels, params=None, config=None):
|
||||
"""Replicated version of `model_fn` to be used instead."""
|
||||
feature_shards, label_shards = _split_batch(
|
||||
features, labels, len(devices), device=local_ps_device)
|
||||
tower_specs = _get_loss_towers(
|
||||
model_fn=model_fn,
|
||||
mode=mode,
|
||||
features=feature_shards,
|
||||
labels=label_shards,
|
||||
params=params,
|
||||
config=config,
|
||||
devices=devices,
|
||||
local_ps_device=local_ps_device)
|
||||
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
train_op = _minimize_towers(tower_specs,
|
||||
_call_optimizer_fn(optimizer_fn, params))
|
||||
return _train_spec(
|
||||
tower_specs, train_op, aggregation_device=local_ps_device)
|
||||
elif mode == model_fn_lib.ModeKeys.EVAL:
|
||||
return _eval_spec(tower_specs, aggregation_device=local_ps_device)
|
||||
elif mode == model_fn_lib.ModeKeys.PREDICT:
|
||||
return _predict_spec(tower_specs, aggregation_device=local_ps_device)
|
||||
|
||||
return replicated_model_fn
|
||||
|
||||
|
||||
def _get_local_devices(device_type):
|
||||
local_device_protos = device_lib.list_local_devices()
|
||||
return [
|
||||
device.name
|
||||
for device in local_device_protos
|
||||
if device.device_type == device_type
|
||||
]
|
||||
|
||||
|
||||
def _split_batch(features, labels, number_of_shards, device):
|
||||
"""Split input features and labes into batches."""
|
||||
|
||||
def split_dictionary(dictionary):
|
||||
shards = [{} for _ in range(number_of_shards)]
|
||||
for name, tensor in six.iteritems(dictionary):
|
||||
for i, shard in enumerate(array_ops.split(tensor, number_of_shards)):
|
||||
shards[i][name] = shard
|
||||
return shards
|
||||
|
||||
with ops_lib.name_scope('split_inputs'):
|
||||
with ops_lib.device(device):
|
||||
if isinstance(features, dict):
|
||||
feature_shards = split_dictionary(features)
|
||||
else:
|
||||
feature_shards = array_ops.split(features, number_of_shards)
|
||||
|
||||
if labels is None:
|
||||
label_shards = None
|
||||
elif isinstance(labels, dict):
|
||||
label_shards = split_dictionary(labels)
|
||||
else:
|
||||
label_shards = array_ops.split(labels, number_of_shards)
|
||||
return feature_shards, label_shards
|
||||
|
||||
|
||||
_DEFAULT_NAME_SCOPE_PATTERN = 'tower_{}'
|
||||
|
||||
|
||||
def _get_loss_towers(model_fn,
|
||||
mode,
|
||||
features,
|
||||
labels,
|
||||
params,
|
||||
config,
|
||||
devices,
|
||||
local_ps_device,
|
||||
name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN):
|
||||
"""Replicate the loss computation across devices."""
|
||||
tower_specs = []
|
||||
|
||||
model_fn_args = util.fn_args(model_fn)
|
||||
optional_params = {}
|
||||
if 'params' in model_fn_args:
|
||||
optional_params['params'] = copy.deepcopy(params)
|
||||
if 'config' in model_fn_args:
|
||||
optional_params['config'] = copy.deepcopy(config)
|
||||
|
||||
for i, device in enumerate(devices):
|
||||
is_the_first_tower = (i == 0)
|
||||
|
||||
device_setter = _local_device_setter(
|
||||
worker_device=device, ps_device=local_ps_device)
|
||||
|
||||
# We would like to preserve the names of the variables and ops that a user
|
||||
# might be relying on. Names with prefix are going to resolve to variables
|
||||
# and ops of the first tower.
|
||||
name_scope = name_scope_pattern
|
||||
if is_the_first_tower:
|
||||
name_scope = ''
|
||||
|
||||
with variable_scope.variable_scope('', reuse=not is_the_first_tower):
|
||||
with ops_lib.name_scope(name_scope.format(i)):
|
||||
with ops_lib.device(device_setter):
|
||||
labels_shard = None
|
||||
if labels:
|
||||
labels_shard = labels[i]
|
||||
|
||||
tower_specs.append(
|
||||
model_fn(
|
||||
mode=mode,
|
||||
features=features[i],
|
||||
labels=labels_shard,
|
||||
**optional_params))
|
||||
return tower_specs
|
||||
|
||||
|
||||
def _local_device_setter(ps_device, worker_device):
|
||||
"""A device setter that puts distributes Var/Ops to PS/workers."""
|
||||
ps_ops = ['Variable', 'VariableV2', 'VarHandleOp']
|
||||
|
||||
def local_device_chooser(op):
|
||||
current_device = framework_device.DeviceSpec.from_string(op.device or '')
|
||||
|
||||
node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
|
||||
if node_def.op in ps_ops:
|
||||
ps_device_spec = framework_device.DeviceSpec.from_string(
|
||||
'{}'.format(ps_device))
|
||||
|
||||
ps_device_spec.merge_from(current_device)
|
||||
return ps_device_spec.to_string()
|
||||
else:
|
||||
worker_device_spec = framework_device.DeviceSpec.from_string(
|
||||
worker_device or '')
|
||||
worker_device_spec.merge_from(current_device)
|
||||
return worker_device_spec.to_string()
|
||||
|
||||
return local_device_chooser
|
||||
|
||||
|
||||
def _minimize_towers(tower_specs, optimizer):
|
||||
"""Aggregate and apply gradients for computed losses."""
|
||||
grad_lists = {}
|
||||
for tower_spec in tower_specs:
|
||||
with ops_lib.device(tower_spec.loss.device):
|
||||
variables = variables_lib.trainable_variables()
|
||||
gradients = gradients_lib.gradients(tower_spec.loss, variables)
|
||||
|
||||
for var, grad in zip(variables, gradients):
|
||||
if grad is not None:
|
||||
grad_lists.setdefault(var, []).append(grad)
|
||||
|
||||
aggregated_grads = []
|
||||
with ops_lib.name_scope('gradient_aggregating'):
|
||||
for var, grads in six.iteritems(grad_lists):
|
||||
grad = _compute_sum_on_device(grads, var.device)
|
||||
aggregated_grads.append((grad, var))
|
||||
|
||||
train_op = optimizer.apply_gradients(
|
||||
aggregated_grads, global_step=training_util.get_global_step())
|
||||
|
||||
return train_op
|
||||
|
||||
|
||||
def _call_optimizer_fn(optimizer_fn, params):
|
||||
arguments = {}
|
||||
optimizer_fn_arguments = util.fn_args(optimizer_fn)
|
||||
if 'params' in optimizer_fn_arguments:
|
||||
arguments['params'] = params
|
||||
return optimizer_fn(**arguments)
|
||||
|
||||
|
||||
def _compute_sum_on_device(values, device, name=None):
|
||||
with ops_lib.device(device):
|
||||
return math_ops.add_n(values, name=name)
|
||||
|
||||
|
||||
def _train_spec(tower_specs,
|
||||
train_op,
|
||||
aggregation_device,
|
||||
aggregated_loss_name='loss'):
|
||||
"""Populate replicated EstimatorSpec for `GraphKeys.TRAIN`."""
|
||||
estimator_spec = tower_specs[0]._asdict()
|
||||
estimator_spec['mode'] = model_fn_lib.ModeKeys.TRAIN
|
||||
estimator_spec['train_op'] = train_op
|
||||
estimator_spec['loss'] = _compute_sum_on_device(
|
||||
[spec.loss for spec in tower_specs], aggregation_device,
|
||||
aggregated_loss_name)
|
||||
return model_fn_lib.EstimatorSpec(**estimator_spec)
|
||||
|
||||
|
||||
def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'):
|
||||
"""Populate replicated EstimatorSpec for `GraphKeys.EVAL`."""
|
||||
estimator_spec = tower_specs[0]._asdict()
|
||||
estimator_spec['mode'] = model_fn_lib.ModeKeys.EVAL
|
||||
estimator_spec['loss'] = _compute_sum_on_device(
|
||||
[spec.loss for spec in tower_specs], aggregation_device,
|
||||
aggregated_loss_name)
|
||||
|
||||
eval_metric_ops_lists = {}
|
||||
for tower_spec in tower_specs:
|
||||
metrics = tower_spec.eval_metric_ops or {}
|
||||
for name, (_, update_op) in six.iteritems(metrics):
|
||||
update_ops = eval_metric_ops_lists.setdefault(name, ([]))
|
||||
update_ops.append(update_op)
|
||||
|
||||
eval_metric_ops = {}
|
||||
for name, (metric_tensor, _) in six.iteritems(tower_specs[0].eval_metric_ops):
|
||||
with ops_lib.control_dependencies(eval_metric_ops_lists[name]):
|
||||
# This operation reduces local variables across all metrics, yet is
|
||||
# called for every metric. This is redundant and it's done because
|
||||
# it is hard to know what local variables correspond to what metric.
|
||||
# Estimator is going to execute all `reduced_update_op`s as part of
|
||||
# a group inside a single `Session.run()` call, which will avoid duplicate
|
||||
# computation.
|
||||
reduced_update_op = _reduce_metric_variables(len(tower_specs))
|
||||
eval_metric_ops[name] = (metric_tensor, reduced_update_op)
|
||||
|
||||
estimator_spec['eval_metric_ops'] = eval_metric_ops
|
||||
return model_fn_lib.EstimatorSpec(**estimator_spec)
|
||||
|
||||
|
||||
def _reduce_metric_variables(number_of_towers):
|
||||
"""Aggregate local variables used in metrics into the first tower."""
|
||||
if number_of_towers == 1:
|
||||
return control_flow_ops.no_op()
|
||||
|
||||
metric_variables = ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)
|
||||
variables_per_tower = len(metric_variables) // number_of_towers
|
||||
|
||||
if len(metric_variables) % number_of_towers != 0:
|
||||
raise ValueError(
|
||||
'Different `EstimatorSpec.eval_metric_ops` across `model_fn()` calls.'
|
||||
' Expected {} local variables, but got {} instead.'.format(
|
||||
variables_per_tower * number_of_towers, len(metric_variables)))
|
||||
|
||||
# `metric_variables` has the size of `variables_per_tower` x
|
||||
# number_of_towers. Each tower is produced by calling the same model_fn.
|
||||
# First `variables_per_tower` correspond to the first tower. Each such
|
||||
# variable has an replica at the `(variables_per_tower * i)` position, where
|
||||
# `i` is `[1.. number_of_towers]`. We are going to add values from replicas
|
||||
# to each variable of the first tower. We then zero out replica values, so
|
||||
# that `_reduce_metric_variables` operation is idempotent. If a metric
|
||||
# is then computed based on local variables from the first tower, then the
|
||||
# resulting metric is an estimate for all `number_of_towers` towers.
|
||||
ops = []
|
||||
for i in range(0, variables_per_tower):
|
||||
next_replica_id = i + variables_per_tower
|
||||
replicas = [
|
||||
metric_variables[replica_id]
|
||||
for replica_id in range(next_replica_id, len(metric_variables),
|
||||
variables_per_tower)
|
||||
] # `replicas` doesn't contain the first-tower variable.
|
||||
|
||||
reduce_op = state_ops.assign_add(metric_variables[i],
|
||||
math_ops.add_n(replicas))
|
||||
|
||||
with ops_lib.control_dependencies([reduce_op]):
|
||||
for replica in replicas:
|
||||
zeros_for_replica = array_ops.zeros(
|
||||
array_ops.shape(replica), dtype=replica.dtype)
|
||||
zero_out_replica_op = state_ops.assign(replica, zeros_for_replica)
|
||||
ops.append(zero_out_replica_op)
|
||||
|
||||
return control_flow_ops.group(*ops)
|
||||
|
||||
|
||||
def _predict_spec(tower_specs, aggregation_device):
|
||||
"""Populate replicated EstimatorSpec for `GraphKeys.PREDICT`."""
|
||||
estimator_spec = tower_specs[0]._asdict()
|
||||
estimator_spec['mode'] = model_fn_lib.ModeKeys.PREDICT
|
||||
|
||||
with ops_lib.device(aggregation_device):
|
||||
estimator_spec['predictions'] = _concat_tensor_dicts(
|
||||
*[tower_spec.predictions for tower_spec in tower_specs])
|
||||
|
||||
export_outputs_dict = _dict_concat(
|
||||
*[tower_spec.export_outputs for tower_spec in tower_specs])
|
||||
|
||||
export_outputs = {}
|
||||
for name, export_output_list in six.iteritems(export_outputs_dict):
|
||||
if isinstance(export_output_list[0], export_output_lib.PredictOutput):
|
||||
export_outputs[name] = export_output_lib.PredictOutput(
|
||||
outputs=_concat_tensor_dicts(*[
|
||||
export_output.outputs for export_output in export_output_list
|
||||
]))
|
||||
elif isinstance(export_output_list[0],
|
||||
export_output_lib.RegressionOutput):
|
||||
export_outputs[name] = export_output_lib.RegressionOutput(
|
||||
value=array_ops.concat(
|
||||
[export_output.value for export_output in export_output_list],
|
||||
axis=0))
|
||||
elif isinstance(export_output_list[0],
|
||||
export_output_lib.ClassificationOutput):
|
||||
scores = None
|
||||
if export_output_list[0].scores is not None:
|
||||
scores = array_ops.concat(
|
||||
[export_output.scores for export_output in export_output_list],
|
||||
axis=0)
|
||||
|
||||
classes = None
|
||||
if export_output_list[0].classes is not None:
|
||||
classes = array_ops.stack(
|
||||
[export_output.classes for export_output in export_output_list],
|
||||
axis=0)
|
||||
|
||||
export_outputs[name] = export_output_lib.ClassificationOutput(
|
||||
scores=scores, classes=classes)
|
||||
|
||||
estimator_spec['export_outputs'] = export_outputs
|
||||
return model_fn_lib.EstimatorSpec(**estimator_spec)
|
||||
|
||||
|
||||
def _concat_tensor_dicts(*tensor_dicts):
|
||||
return {
|
||||
name: array_ops.concat(tensors, axis=0, name=name)
|
||||
for name, tensors in six.iteritems(_dict_concat(*tensor_dicts))
|
||||
}
|
||||
|
||||
|
||||
def _dict_concat(*dicts):
|
||||
list_dict = {}
|
||||
for d in dicts:
|
||||
if d is None:
|
||||
continue
|
||||
|
||||
for k, v in six.iteritems(d):
|
||||
list_dict.setdefault(k, []).append(v)
|
||||
return list_dict
|
@ -0,0 +1,901 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for utilities that replicate `Estimator.model_fn` over GPUs."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.estimator.python.estimator import replicate_model_fn
|
||||
from tensorflow.python.estimator import estimator as estimator_lib
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.estimator.canned import dnn
|
||||
from tensorflow.python.estimator.canned import optimizers
|
||||
from tensorflow.python.estimator.canned import prediction_keys
|
||||
from tensorflow.python.estimator.export import export
|
||||
from tensorflow.python.estimator.export import export_output
|
||||
from tensorflow.python.estimator.inputs import numpy_io
|
||||
from tensorflow.python.feature_column import feature_column
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops as ops_lib
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import metrics as metrics_lib
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.losses import losses
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.summary.writer import writer_cache
|
||||
from tensorflow.python.training import gradient_descent
|
||||
|
||||
|
||||
class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._model_dir = tempfile.mkdtemp()
|
||||
|
||||
def test_complete_flow(self):
|
||||
n_classes = 3
|
||||
input_dimension = 2
|
||||
batch_size = 12
|
||||
|
||||
data = np.linspace(
|
||||
0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
|
||||
x_data = data.reshape(batch_size, input_dimension)
|
||||
y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))
|
||||
train_input_fn = numpy_io.numpy_input_fn(
|
||||
x={'x': x_data},
|
||||
y=y_data,
|
||||
batch_size=batch_size,
|
||||
num_epochs=None,
|
||||
shuffle=True)
|
||||
eval_input_fn = numpy_io.numpy_input_fn(
|
||||
x={'x': x_data}, y=y_data, batch_size=batch_size, shuffle=False)
|
||||
predict_input_fn = numpy_io.numpy_input_fn(
|
||||
x={'x': x_data}, batch_size=batch_size, shuffle=False)
|
||||
|
||||
feature_columns = [
|
||||
feature_column.numeric_column('x', shape=(input_dimension,))
|
||||
]
|
||||
|
||||
estimator = dnn.DNNClassifier(
|
||||
hidden_units=(2, 2),
|
||||
feature_columns=feature_columns,
|
||||
n_classes=n_classes,
|
||||
model_dir=self._model_dir)
|
||||
|
||||
def optimizer_fn():
|
||||
return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05)
|
||||
|
||||
# TODO(isaprykin): Switch Estimator to use allow_soft_placement=True
|
||||
# during export_savedmodel and then switch this test to replicate over
|
||||
# GPUs instead of CPUs.
|
||||
estimator = estimator_lib.Estimator(
|
||||
model_fn=replicate_model_fn.replicate_model_fn(
|
||||
estimator.model_fn,
|
||||
optimizer_fn,
|
||||
devices=['/cpu:0', '/cpu:0', '/cpu:0']),
|
||||
model_dir=estimator.model_dir,
|
||||
config=estimator.config,
|
||||
params=estimator.params)
|
||||
|
||||
num_steps = 10
|
||||
estimator.train(train_input_fn, steps=num_steps)
|
||||
|
||||
scores = estimator.evaluate(eval_input_fn)
|
||||
self.assertEqual(num_steps, scores[ops_lib.GraphKeys.GLOBAL_STEP])
|
||||
self.assertIn('loss', six.iterkeys(scores))
|
||||
|
||||
predicted_proba = np.array([
|
||||
x[prediction_keys.PredictionKeys.PROBABILITIES]
|
||||
for x in estimator.predict(predict_input_fn)
|
||||
])
|
||||
self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
|
||||
|
||||
feature_spec = feature_column.make_parse_example_spec(feature_columns)
|
||||
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
|
||||
feature_spec)
|
||||
export_dir = estimator.export_savedmodel(tempfile.mkdtemp(),
|
||||
serving_input_receiver_fn)
|
||||
self.assertTrue(gfile.Exists(export_dir))
|
||||
|
||||
def _as_label(self, data_in_float):
|
||||
return np.rint(data_in_float).astype(np.int64)
|
||||
|
||||
def tearDown(self):
|
||||
if self._model_dir:
|
||||
writer_cache.FileWriterCache.clear()
|
||||
shutil.rmtree(self._model_dir)
|
||||
|
||||
|
||||
class ReplicateModelTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def model_fn(self, mode, features, labels, params):
|
||||
c = variable_scope.get_variable(
|
||||
'c',
|
||||
initializer=constant_op.constant(10, dtype=dtypes.float64),
|
||||
dtype=dtypes.float64)
|
||||
|
||||
predictions = math_ops.multiply(features, c)
|
||||
|
||||
loss = None
|
||||
if mode is not model_fn_lib.ModeKeys.PREDICT:
|
||||
loss = losses.absolute_difference(
|
||||
labels=labels,
|
||||
predictions=predictions,
|
||||
reduction=losses.Reduction.SUM)
|
||||
loss = math_ops.reduce_sum(loss)
|
||||
|
||||
metrics = {
|
||||
'accuracy': metrics_lib.accuracy(labels, predictions),
|
||||
'auc': metrics_lib.auc(labels, predictions)
|
||||
}
|
||||
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode=mode,
|
||||
loss=loss,
|
||||
eval_metric_ops=metrics,
|
||||
predictions={'probabilities': predictions},
|
||||
train_op=control_flow_ops.no_op()) # This train_op isn't actually used.
|
||||
|
||||
def optimizer_fn(self, params):
|
||||
return gradient_descent.GradientDescentOptimizer(params['learning_rate'])
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
params = {}
|
||||
params['learning_rate'] = 1.0
|
||||
return params
|
||||
|
||||
def test_train(self):
|
||||
features = np.array([[1.0], [2.0]])
|
||||
labels = np.array([[1.0], [2.0]])
|
||||
|
||||
with self.test_session() as session:
|
||||
replicated_model_fn = replicate_model_fn.replicate_model_fn(
|
||||
self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
|
||||
estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN,
|
||||
features, labels, self.params)
|
||||
session.run(variables.global_variables_initializer())
|
||||
|
||||
# loss = feature * c - label
|
||||
total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
|
||||
self.assertEqual(total_loss, session.run(estimator_spec.loss))
|
||||
|
||||
# loss' of c is 3.
|
||||
# new value of c = 10 - learning rate * 3 = 7.0.
|
||||
session.run(estimator_spec.train_op)
|
||||
with variable_scope.variable_scope('', reuse=True):
|
||||
c = variable_scope.get_variable('c', dtype=dtypes.float64)
|
||||
self.assertEqual(7.0, session.run(c))
|
||||
|
||||
def test_train_spec_with_optimizer_without_params(self):
|
||||
|
||||
def optimizer_fn_without_params():
|
||||
return gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
features = np.array([[1.0], [2.0]])
|
||||
labels = np.array([[1.0], [2.0]])
|
||||
|
||||
with self.test_session() as session: # pylint: disable=unused-variable
|
||||
replicated_model_fn = replicate_model_fn.replicate_model_fn(
|
||||
self.model_fn,
|
||||
optimizer_fn_without_params,
|
||||
devices=['/gpu:0', '/gpu:1'])
|
||||
# This call is going to fail if `replicated_model_fn` is still passing
|
||||
# `params` inside `optimizer_fn`, even though the latter doesn't take any:
|
||||
estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN,
|
||||
features, labels, self.params)
|
||||
del estimator_spec
|
||||
|
||||
def test_eval(self):
|
||||
features = np.array([[0.01], [0.002]])
|
||||
labels = np.array([[0.01], [0.02]])
|
||||
|
||||
with self.test_session() as session:
|
||||
replicated_model_fn = replicate_model_fn.replicate_model_fn(
|
||||
self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
|
||||
estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features,
|
||||
labels, self.params)
|
||||
session.run(variables.local_variables_initializer())
|
||||
session.run(variables.global_variables_initializer())
|
||||
|
||||
accuracy, a = estimator_spec.eval_metric_ops['accuracy']
|
||||
auc, b = estimator_spec.eval_metric_ops['auc']
|
||||
|
||||
session.run([a, b])
|
||||
accuracy = session.run(accuracy)
|
||||
auc = session.run(auc)
|
||||
|
||||
# Accuracy is 0.0 (no match) in the first tower.
|
||||
# Accuracy is 1.0 (match) in the second tower, since the feature
|
||||
# times weight "c" happened to be equal to the label.
|
||||
total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02))
|
||||
|
||||
self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
|
||||
self.assertEqual(0, auc)
|
||||
self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
|
||||
|
||||
def test_predict(self):
|
||||
features = np.array([[0.01], [0.002]])
|
||||
labels = np.array([[0.01], [0.02]])
|
||||
|
||||
with self.test_session() as session:
|
||||
replicated_model_fn = replicate_model_fn.replicate_model_fn(
|
||||
self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
|
||||
estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT,
|
||||
features, labels, self.params)
|
||||
session.run(variables.global_variables_initializer())
|
||||
|
||||
self.assertAllClose({
|
||||
'probabilities': np.array([[0.1], [0.02]])
|
||||
}, session.run(estimator_spec.predictions))
|
||||
|
||||
def test_train_single_tower(self):
|
||||
features = np.array([[1.0], [2.0]])
|
||||
labels = np.array([[1.0], [2.0]])
|
||||
|
||||
with self.test_session() as session:
|
||||
replicated_model_fn = replicate_model_fn.replicate_model_fn(
|
||||
self.model_fn, self.optimizer_fn)
|
||||
estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN,
|
||||
features, labels, self.params)
|
||||
session.run(variables.global_variables_initializer())
|
||||
|
||||
# loss = feature * c - label
|
||||
total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
|
||||
self.assertEqual(total_loss, session.run(estimator_spec.loss))
|
||||
|
||||
# loss' of c is 3.
|
||||
# new value of c = 10 - learning rate * 3 = 7.0.
|
||||
session.run(estimator_spec.train_op)
|
||||
with variable_scope.variable_scope('', reuse=True):
|
||||
c = variable_scope.get_variable('c', dtype=dtypes.float64)
|
||||
self.assertEqual(7.0, session.run(c))
|
||||
|
||||
def test_eval_single_tower(self):
|
||||
features = np.array([[0.01], [0.002]])
|
||||
labels = np.array([[0.01], [0.02]])
|
||||
|
||||
with self.test_session() as session:
|
||||
replicated_model_fn = replicate_model_fn.replicate_model_fn(
|
||||
self.model_fn, self.optimizer_fn, devices=['/gpu:0'])
|
||||
estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features,
|
||||
labels, self.params)
|
||||
session.run(variables.local_variables_initializer())
|
||||
session.run(variables.global_variables_initializer())
|
||||
|
||||
accuracy, a = estimator_spec.eval_metric_ops['accuracy']
|
||||
auc, b = estimator_spec.eval_metric_ops['auc']
|
||||
|
||||
session.run([a, b])
|
||||
accuracy = session.run(accuracy)
|
||||
auc = session.run(auc)
|
||||
|
||||
# Accuracy is 0.0 (no match) in the first tower.
|
||||
# Accuracy is 1.0 (match) in the second tower, since the feature
|
||||
# times weight "c" happened to be equal to the label.
|
||||
total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02))
|
||||
|
||||
self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
|
||||
self.assertEqual(0, auc)
|
||||
self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
|
||||
|
||||
def test_predict_single_tower(self):
|
||||
features = np.array([[0.01], [0.002]])
|
||||
labels = np.array([[0.01], [0.02]])
|
||||
|
||||
with self.test_session() as session:
|
||||
replicated_model_fn = replicate_model_fn.replicate_model_fn(
|
||||
self.model_fn, self.optimizer_fn, devices=['/gpu:0'])
|
||||
estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT,
|
||||
features, labels, self.params)
|
||||
session.run(variables.global_variables_initializer())
|
||||
|
||||
self.assertAllClose({
|
||||
'probabilities': np.array([[0.1], [0.02]])
|
||||
}, session.run(estimator_spec.predictions))
|
||||
|
||||
|
||||
class GetLossTowersTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def model_fn(self, mode, features, labels, params):
|
||||
c = variable_scope.get_variable(
|
||||
'c',
|
||||
initializer=constant_op.constant(0.25, dtype=dtypes.float64),
|
||||
dtype=dtypes.float64)
|
||||
|
||||
predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c)
|
||||
labels = np.array([0.1, 0.2, 0.3, labels[0]])
|
||||
|
||||
loss = losses.absolute_difference(
|
||||
labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
|
||||
|
||||
return model_fn_lib.EstimatorSpec(mode=mode, loss=math_ops.reduce_sum(loss))
|
||||
|
||||
def test_gradients_are_computed(self):
|
||||
with self.test_session() as session:
|
||||
tower_specs = replicate_model_fn._get_loss_towers(
|
||||
self.model_fn,
|
||||
mode=None,
|
||||
features=[[0.6], [1.6]],
|
||||
labels=[[0.6], [0.6]],
|
||||
params=None,
|
||||
config=None,
|
||||
devices=['/gpu:0', '/gpu:1'],
|
||||
local_ps_device='/gpu:0',
|
||||
name_scope_pattern='test_tower_{}')
|
||||
session.run(variables.global_variables_initializer())
|
||||
|
||||
self.assertEqual(len(tower_specs), 2)
|
||||
|
||||
self.assertEqual('/device:GPU:0', tower_specs[0].loss.device)
|
||||
self.assertEqual('Sum:0', tower_specs[0].loss.name)
|
||||
self.assertEqual(1.0, session.run(tower_specs[0].loss))
|
||||
|
||||
self.assertEqual('/device:GPU:1', tower_specs[1].loss.device)
|
||||
self.assertEqual('test_tower_1/Sum:0', tower_specs[1].loss.name)
|
||||
# The input batch for the second tower had a loss that is 1.0
|
||||
# bigger: 0.6 vs 1.6.
|
||||
self.assertEqual(2.0, session.run(tower_specs[1].loss))
|
||||
|
||||
self.assertEqual(1, len(variables.global_variables()))
|
||||
self.assertEqual(1, len(variables.trainable_variables()))
|
||||
|
||||
with variable_scope.variable_scope('', reuse=True):
|
||||
c = variable_scope.get_variable('c', dtype=dtypes.float64)
|
||||
self.assertEqual(0.25, session.run(c))
|
||||
|
||||
|
||||
class SplitBatchTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def evaluate_shards(self, first_list, second_list):
|
||||
evaluate_items = lambda x: x.eval()
|
||||
return list(map(evaluate_items, first_list)), list(
|
||||
map(evaluate_items, second_list))
|
||||
|
||||
def test_simple_half_split(self):
|
||||
with self.test_session() as session: # pylint: disable=unused-variable
|
||||
features = [0.0, 1.0, 2.0, 3.0]
|
||||
labels = [10.0, 11.0, 12.0, 13.0]
|
||||
feature_shards, label_shards = replicate_model_fn._split_batch(
|
||||
features, labels, 2, device='/gpu:0')
|
||||
|
||||
feature_shards, label_shards = self.evaluate_shards(
|
||||
feature_shards, label_shards)
|
||||
|
||||
self.assertAllEqual([[0.0, 1.0], [2.0, 3.0]], feature_shards)
|
||||
self.assertAllEqual([[10.0, 11.0], [12.0, 13.0]], label_shards)
|
||||
|
||||
def test_to_each_their_own(self):
|
||||
with self.test_session() as session: # pylint: disable=unused-variable
|
||||
features = [0.0, 1.0, 2.0, 3.0]
|
||||
labels = [10.0, 11.0, 12.0, 13.0]
|
||||
feature_shards, label_shards = replicate_model_fn._split_batch(
|
||||
features, labels, 4, device='/gpu:0')
|
||||
|
||||
feature_shards, label_shards = self.evaluate_shards(
|
||||
feature_shards, label_shards)
|
||||
|
||||
self.assertAllEqual([[0.0], [1.0], [2.0], [3.0]], feature_shards)
|
||||
self.assertAllEqual([[10.0], [11.0], [12.0], [13.0]], label_shards)
|
||||
|
||||
def test_one_batch(self):
|
||||
with self.test_session() as session: # pylint: disable=unused-variable
|
||||
features = [0.0, 1.0, 2.0, 3.0]
|
||||
labels = [10.0, 11.0, 12.0, 13.0]
|
||||
feature_shards, label_shards = replicate_model_fn._split_batch(
|
||||
features, labels, 1, device='/gpu:0')
|
||||
|
||||
feature_shards, label_shards = self.evaluate_shards(
|
||||
feature_shards, label_shards)
|
||||
|
||||
self.assertAllEqual([[0.0, 1.0, 2.0, 3.0]], feature_shards)
|
||||
self.assertAllEqual([[10.0, 11.0, 12.0, 13.0]], label_shards)
|
||||
|
||||
def test_half_split_in_dictionary(self):
|
||||
with self.test_session() as session: # pylint: disable=unused-variable
|
||||
features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
|
||||
labels = [10.0, 11.0, 12.0, 13.0]
|
||||
|
||||
feature_shards, label_shards = replicate_model_fn._split_batch(
|
||||
features, labels, 2, device='/gpu:0')
|
||||
|
||||
self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval())
|
||||
self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval())
|
||||
self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval())
|
||||
self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval())
|
||||
self.assertAllEqual([10.0, 11.0], label_shards[0].eval())
|
||||
self.assertAllEqual([12.0, 13.0], label_shards[1].eval())
|
||||
|
||||
def test_one_batch_in_dictionary(self):
|
||||
with self.test_session() as session: # pylint: disable=unused-variable
|
||||
features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
|
||||
labels = [10.0, 11.0, 12.0, 13.0]
|
||||
|
||||
feature_shards, label_shards = replicate_model_fn._split_batch(
|
||||
features, labels, 1, device='/gpu:0')
|
||||
|
||||
self.assertAllEqual([0.0, 1.0, 2.0, 3.0],
|
||||
feature_shards[0]['first'].eval())
|
||||
self.assertAllEqual([4.0, 5.0, 6.0, 7.0],
|
||||
feature_shards[0]['second'].eval())
|
||||
self.assertAllEqual([10.0, 11.0, 12.0, 13.0], label_shards[0].eval())
|
||||
|
||||
def test_feature_and_label_dictionaries(self):
|
||||
with self.test_session() as session: # pylint: disable=unused-variable
|
||||
features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
|
||||
labels = {'first': [10.0, 11.0], 'second': [12.0, 13.0]}
|
||||
|
||||
feature_shards, label_shards = replicate_model_fn._split_batch(
|
||||
features, labels, 2, device='/gpu:0')
|
||||
|
||||
self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval())
|
||||
self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval())
|
||||
self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval())
|
||||
self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval())
|
||||
self.assertAllEqual([10.0], label_shards[0]['first'].eval())
|
||||
self.assertAllEqual([12.0], label_shards[0]['second'].eval())
|
||||
self.assertAllEqual([11], label_shards[1]['first'].eval())
|
||||
self.assertAllEqual([13.0], label_shards[1]['second'].eval())
|
||||
|
||||
|
||||
class TrainSpecTest(test_util.TensorFlowTestCase):
|
||||
|
||||
expected_predictions = {}
|
||||
|
||||
def create_estimator_spec(self, loss):
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode=model_fn_lib.ModeKeys.TRAIN,
|
||||
loss=loss,
|
||||
train_op=loss, # Not used; currently required.
|
||||
predictions=self.expected_predictions)
|
||||
|
||||
def create_constant_loss(self, loss_value):
|
||||
return constant_op.constant(loss_value, dtype=dtypes.float64)
|
||||
|
||||
def test_example(self):
|
||||
with self.test_session() as session:
|
||||
tower_losses = list(map(self.create_constant_loss, [2, 4, 6]))
|
||||
tower_specs = list(map(self.create_estimator_spec, tower_losses))
|
||||
|
||||
expected_train_op = tower_losses[1]
|
||||
|
||||
estimator_spec = replicate_model_fn._train_spec(
|
||||
tower_specs, expected_train_op, aggregation_device='/gpu:0')
|
||||
|
||||
self.assertEqual(expected_train_op, estimator_spec.train_op)
|
||||
self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
|
||||
self.assertEqual(self.expected_predictions, estimator_spec.predictions)
|
||||
|
||||
|
||||
class EvalSpecTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def create_estimator_spec(self, loss, metrics):
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode=model_fn_lib.ModeKeys.EVAL, loss=loss, eval_metric_ops=metrics)
|
||||
|
||||
def create_constant_loss(self, loss_value):
|
||||
return constant_op.constant(loss_value, dtype=dtypes.float64)
|
||||
|
||||
def create_eval_metrics(self, noise):
|
||||
predictions = np.array([0.1, 0.2, 0.3, 0.6 + noise])
|
||||
labels = np.array([0.1, 0.2, 0.3, 0.6])
|
||||
|
||||
metrics = {
|
||||
'accuracy': metrics_lib.accuracy(labels, predictions),
|
||||
'auc': metrics_lib.auc(labels, predictions)
|
||||
}
|
||||
return metrics
|
||||
|
||||
def test_example(self):
|
||||
with self.test_session() as session:
|
||||
tower_losses = map(self.create_constant_loss, [2, 4, 6])
|
||||
tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3])
|
||||
tower_specs = [
|
||||
self.create_estimator_spec(l, m)
|
||||
for l, m in zip(tower_losses, tower_metrics)
|
||||
]
|
||||
session.run(variables.local_variables_initializer())
|
||||
|
||||
estimator_spec = replicate_model_fn._eval_spec(
|
||||
tower_specs, aggregation_device='/device:GPU:0')
|
||||
|
||||
accuracy, a = estimator_spec.eval_metric_ops['accuracy']
|
||||
auc, b = estimator_spec.eval_metric_ops['auc']
|
||||
|
||||
self.assertEqual('/device:CPU:0', accuracy.device)
|
||||
self.assertEqual('/device:CPU:0', auc.device)
|
||||
|
||||
session.run([a, b])
|
||||
accuracy = session.run(accuracy)
|
||||
auc = session.run(auc)
|
||||
|
||||
self.assertNear((12 - 2) / 12, accuracy, 0.01)
|
||||
self.assertEqual(0, auc)
|
||||
self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
|
||||
|
||||
def test_handles_single_tower(self):
|
||||
with self.test_session() as session:
|
||||
tower_losses = map(self.create_constant_loss, [5])
|
||||
tower_metrics = map(self.create_eval_metrics, [0.2])
|
||||
tower_specs = [
|
||||
self.create_estimator_spec(l, m)
|
||||
for l, m in zip(tower_losses, tower_metrics)
|
||||
]
|
||||
session.run(variables.local_variables_initializer())
|
||||
|
||||
estimator_spec = replicate_model_fn._eval_spec(
|
||||
tower_specs, aggregation_device='/device:GPU:0')
|
||||
|
||||
accuracy, a = estimator_spec.eval_metric_ops['accuracy']
|
||||
auc, b = estimator_spec.eval_metric_ops['auc']
|
||||
|
||||
self.assertEqual('/device:CPU:0', accuracy.device)
|
||||
self.assertEqual('/device:CPU:0', auc.device)
|
||||
|
||||
session.run([a, b])
|
||||
accuracy = session.run(accuracy)
|
||||
auc = session.run(auc)
|
||||
|
||||
self.assertNear((4 - 1) / 4, accuracy, 0.01)
|
||||
self.assertEqual(0, auc)
|
||||
self.assertEqual(5, session.run(estimator_spec.loss))
|
||||
|
||||
|
||||
class PredictSpecTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def model_fn(self, mode, features, labels, params):
|
||||
c = variable_scope.get_variable(
|
||||
'c',
|
||||
initializer=constant_op.constant(0.25, dtype=dtypes.float64),
|
||||
dtype=dtypes.float64)
|
||||
|
||||
predictions = math_ops.add(np.array([features[0], features[0]]), c)
|
||||
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode=model_fn_lib.ModeKeys.PREDICT,
|
||||
predictions={
|
||||
'probabilities': predictions
|
||||
})
|
||||
|
||||
def test_example(self):
|
||||
with self.test_session() as session:
|
||||
tower_specs = replicate_model_fn._get_loss_towers(
|
||||
self.model_fn,
|
||||
mode=None,
|
||||
features=[[0.1], [0.2]],
|
||||
labels=[[], []],
|
||||
params=None,
|
||||
config=None,
|
||||
devices=['/gpu:0', '/gpu:1'],
|
||||
local_ps_device='/gpu:0',
|
||||
)
|
||||
session.run(variables.global_variables_initializer())
|
||||
|
||||
estimator_spec = replicate_model_fn._predict_spec(
|
||||
tower_specs, aggregation_device='/gpu:0')
|
||||
|
||||
self.assertEqual('/device:GPU:0',
|
||||
estimator_spec.predictions['probabilities'].device)
|
||||
self.assertAllClose({
|
||||
'probabilities': np.array([0.35, 0.35, 0.45, 0.45])
|
||||
}, session.run(estimator_spec.predictions))
|
||||
|
||||
|
||||
class ReduceMetricVariablesTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def create_metric_variable(self, initial_value, name):
|
||||
return variable_scope.variable(
|
||||
initial_value,
|
||||
trainable=False,
|
||||
collections=[ops_lib.GraphKeys.METRIC_VARIABLES],
|
||||
validate_shape=True,
|
||||
name=name)
|
||||
|
||||
def create_tower_metrics(self, tower_id):
|
||||
with variable_scope.variable_scope('', reuse=(tower_id != 0)):
|
||||
self.create_metric_variable(1.3 * (tower_id + 1), 'total')
|
||||
self.create_metric_variable(2.3 * (tower_id + 1), 'count')
|
||||
self.create_metric_variable(
|
||||
np.array([3.3, 3.5, 3.7]) * (tower_id + 1), 'total')
|
||||
|
||||
def test_example(self):
|
||||
with self.test_session() as session:
|
||||
for tower_id in range(3):
|
||||
self.create_tower_metrics(tower_id)
|
||||
|
||||
session.run(
|
||||
variables.variables_initializer(
|
||||
ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
|
||||
|
||||
session.run(
|
||||
replicate_model_fn._reduce_metric_variables(number_of_towers=3))
|
||||
|
||||
# 1st tower = 1.3, 2.3, [3.3, 3.5, 3.7]
|
||||
# 2nd tower = 2.6, 4.6, [6.6, 7.0, 7.4]
|
||||
# 3rd tower = 3.9, 6.9, [9.9, 10.5, 11.1]
|
||||
# Reduced = 7.8, 13.8, [19.8, 21.0, 22.2]
|
||||
# Towers are accumulated in the first tower.
|
||||
local_metrics = session.run(
|
||||
ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
|
||||
|
||||
self.assertNear(7.8, local_metrics[0], 0.01)
|
||||
self.assertNear(13.8, local_metrics[1], 0.01)
|
||||
self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
|
||||
self.assertNear(0.0, local_metrics[3], 0.01)
|
||||
self.assertNear(0.0, local_metrics[4], 0.01)
|
||||
self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
|
||||
self.assertNear(0.0, local_metrics[6], 0.01)
|
||||
self.assertNear(0.0, local_metrics[7], 0.01)
|
||||
self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
|
||||
|
||||
def test_reduce_is_idempotent(self):
|
||||
with self.test_session() as session:
|
||||
for tower_id in range(3):
|
||||
self.create_tower_metrics(tower_id)
|
||||
|
||||
session.run(
|
||||
variables.variables_initializer(
|
||||
ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
|
||||
|
||||
for _ in range(20):
|
||||
session.run(
|
||||
replicate_model_fn._reduce_metric_variables(number_of_towers=3))
|
||||
|
||||
local_metrics = session.run(
|
||||
ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
|
||||
|
||||
self.assertNear(7.8, local_metrics[0], 0.01)
|
||||
self.assertNear(13.8, local_metrics[1], 0.01)
|
||||
self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
|
||||
self.assertNear(0.0, local_metrics[3], 0.01)
|
||||
self.assertNear(0.0, local_metrics[4], 0.01)
|
||||
self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
|
||||
self.assertNear(0.0, local_metrics[6], 0.01)
|
||||
self.assertNear(0.0, local_metrics[7], 0.01)
|
||||
self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
|
||||
|
||||
def test_handles_single_tower(self):
|
||||
with self.test_session() as session:
|
||||
self.create_tower_metrics(0)
|
||||
session.run(
|
||||
variables.variables_initializer(
|
||||
ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
|
||||
|
||||
session.run(
|
||||
replicate_model_fn._reduce_metric_variables(number_of_towers=1))
|
||||
|
||||
local_metrics = session.run(
|
||||
ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
|
||||
|
||||
self.assertNear(1.3, local_metrics[0], 0.01)
|
||||
self.assertNear(2.3, local_metrics[1], 0.01)
|
||||
self.assertAllClose([3.3, 3.5, 3.7], local_metrics[2], 0.01)
|
||||
|
||||
def test_doesnt_accept_uneven_number_of_variables(self):
|
||||
with self.test_session() as session:
|
||||
for tower_id in range(3):
|
||||
self.create_tower_metrics(tower_id)
|
||||
self.create_metric_variable(-1.0, 'oddball')
|
||||
|
||||
session.run(
|
||||
variables.variables_initializer(
|
||||
ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, ''):
|
||||
session.run(
|
||||
replicate_model_fn._reduce_metric_variables(number_of_towers=3))
|
||||
|
||||
|
||||
class MergeExportOutputsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def optimizer_fn(self):
|
||||
return gradient_descent.GradientDescentOptimizer(1.0)
|
||||
|
||||
def model_fn(self, mode, features, labels, params):
|
||||
c = variable_scope.get_variable(
|
||||
'c',
|
||||
initializer=constant_op.constant(10, dtype=dtypes.float64),
|
||||
dtype=dtypes.float64)
|
||||
|
||||
predictions = {'probabilities': math_ops.multiply(features, c)}
|
||||
loss = losses.absolute_difference(
|
||||
labels=labels,
|
||||
predictions=predictions['probabilities'],
|
||||
reduction=losses.Reduction.SUM)
|
||||
|
||||
metrics = {
|
||||
'accuracy': metrics_lib.accuracy(labels, predictions['probabilities']),
|
||||
'auc': metrics_lib.auc(labels, predictions['probabilities'])
|
||||
}
|
||||
tensor_string_repr = str(features)
|
||||
classes = constant_op.constant(
|
||||
re.search('(split_inputs/split:[0-9])', tensor_string_repr).group(1),
|
||||
dtype=dtypes.string)
|
||||
|
||||
export_outputs = {
|
||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
|
||||
export_output.PredictOutput(predictions),
|
||||
'classification_output':
|
||||
export_output.ClassificationOutput(predictions['probabilities'],
|
||||
classes),
|
||||
'classification_scores':
|
||||
export_output.ClassificationOutput(
|
||||
scores=predictions['probabilities']),
|
||||
'classification_classes':
|
||||
export_output.ClassificationOutput(classes=classes),
|
||||
'regression_output':
|
||||
export_output.RegressionOutput(predictions['probabilities']),
|
||||
}
|
||||
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode=mode,
|
||||
loss=math_ops.reduce_sum(loss),
|
||||
eval_metric_ops=metrics,
|
||||
predictions=predictions,
|
||||
train_op=loss, # This train_op isn't actually used.
|
||||
export_outputs=export_outputs)
|
||||
|
||||
def replicate_estimator_spec(self, session):
|
||||
features = np.array([0.01, 0.002])
|
||||
labels = np.array([0.01, 0.02])
|
||||
|
||||
replicated_model_fn = replicate_model_fn.replicate_model_fn(
|
||||
self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
|
||||
estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT,
|
||||
features, labels, {})
|
||||
session.run(variables.global_variables_initializer())
|
||||
return estimator_spec
|
||||
|
||||
def test_merde_predict_output(self):
|
||||
with self.test_session() as session:
|
||||
estimator_spec = self.replicate_estimator_spec(session)
|
||||
self.assertAllClose(
|
||||
{
|
||||
'probabilities': np.array([0.1, 0.02])
|
||||
},
|
||||
session.run(estimator_spec.export_outputs[
|
||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs))
|
||||
|
||||
def test_merge_classification_output_scores_classes(self):
|
||||
with self.test_session() as session:
|
||||
estimator_spec = self.replicate_estimator_spec(session)
|
||||
self.assertAllClose(
|
||||
[0.1, 0.02],
|
||||
session.run(
|
||||
estimator_spec.export_outputs['classification_output'].scores))
|
||||
self.assertAllEqual(
|
||||
[b'split_inputs/split:0', b'split_inputs/split:1'],
|
||||
session.run(
|
||||
estimator_spec.export_outputs['classification_output'].classes))
|
||||
|
||||
def test_merge_classification_output_scores(self):
|
||||
with self.test_session() as session:
|
||||
estimator_spec = self.replicate_estimator_spec(session)
|
||||
self.assertAllClose(
|
||||
[0.1, 0.02],
|
||||
session.run(
|
||||
estimator_spec.export_outputs['classification_scores'].scores))
|
||||
self.assertEqual(
|
||||
None, estimator_spec.export_outputs['classification_scores'].classes)
|
||||
|
||||
def test_merge_classification_output_classes(self):
|
||||
with self.test_session() as session:
|
||||
estimator_spec = self.replicate_estimator_spec(session)
|
||||
self.assertAllEqual(
|
||||
[b'split_inputs/split:0', b'split_inputs/split:1'],
|
||||
session.run(
|
||||
estimator_spec.export_outputs['classification_classes'].classes))
|
||||
self.assertEqual(
|
||||
None, estimator_spec.export_outputs['classification_classes'].scores)
|
||||
|
||||
def test_merge_regression_output(self):
|
||||
with self.test_session() as session:
|
||||
estimator_spec = self.replicate_estimator_spec(session)
|
||||
self.assertAllClose(
|
||||
[0.1, 0.02],
|
||||
session.run(estimator_spec.export_outputs['regression_output'].value))
|
||||
|
||||
|
||||
class GetLocalDevicesTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def test_there_is_at_least_a_cpu(self):
|
||||
self.assertTrue(replicate_model_fn._get_local_devices('CPU'))
|
||||
|
||||
def test_there_is_no_xpu(self):
|
||||
self.assertFalse(
|
||||
replicate_model_fn._get_local_devices('XPU')) # XPU doesn't exist.
|
||||
|
||||
def test_whether_there_is_a_gpu(self):
|
||||
self.assertEqual(
|
||||
len(replicate_model_fn._get_local_devices('GPU')),
|
||||
test.is_gpu_available())
|
||||
|
||||
|
||||
class LocalDeviceSetterTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def test_vars_are_on_ps_but_ops_are_on_workers(self):
|
||||
local_device_setter = replicate_model_fn._local_device_setter(
|
||||
ps_device='/device:GPU:3', worker_device='/device:GPU:2')
|
||||
|
||||
with ops_lib.device(local_device_setter):
|
||||
c = variables.Variable(0.01)
|
||||
self.assertEqual('/device:GPU:3', c.device)
|
||||
|
||||
cc = variables.Variable(0.02)
|
||||
self.assertEqual('/device:GPU:3', cc.device)
|
||||
|
||||
ccc = variables.Variable(0.03)
|
||||
self.assertEqual('/device:GPU:3', ccc.device)
|
||||
|
||||
c_op = array_ops.concat(c, axis=0)
|
||||
self.assertEqual('/device:GPU:2', c_op.device)
|
||||
|
||||
cc_op = array_ops.concat(cc, axis=0)
|
||||
self.assertEqual('/device:GPU:2', cc_op.device)
|
||||
|
||||
|
||||
class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def test_example(self):
|
||||
with self.test_session() as session:
|
||||
total = replicate_model_fn._compute_sum_on_device(
|
||||
[1.0, 2.0, 3.0, 4.0], device='/device:GPU:0', name='test_sum')
|
||||
|
||||
self.assertEqual('/device:GPU:0', total.device)
|
||||
self.assertEqual('test_sum', total.op.name)
|
||||
self.assertEqual(10.0, session.run(total))
|
||||
|
||||
|
||||
class ConcatTensorDictsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def test_example(self):
|
||||
tensor_dicts = [
|
||||
{
|
||||
'a': np.array([1.0, 2.0]),
|
||||
'b': np.array([11.0]),
|
||||
'c': np.array([21.0]),
|
||||
},
|
||||
{
|
||||
'a': np.array([3.0]),
|
||||
'b': np.array([12.0, 13.0]),
|
||||
},
|
||||
{
|
||||
'b': np.array([14.0]),
|
||||
},
|
||||
]
|
||||
|
||||
with self.test_session() as session:
|
||||
self.assertAllClose({
|
||||
'a': np.array([1.0, 2.0, 3.0]),
|
||||
'b': np.array([11.0, 12.0, 13.0, 14.0]),
|
||||
'c': np.array([21.0]),
|
||||
}, session.run(replicate_model_fn._concat_tensor_dicts(*tensor_dicts)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -24,6 +24,7 @@ tf_custom_op_py_library(
|
||||
"python/framework/__init__.py",
|
||||
"python/framework/checkpoint_utils.py",
|
||||
"python/framework/experimental.py",
|
||||
"python/framework/graph_util.py",
|
||||
"python/framework/tensor_util.py",
|
||||
"python/ops/__init__.py",
|
||||
"python/ops/accumulate_n_v2.py",
|
||||
@ -32,6 +33,7 @@ tf_custom_op_py_library(
|
||||
"python/ops/checkpoint_ops.py",
|
||||
"python/ops/ops.py",
|
||||
"python/ops/prettyprint_ops.py",
|
||||
"python/ops/sort_ops.py",
|
||||
"python/ops/variables.py",
|
||||
],
|
||||
dso = [
|
||||
@ -231,6 +233,17 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "graph_util_test",
|
||||
srcs = ["python/framework/graph_util_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tensor_util_test",
|
||||
srcs = ["python/framework/tensor_util_test.py"],
|
||||
@ -307,6 +320,20 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "sort_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["python/ops/sort_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -79,6 +79,8 @@ See the @{$python/contrib.framework} guide.
|
||||
@@load_embedding_initializer
|
||||
@@load_linear_multiclass_bias_initializer
|
||||
@@load_variable_slot_initializer
|
||||
|
||||
@@sort
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.framework.python.framework.checkpoint_utils import *
|
||||
from tensorflow.contrib.framework.python.framework.experimental import experimental
|
||||
from tensorflow.contrib.framework.python.framework.graph_util import *
|
||||
from tensorflow.contrib.framework.python.framework.tensor_util import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util import decorator_utils
|
||||
|
128
tensorflow/contrib/framework/python/framework/graph_util.py
Normal file
128
tensorflow/contrib/framework/python/framework/graph_util.py
Normal file
@ -0,0 +1,128 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Helpers to manipulate a tensor graph in python.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import copy
|
||||
import six
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.framework import node_def_pb2
|
||||
from tensorflow.python.framework.graph_util_impl import _assert_nodes_are_present
|
||||
from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
|
||||
from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
|
||||
from tensorflow.python.framework.graph_util_impl import _node_name
|
||||
|
||||
__all__ = ["fuse_op"]
|
||||
|
||||
|
||||
def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes,
|
||||
output_quantized, op_name, op_type):
|
||||
"""Fuse subgraph between input_nodes and output_nodes into a single custom op.
|
||||
|
||||
Args:
|
||||
graph_def: A graph_pb2.GraphDef proto.
|
||||
input_nodes: input nodes to the subgraph to be fused.
|
||||
output_nodes: output nodes to the subgraph to be fused.
|
||||
output_dtypes: A list of output datatypes for the custom op
|
||||
output_quantized: A boolean flag that indicates if output is quantized
|
||||
op_name: fused op name.
|
||||
op_type: fused op type.
|
||||
Returns:
|
||||
The GraphDef of the new graph.
|
||||
|
||||
Raises:
|
||||
TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
|
||||
"""
|
||||
|
||||
if not isinstance(graph_def, graph_pb2.GraphDef):
|
||||
raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")
|
||||
|
||||
if isinstance(input_nodes, six.string_types):
|
||||
raise TypeError("input_nodes must be a list.")
|
||||
|
||||
if isinstance(output_nodes, six.string_types):
|
||||
raise TypeError("output_nodes must be a list.")
|
||||
|
||||
name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
|
||||
graph_def)
|
||||
_assert_nodes_are_present(name_to_node, input_nodes + output_nodes)
|
||||
|
||||
# Nodes upto and including input_nodes
|
||||
reachable_by_input = _bfs_for_reachable_nodes(input_nodes, name_to_input_name)
|
||||
# Nodes upto and including output_nodes
|
||||
reachable_by_output = _bfs_for_reachable_nodes(output_nodes,
|
||||
name_to_input_name)
|
||||
|
||||
# Set of nodes in the list input_nodes
|
||||
input_nodes_set = set(input_nodes)
|
||||
|
||||
# Set of nodes in the list output_nodes
|
||||
output_nodes_set = set(output_nodes)
|
||||
|
||||
nodes_post_output = []
|
||||
for node in graph_def.node:
|
||||
n = _node_name(node.name)
|
||||
if n in reachable_by_output:
|
||||
if n not in reachable_by_input and n not in output_nodes_set:
|
||||
# n is between input and output, i.e., part of the fused op
|
||||
next_to_visit = [n]
|
||||
while next_to_visit:
|
||||
cur_node = next_to_visit[0]
|
||||
del next_to_visit[0]
|
||||
if cur_node in reachable_by_input and cur_node not in input_nodes_set:
|
||||
raise TypeError("Node %s uses input %s not in input_nodes." %
|
||||
(n, cur_node))
|
||||
if cur_node not in input_nodes_set:
|
||||
next_to_visit += name_to_input_name[cur_node]
|
||||
else:
|
||||
nodes_post_output.append(n)
|
||||
|
||||
# Add all nodes upto the input nodes
|
||||
out = graph_pb2.GraphDef()
|
||||
reachable_by_input_sorted = sorted(
|
||||
list(reachable_by_input), key=lambda n: name_to_seq_num[n])
|
||||
for node in reachable_by_input_sorted:
|
||||
out.node.extend([copy.deepcopy(name_to_node[node])])
|
||||
|
||||
# Add the custom op
|
||||
new_node = node_def_pb2.NodeDef()
|
||||
for node in input_nodes:
|
||||
new_node.input.append(node)
|
||||
new_node.attr["_output_types"].list.type[:] = output_dtypes
|
||||
new_node.attr["_output_quantized"].b = output_quantized
|
||||
new_node.op = op_type
|
||||
new_node.name = op_name
|
||||
out.node.extend([new_node])
|
||||
|
||||
# Add the nodes in the output of the custom op
|
||||
for index, n in enumerate(output_nodes):
|
||||
assert len(name_to_node[n].input) == 1
|
||||
new_node = copy.deepcopy(name_to_node[n])
|
||||
del new_node.input[:]
|
||||
new_node.input.append(op_name + (":" + str(index) if index != 0 else ""))
|
||||
out.node.extend([new_node])
|
||||
|
||||
# Add the nodes post output_nodes
|
||||
for n in nodes_post_output:
|
||||
out.node.extend([copy.deepcopy(name_to_node[n])])
|
||||
|
||||
out.library.CopyFrom(graph_def.library)
|
||||
out.versions.CopyFrom(graph_def.versions)
|
||||
return out
|
@ -0,0 +1,61 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""@graph_util tests."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework.python.framework import graph_util
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.framework import node_def_pb2
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def GetNewNode(name, op, input_nodes):
|
||||
new_node = node_def_pb2.NodeDef()
|
||||
new_node.op = op
|
||||
new_node.name = name
|
||||
for node in input_nodes:
|
||||
new_node.input.append(node)
|
||||
return new_node
|
||||
|
||||
|
||||
class GraphUtilTest(test.TestCase):
|
||||
|
||||
def testGraphUtil(self):
|
||||
graph_def = graph_pb2.GraphDef()
|
||||
node_a = GetNewNode('A', 'Placeholder', [])
|
||||
node_b = GetNewNode('B', 'Op1', ['A'])
|
||||
node_c = GetNewNode('C', 'Op1', ['B'])
|
||||
node_d = GetNewNode('D', 'Op1', ['C'])
|
||||
node_e = GetNewNode('E', 'Op1', ['D'])
|
||||
graph_def.node.extend([node_a, node_b, node_c, node_d, node_e])
|
||||
fused_graph_def = graph_util.fuse_op(
|
||||
graph_def, ['A'], ['D'], [types_pb2.DT_FLOAT], True, 'FusedOp', 'Op2')
|
||||
self.assertEqual(len(fused_graph_def.node), 4)
|
||||
self.assertEqual(fused_graph_def.node[0].name, 'A')
|
||||
self.assertEqual(fused_graph_def.node[1].name, 'FusedOp')
|
||||
self.assertEqual(fused_graph_def.node[1].input[0], 'A')
|
||||
self.assertEqual(fused_graph_def.node[1].op, 'Op2')
|
||||
self.assertEqual(fused_graph_def.node[1].attr['_output_quantized'].b, True)
|
||||
self.assertEqual(fused_graph_def.node[1].attr['_output_types'].list.type,
|
||||
[types_pb2.DT_FLOAT])
|
||||
self.assertEqual(fused_graph_def.node[2].name, 'D')
|
||||
self.assertEqual(fused_graph_def.node[3].name, 'E')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -24,5 +24,6 @@ from tensorflow.contrib.framework.python.ops.arg_scope import *
|
||||
from tensorflow.contrib.framework.python.ops.checkpoint_ops import *
|
||||
from tensorflow.contrib.framework.python.ops.ops import *
|
||||
from tensorflow.contrib.framework.python.ops.prettyprint_ops import *
|
||||
from tensorflow.contrib.framework.python.ops.sort_ops import *
|
||||
from tensorflow.contrib.framework.python.ops.variables import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
113
tensorflow/contrib/framework/python/ops/sort_ops.py
Normal file
113
tensorflow/contrib/framework/python/ops/sort_ops.py
Normal file
@ -0,0 +1,113 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Support for sorting tensors.
|
||||
|
||||
@@sort
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops as framework_ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
|
||||
|
||||
def sort(values, axis=-1, direction='ASCENDING', name=None):
|
||||
"""Sorts a tensor.
|
||||
|
||||
Args:
|
||||
values: 1-D or higher numeric `Tensor`.
|
||||
axis: The axis along which to sort. The default is -1, which sorts the last
|
||||
axis.
|
||||
direction: The direction in which to sort the values (`'ASCENDING'` or
|
||||
`'DESCENDING'`).
|
||||
name: Optional name for the operation.
|
||||
|
||||
Returns:
|
||||
A `Tensor` with the same dtype and shape as `values`, with the elements
|
||||
sorted along the given `axis`.
|
||||
|
||||
Raises:
|
||||
ValueError: If axis is not a constant scalar, or the direction is invalid.
|
||||
"""
|
||||
with framework_ops.name_scope(name, 'sort'):
|
||||
if direction not in _SORT_IMPL:
|
||||
raise ValueError('%s should be one of %s' %
|
||||
(direction, ', '.join(sorted(_SORT_IMPL.keys()))))
|
||||
# Axis must be an integer, not a Tensor.
|
||||
axis = framework_ops.convert_to_tensor(axis, name='axis')
|
||||
axis_static = tensor_util.constant_value(axis)
|
||||
if axis.shape.ndims != 0 or axis_static is None:
|
||||
raise ValueError('axis must be a constant scalar')
|
||||
axis_static = int(axis_static) # Avoids NumPy casting error
|
||||
|
||||
values = framework_ops.convert_to_tensor(values, name='values')
|
||||
|
||||
return _SORT_IMPL[direction](values, axis_static)
|
||||
|
||||
|
||||
def _descending_sort(values, axis):
|
||||
"""Sorts values in reverse using `top_k`.
|
||||
|
||||
Args:
|
||||
values: Tensor of numeric values.
|
||||
axis: Index of the axis which values should be sorted along.
|
||||
|
||||
Returns:
|
||||
The sorted values.
|
||||
"""
|
||||
k = array_ops.shape(values)[axis]
|
||||
rank = array_ops.rank(values)
|
||||
# Fast path: sorting the last axis.
|
||||
if axis == -1 or axis + 1 == values.get_shape().ndims:
|
||||
return nn_ops.top_k(values, k)[0]
|
||||
|
||||
# Otherwise, transpose the array. Swap axes `axis` and `rank - 1`.
|
||||
if axis < 0:
|
||||
# Make axis a Tensor with the real axis index if needed.
|
||||
axis += rank
|
||||
transposition = array_ops.concat(
|
||||
[
|
||||
# Axes up to axis are unchanged.
|
||||
math_ops.range(axis),
|
||||
# Swap axis and rank - 1.
|
||||
[rank - 1],
|
||||
# Axes in [axis + 1, rank - 1) are unchanged.
|
||||
math_ops.range(axis + 1, rank - 1),
|
||||
# Swap axis and rank - 1.
|
||||
[axis]
|
||||
],
|
||||
axis=0)
|
||||
top_k_input = array_ops.transpose(values, transposition)
|
||||
values, unused_indices = nn_ops.top_k(top_k_input, k)
|
||||
# transposition contains a single cycle of length 2 (swapping 2 elements),
|
||||
# so it is an involution (it is its own inverse).
|
||||
return array_ops.transpose(values, transposition)
|
||||
|
||||
|
||||
def _ascending_sort(values, axis):
|
||||
# Negate the values to get the ascending order from descending sort.
|
||||
values_or_indices = _descending_sort(-values, axis)
|
||||
return -values_or_indices
|
||||
|
||||
|
||||
_SORT_IMPL = {
|
||||
'ASCENDING': _ascending_sort,
|
||||
'DESCENDING': _descending_sort,
|
||||
}
|
95
tensorflow/contrib/framework/python/ops/sort_ops_test.py
Normal file
95
tensorflow/contrib/framework/python/ops/sort_ops_test.py
Normal file
@ -0,0 +1,95 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the sort wrapper."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import sort_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class SortTest(test.TestCase):
|
||||
|
||||
def testRandom_lowDimensionality(self):
|
||||
self._testRandom_lowDimensionality(negative_axis=False)
|
||||
|
||||
def testRandom_lowDimensionality_negative(self):
|
||||
self._testRandom_lowDimensionality(negative_axis=True)
|
||||
|
||||
def _testRandom_lowDimensionality(self, negative_axis):
|
||||
np.random.seed(42)
|
||||
for _ in range(20):
|
||||
rank = np.random.randint(1, 3)
|
||||
shape = [np.random.randint(0, 20) for _ in range(rank)]
|
||||
arr = np.random.random(shape)
|
||||
sort_axis = np.random.choice(rank)
|
||||
if negative_axis:
|
||||
sort_axis = -1 - sort_axis
|
||||
with self.test_session():
|
||||
self.assertAllEqual(
|
||||
np.sort(arr, axis=sort_axis),
|
||||
sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval())
|
||||
|
||||
def testRandom_highDimensionality(self):
|
||||
np.random.seed(100)
|
||||
for _ in range(20):
|
||||
rank = np.random.randint(5, 15)
|
||||
shape = [np.random.randint(1, 4) for _ in range(rank)]
|
||||
arr = np.random.random(shape)
|
||||
sort_axis = np.random.choice(rank)
|
||||
with self.test_session():
|
||||
self.assertAllEqual(
|
||||
np.sort(arr, axis=sort_axis),
|
||||
sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval())
|
||||
|
||||
def testScalar(self):
|
||||
# Create an empty scalar where the static shape is unknown.
|
||||
zeros_length_1 = array_ops.zeros(
|
||||
random_ops.random_uniform([1], minval=0, maxval=1, dtype=dtypes.int32),
|
||||
dtype=dtypes.int32)
|
||||
scalar = array_ops.zeros(zeros_length_1)
|
||||
|
||||
sort = sort_ops.sort(scalar)
|
||||
with self.test_session():
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sort.eval()
|
||||
|
||||
def testNegativeOutOfBounds_staticShape(self):
|
||||
arr = constant_op.constant([3, 4, 5])
|
||||
with self.assertRaises(ValueError):
|
||||
sort_ops.sort(arr, axis=-4)
|
||||
|
||||
def testDescending(self):
|
||||
arr = np.random.random((10, 5, 5))
|
||||
with self.test_session():
|
||||
self.assertAllEqual(
|
||||
np.sort(arr, axis=0)[::-1],
|
||||
sort_ops.sort(
|
||||
constant_op.constant(arr),
|
||||
axis=0,
|
||||
direction='DESCENDING').eval())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -246,8 +246,8 @@ def image(name, tensor, bad_color=None, max_images=3, family=None):
|
||||
"""Writes an image summary if possible."""
|
||||
|
||||
def function(tag, scope):
|
||||
if bad_color is None:
|
||||
bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8)
|
||||
bad_color_ = (constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8)
|
||||
if bad_color is None else bad_color)
|
||||
# Note the identity to move the tensor to the CPU.
|
||||
return gen_summary_ops.write_image_summary(
|
||||
context.context().summary_writer_resource,
|
||||
|
@ -95,3 +95,10 @@ tf_proto_library_cc(
|
||||
cc_api_version = 2,
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
tf_proto_library_cc(
|
||||
name = "tf_op_stats_proto",
|
||||
srcs = ["tf_op_stats.proto"],
|
||||
cc_api_version = 2,
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
127
tensorflow/contrib/tpu/profiler/tf_op_stats.proto
Normal file
127
tensorflow/contrib/tpu/profiler/tf_op_stats.proto
Normal file
@ -0,0 +1,127 @@
|
||||
// This proto describes the format of tensorflow operation level stats for
|
||||
// profiling (in tensorboard) purpose.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package tensorflow.tpu;
|
||||
|
||||
// Result proto for OpMetrics.
|
||||
message OpMetricsResult {
|
||||
// True if this OP is executed on the device; False if it is executed on the
|
||||
// host.
|
||||
optional bool on_device = 1;
|
||||
reserved 2; // was uint32 id.
|
||||
// Name of this OP.
|
||||
optional string name = 3;
|
||||
// Rank of this OP.
|
||||
optional uint64 rank = 4;
|
||||
// The starting time in cycles of the last instance of this OP executed.
|
||||
optional double last_starttime_in_cycles = 5;
|
||||
// The ending time in cycles of the last instance of this OP executed.
|
||||
optional double last_endtime_in_cycles = 6;
|
||||
// If this OP (say A), is an immediate child of another OP (say B), this field
|
||||
// stores the sum of duration in microseconds of A inside B. If A appears more
|
||||
// than once in B, the duration of all A's appearances will be added together.
|
||||
// This sum will be reset after the self-time of B is calculated so that it
|
||||
// can be reused for a new parent OP.
|
||||
optional double sum_of_duration_in_us_as_children = 7;
|
||||
// Number of instances that this OP occurred.
|
||||
optional uint64 occurrences = 8;
|
||||
// Total time in microseconds spent in this OP (accumulated
|
||||
// over all of its occurrences).
|
||||
optional double total_time_in_us = 9;
|
||||
// Total self time in microseconds spent in this OP
|
||||
// (accumulated over all of its occurrences).
|
||||
optional double total_self_time_in_us = 10;
|
||||
// The total self time as a fraction of sum of all OP's
|
||||
// total self time on the host.
|
||||
optional double host_total_self_time_as_fraction_of_all_op_time = 11;
|
||||
// Cumulative total self time in fraction on the host.
|
||||
optional double host_cumulative_total_self_time_as_fraction_of_all_op_time =
|
||||
12;
|
||||
// The total self time as a fraction of sum of all OP's
|
||||
// total self time on the device.
|
||||
optional double device_total_self_time_as_fraction_of_all_op_time = 13;
|
||||
// Cumulative total self time in fraction on the device.
|
||||
optional double device_cumulative_total_self_time_as_fraction_of_all_op_time =
|
||||
14;
|
||||
// Total number of FLOPs incurred by this OP.
|
||||
optional double total_flops = 15;
|
||||
// Total time in microseconds that the MXU is occupied by this OP.
|
||||
optional double total_bytes_accessed = 16;
|
||||
// Total time in microseconds that the MXU is occupied by this OP.
|
||||
optional double mxu_occupancy_in_us = 17;
|
||||
// Total time in microseconds that the XU is occupied by this OP.
|
||||
optional double xu_occupancy_in_us = 18;
|
||||
// Total DMA access stall time in microseconds.
|
||||
optional double total_dma_stall_in_us = 19;
|
||||
}
|
||||
|
||||
// Result proto for OpMetricsDb.
|
||||
message OpMetricsDbResult {
|
||||
// A bunch of OpMetricsResults.
|
||||
repeated OpMetricsResult metrics_db = 1;
|
||||
}
|
||||
|
||||
// Result proto for StepInfo.
|
||||
message StepInfoResult {
|
||||
// The (micro) step number.
|
||||
optional uint32 step_num = 1;
|
||||
// The step duration in picoseconds.
|
||||
optional uint64 duration_ps = 2;
|
||||
// The infeed duration in picoseconds.
|
||||
// Can turn into a map if we want a variable number of ops.
|
||||
optional uint64 infeed_duration_ps = 3;
|
||||
}
|
||||
|
||||
// Result proto for a sequence of steps.
|
||||
message StepSequenceResult {
|
||||
// A sequence of StepInfoResults.
|
||||
repeated StepInfoResult step_sequence = 1;
|
||||
}
|
||||
|
||||
// Result proto for a StepDatabase.
|
||||
message StepDatabaseResult {
|
||||
// A map from core_id to StepSequenceResult.
|
||||
map<uint32, StepSequenceResult> step_sequence_per_core = 1;
|
||||
}
|
||||
|
||||
// Result proto for Dashboard data.
|
||||
message DashboardResult {
|
||||
// The total iteration time in nanoseconds.
|
||||
optional double iteration_time_ns = 1;
|
||||
// The total number of iterations.
|
||||
optional int32 num_iterations = 2;
|
||||
// The total computation time in nanoseconds.
|
||||
optional double computation_time_ns = 3;
|
||||
// The total number of computations.
|
||||
optional int32 num_computations = 4;
|
||||
}
|
||||
|
||||
// Result proto for HloExtraInfo.
|
||||
message HloExtraInfoResult {
|
||||
// Category of the HLO op given by the compiler.
|
||||
optional string category = 1;
|
||||
// The long name of the HLO that includes the dimensions.
|
||||
optional string long_name = 2;
|
||||
}
|
||||
|
||||
// Result proto for HloExtraInfoMap.
|
||||
message HloExtraInfoMapResult {
|
||||
// A map from HLO name to HloExtraInfo.
|
||||
map<string, HloExtraInfoResult> hlo_extrainfo_map = 1;
|
||||
}
|
||||
|
||||
// Result proto for TfStatsHelper.
|
||||
message TfOpStats {
|
||||
// The result for the TF-metric database.
|
||||
optional OpMetricsDbResult tf_metrics_db = 1;
|
||||
// The result for the HLO-metric database.
|
||||
optional OpMetricsDbResult hlo_metrics_db = 2;
|
||||
// The result for the step database.
|
||||
optional StepDatabaseResult step_db = 3;
|
||||
// The result for the TPU dashboard.
|
||||
optional DashboardResult dashboard = 4;
|
||||
// The result for the HloExtraInfoMap.
|
||||
optional HloExtraInfoMapResult hlo_extrainfo_map = 5;
|
||||
}
|
@ -66,7 +66,7 @@ _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum'
|
||||
_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY]
|
||||
|
||||
# TODO(b/65703635): Flip the value and remove all dead code.
|
||||
_WRAP_INPUT_FN_INTO_WHILE_LOOP = False
|
||||
_WRAP_INPUT_FN_INTO_WHILE_LOOP = True
|
||||
|
||||
|
||||
def _create_global_step(graph):
|
||||
|
@ -1414,16 +1414,19 @@ LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [
|
||||
"platform/tracing.h",
|
||||
]
|
||||
|
||||
# Replicated for lib_internal and lib_internal_impl.
|
||||
LIB_INTERNAL_DEFINES = (tf_additional_lib_defines() + [
|
||||
"TF_USE_SNAPPY",
|
||||
] + tf_additional_verbs_lib_defines() +
|
||||
tf_additional_mpi_lib_defines() +
|
||||
tf_additional_gdr_lib_defines())
|
||||
|
||||
cc_library(
|
||||
name = "lib_internal",
|
||||
srcs = LIB_INTERNAL_PRIVATE_HEADERS,
|
||||
hdrs = LIB_INTERNAL_PUBLIC_HEADERS,
|
||||
copts = tf_copts(),
|
||||
defines = tf_additional_lib_defines() + [
|
||||
"TF_USE_SNAPPY",
|
||||
] + tf_additional_verbs_lib_defines() +
|
||||
tf_additional_mpi_lib_defines() +
|
||||
tf_additional_gdr_lib_defines(),
|
||||
defines = LIB_INTERNAL_DEFINES,
|
||||
linkopts = select({
|
||||
"//tensorflow:freebsd": [],
|
||||
"//tensorflow:windows": [],
|
||||
@ -1477,6 +1480,7 @@ cc_library(
|
||||
),
|
||||
hdrs = LIB_INTERNAL_PUBLIC_HEADERS,
|
||||
copts = tf_copts(),
|
||||
defines = LIB_INTERNAL_DEFINES,
|
||||
deps = tf_additional_lib_deps() + [
|
||||
":lib_hash_crc32c_accelerate_internal",
|
||||
":lib_proto_parsing",
|
||||
|
@ -46,92 +46,218 @@ constexpr char kDefaultApiDefDir[] =
|
||||
"tensorflow/core/api_def/base_api";
|
||||
constexpr char kOverridesFilePath[] =
|
||||
"tensorflow/cc/ops/op_gen_overrides.pbtxt";
|
||||
constexpr char kApiDefFileFormat[] = "api_def_%c.pbtxt";
|
||||
constexpr char kAlphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
|
||||
constexpr char kApiDefFileFormat[] = "api_def_%s.pbtxt";
|
||||
constexpr char kApiDefFilePattern[] = "api_def_*.pbtxt";
|
||||
|
||||
// Get map from first character to ApiDefs for ops
|
||||
// that start with that character.
|
||||
std::unordered_map<char, ApiDefs> GenerateApiDef(
|
||||
const OpList& ops, const OpGenOverrides& overrides) {
|
||||
void FillBaseApiDef(ApiDef* api_def, const OpDef& op) {
|
||||
api_def->set_graph_op_name(op.name());
|
||||
// Add arg docs
|
||||
for (auto& input_arg : op.input_arg()) {
|
||||
if (!input_arg.description().empty()) {
|
||||
auto* api_def_in_arg = api_def->add_in_arg();
|
||||
api_def_in_arg->set_name(input_arg.name());
|
||||
api_def_in_arg->set_description(input_arg.description());
|
||||
}
|
||||
}
|
||||
for (auto& output_arg : op.output_arg()) {
|
||||
if (!output_arg.description().empty()) {
|
||||
auto* api_def_out_arg = api_def->add_out_arg();
|
||||
api_def_out_arg->set_name(output_arg.name());
|
||||
api_def_out_arg->set_description(output_arg.description());
|
||||
}
|
||||
}
|
||||
// Add attr docs
|
||||
for (auto& attr : op.attr()) {
|
||||
if (!attr.description().empty()) {
|
||||
auto* api_def_attr = api_def->add_attr();
|
||||
api_def_attr->set_name(attr.name());
|
||||
api_def_attr->set_description(attr.description());
|
||||
}
|
||||
}
|
||||
// Add docs
|
||||
api_def->set_summary(op.summary());
|
||||
api_def->set_description(op.description());
|
||||
}
|
||||
|
||||
// Checks if arg1 should be before arg2 according to ordering in args.
|
||||
bool CheckArgBefore(const ApiDef::Arg* arg1, const ApiDef::Arg* arg2,
|
||||
const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
|
||||
for (auto& arg : args) {
|
||||
if (arg.name() == arg2->name()) {
|
||||
return false;
|
||||
} else if (arg.name() == arg1->name()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Checks if attr1 should be before attr2 according to ordering in op_def.
|
||||
bool CheckAttrBefore(const ApiDef::Attr* attr1, const ApiDef::Attr* attr2,
|
||||
const OpDef& op_def) {
|
||||
for (auto& attr : op_def.attr()) {
|
||||
if (attr.name() == attr2->name()) {
|
||||
return false;
|
||||
} else if (attr.name() == attr1->name()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Applies renames to args.
|
||||
void ApplyArgOverrides(
|
||||
protobuf::RepeatedPtrField<ApiDef::Arg>* args,
|
||||
const protobuf::RepeatedPtrField<OpGenOverride::Rename>& renames,
|
||||
const protobuf::RepeatedPtrField<OpDef::ArgDef>& op_args,
|
||||
const string& op_name) {
|
||||
for (auto& rename : renames) {
|
||||
// First check if rename is valid.
|
||||
bool valid = false;
|
||||
for (const auto& op_arg : op_args) {
|
||||
if (op_arg.name() == rename.from()) {
|
||||
valid = true;
|
||||
}
|
||||
}
|
||||
QCHECK(valid) << rename.from() << " is not a valid argument for "
|
||||
<< op_name;
|
||||
bool found_arg = false;
|
||||
// If Arg is already in ApiDef, just update it.
|
||||
for (int i = 0; i < args->size(); ++i) {
|
||||
auto* arg = args->Mutable(i);
|
||||
if (arg->name() == rename.from()) {
|
||||
arg->set_rename_to(rename.to());
|
||||
found_arg = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found_arg) { // not in ApiDef, add a new arg.
|
||||
auto* new_arg = args->Add();
|
||||
new_arg->set_name(rename.from());
|
||||
new_arg->set_rename_to(rename.to());
|
||||
}
|
||||
}
|
||||
// We don't really need a specific order here right now.
|
||||
// However, it is clearer if order follows OpDef.
|
||||
std::sort(args->pointer_begin(), args->pointer_end(),
|
||||
[&](ApiDef::Arg* arg1, ApiDef::Arg* arg2) {
|
||||
return CheckArgBefore(arg1, arg2, op_args);
|
||||
});
|
||||
}
|
||||
|
||||
// Returns existing attribute with the given name if such
|
||||
// attribute exists. Otherwise, adds a new attribute and returns it.
|
||||
ApiDef::Attr* FindOrAddAttr(ApiDef* api_def, const string attr_name) {
|
||||
// If Attr is already in ApiDef, just update it.
|
||||
for (int i = 0; i < api_def->attr_size(); ++i) {
|
||||
auto* attr = api_def->mutable_attr(i);
|
||||
if (attr->name() == attr_name) {
|
||||
return attr;
|
||||
}
|
||||
}
|
||||
// Add a new Attr.
|
||||
auto* new_attr = api_def->add_attr();
|
||||
new_attr->set_name(attr_name);
|
||||
return new_attr;
|
||||
}
|
||||
|
||||
// Applies renames and default values to attributes.
|
||||
void ApplyAttrOverrides(ApiDef* api_def, const OpGenOverride& op_override,
|
||||
const OpDef& op_def) {
|
||||
for (auto& attr_rename : op_override.attr_rename()) {
|
||||
auto* attr = FindOrAddAttr(api_def, attr_rename.from());
|
||||
attr->set_rename_to(attr_rename.to());
|
||||
}
|
||||
|
||||
for (auto& attr_default : op_override.attr_default()) {
|
||||
auto* attr = FindOrAddAttr(api_def, attr_default.name());
|
||||
*(attr->mutable_default_value()) = attr_default.value();
|
||||
}
|
||||
// We don't really need a specific order here right now.
|
||||
// However, it is clearer if order follows OpDef.
|
||||
std::sort(api_def->mutable_attr()->pointer_begin(),
|
||||
api_def->mutable_attr()->pointer_end(),
|
||||
[&](ApiDef::Attr* attr1, ApiDef::Attr* attr2) {
|
||||
return CheckAttrBefore(attr1, attr2, op_def);
|
||||
});
|
||||
}
|
||||
|
||||
void ApplyOverridesToApiDef(ApiDef* api_def, const OpDef& op,
|
||||
const OpGenOverride& op_override) {
|
||||
// Fill ApiDef with data based on op and op_override.
|
||||
// Set visibility
|
||||
if (op_override.skip()) {
|
||||
api_def->set_visibility(ApiDef_Visibility_SKIP);
|
||||
} else if (op_override.hide()) {
|
||||
api_def->set_visibility(ApiDef_Visibility_HIDDEN);
|
||||
}
|
||||
// Add endpoints
|
||||
if (!op_override.rename_to().empty()) {
|
||||
api_def->add_endpoint()->set_name(op_override.rename_to());
|
||||
} else if (!op_override.alias().empty()) {
|
||||
api_def->add_endpoint()->set_name(op.name());
|
||||
}
|
||||
|
||||
for (auto& alias : op_override.alias()) {
|
||||
auto* endpoint = api_def->add_endpoint();
|
||||
endpoint->set_name(alias);
|
||||
}
|
||||
|
||||
ApplyArgOverrides(api_def->mutable_in_arg(), op_override.input_rename(),
|
||||
op.input_arg(), api_def->graph_op_name());
|
||||
ApplyArgOverrides(api_def->mutable_out_arg(), op_override.output_rename(),
|
||||
op.output_arg(), api_def->graph_op_name());
|
||||
ApplyAttrOverrides(api_def, op_override, op);
|
||||
}
|
||||
|
||||
// Get map from ApiDef file path to corresponding ApiDefs proto.
|
||||
std::unordered_map<string, ApiDefs> GenerateApiDef(
|
||||
const string& api_def_dir, const OpList& ops,
|
||||
const OpGenOverrides& overrides) {
|
||||
std::unordered_map<string, OpGenOverride> name_to_override;
|
||||
for (const auto& op_override : overrides.op()) {
|
||||
name_to_override[op_override.name()] = op_override;
|
||||
}
|
||||
|
||||
std::unordered_map<char, ApiDefs> api_defs_map;
|
||||
std::unordered_map<string, ApiDefs> api_defs_map;
|
||||
|
||||
for (const auto& op : ops.op()) {
|
||||
CHECK(!op.name().empty())
|
||||
<< "Encountered empty op name: %s" << op.DebugString();
|
||||
const char file_id = toupper(op.name()[0]);
|
||||
CHECK(isalpha(file_id)) << "Unexpected op name: " << op.name();
|
||||
ApiDef* api_def = api_defs_map[file_id].add_op();
|
||||
api_def->set_graph_op_name(op.name());
|
||||
string file_path = io::JoinPath(api_def_dir, kApiDefFileFormat);
|
||||
file_path = strings::Printf(file_path.c_str(), op.name().c_str());
|
||||
ApiDef* api_def = api_defs_map[file_path].add_op();
|
||||
FillBaseApiDef(api_def, op);
|
||||
|
||||
if (name_to_override.find(op.name()) != name_to_override.end()) {
|
||||
const auto& op_override = name_to_override[op.name()];
|
||||
// Set visibility
|
||||
if (op_override.skip()) {
|
||||
api_def->set_visibility(ApiDef_Visibility_SKIP);
|
||||
} else if (op_override.hide()) {
|
||||
api_def->set_visibility(ApiDef_Visibility_HIDDEN);
|
||||
}
|
||||
// Add endpoints
|
||||
if (!op_override.rename_to().empty()) {
|
||||
auto* endpoint = api_def->add_endpoint();
|
||||
endpoint->set_name(op_override.rename_to());
|
||||
} else {
|
||||
auto* endpoint = api_def->add_endpoint();
|
||||
endpoint->set_name(op.name());
|
||||
}
|
||||
for (auto& alias : op_override.alias()) {
|
||||
auto* endpoint = api_def->add_endpoint();
|
||||
endpoint->set_name(alias);
|
||||
}
|
||||
// Add attributes
|
||||
for (auto& attr : op.attr()) {
|
||||
auto* api_def_attr = api_def->add_attr();
|
||||
api_def_attr->set_name(attr.name());
|
||||
for (auto& attr_override : op_override.attr_default()) {
|
||||
if (attr.name() == attr_override.name()) {
|
||||
*(api_def_attr->mutable_default_value()) = attr_override.value();
|
||||
}
|
||||
}
|
||||
for (auto& attr_rename : op_override.attr_rename()) {
|
||||
if (attr.name() == attr_rename.from()) {
|
||||
api_def_attr->set_rename_to(attr_rename.to());
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto* endpoint = api_def->add_endpoint();
|
||||
endpoint->set_name(op.name());
|
||||
ApplyOverridesToApiDef(api_def, op, name_to_override[op.name()]);
|
||||
}
|
||||
// Add docs
|
||||
api_def->set_summary(op.summary());
|
||||
api_def->set_description(op.description());
|
||||
}
|
||||
return api_defs_map;
|
||||
}
|
||||
|
||||
// Reads golden api defs file with the given suffix.
|
||||
string GetGoldenApiDefsStr(Env* env, const string& api_files_dir, char suffix) {
|
||||
string file_path = strings::Printf(
|
||||
io::JoinPath(api_files_dir, kApiDefFileFormat).c_str(), suffix);
|
||||
if (env->FileExists(file_path).ok()) {
|
||||
// Reads golden ApiDef files and returns a map from file name to ApiDef file
|
||||
// contents.
|
||||
std::unordered_map<string, string> GetGoldenApiDefs(
|
||||
Env* env, const string& api_files_dir) {
|
||||
std::vector<string> matching_paths;
|
||||
TF_CHECK_OK(env->GetMatchingPaths(
|
||||
io::JoinPath(api_files_dir, kApiDefFilePattern), &matching_paths));
|
||||
|
||||
std::unordered_map<string, string> file_path_to_api_def;
|
||||
for (auto& file_path : matching_paths) {
|
||||
string file_contents;
|
||||
TF_EXPECT_OK(ReadFileToString(env, file_path, &file_contents));
|
||||
return file_contents;
|
||||
TF_CHECK_OK(ReadFileToString(env, file_path, &file_contents));
|
||||
file_path_to_api_def[file_path] = file_contents;
|
||||
}
|
||||
return "";
|
||||
return file_path_to_api_def;
|
||||
}
|
||||
|
||||
void RunApiTest(bool update_api_def, const string& api_files_dir) {
|
||||
// Read C++ overrides file
|
||||
string overrides_file_contents;
|
||||
OpGenOverrides overrides;
|
||||
Env* env = Env::Default();
|
||||
TF_EXPECT_OK(
|
||||
ReadFileToString(env, kOverridesFilePath, &overrides_file_contents));
|
||||
TF_EXPECT_OK(ReadTextProto(env, kOverridesFilePath, &overrides));
|
||||
|
||||
// Read all ops
|
||||
OpList ops;
|
||||
@ -139,29 +265,22 @@ void RunApiTest(bool update_api_def, const string& api_files_dir) {
|
||||
const std::vector<string> multi_line_fields = {"description"};
|
||||
|
||||
// Get expected ApiDefs
|
||||
OpGenOverrides overrides;
|
||||
auto new_api_defs_map = GenerateApiDef(ops, overrides);
|
||||
const auto new_api_defs_map = GenerateApiDef(api_files_dir, ops, overrides);
|
||||
|
||||
bool updated_at_least_one_file = false;
|
||||
const auto golden_api_defs_map = GetGoldenApiDefs(env, api_files_dir);
|
||||
|
||||
for (char c : kAlphabet) {
|
||||
string golden_api_defs_str = GetGoldenApiDefsStr(env, api_files_dir, c);
|
||||
string new_api_defs_str = new_api_defs_map[c].DebugString();
|
||||
for (auto new_api_entry : new_api_defs_map) {
|
||||
const auto& file_path = new_api_entry.first;
|
||||
const auto& golden_api_defs_str = golden_api_defs_map.at(file_path);
|
||||
string new_api_defs_str = new_api_entry.second.DebugString();
|
||||
new_api_defs_str = PBTxtToMultiline(new_api_defs_str, multi_line_fields);
|
||||
if (golden_api_defs_str == new_api_defs_str) {
|
||||
continue;
|
||||
}
|
||||
if (update_api_def) {
|
||||
string output_file_path =
|
||||
io::JoinPath(api_files_dir, strings::Printf(kApiDefFileFormat, c));
|
||||
if (new_api_defs_str.empty()) {
|
||||
std::cout << "Deleting " << output_file_path << "..." << std::endl;
|
||||
TF_EXPECT_OK(env->DeleteFile(output_file_path));
|
||||
} else {
|
||||
std::cout << "Updating " << output_file_path << "..." << std::endl;
|
||||
TF_EXPECT_OK(
|
||||
WriteStringToFile(env, output_file_path, new_api_defs_str));
|
||||
}
|
||||
std::cout << "Updating " << file_path << "..." << std::endl;
|
||||
TF_EXPECT_OK(WriteStringToFile(env, file_path, new_api_defs_str));
|
||||
updated_at_least_one_file = true;
|
||||
} else {
|
||||
EXPECT_EQ(golden_api_defs_str, new_api_defs_str)
|
||||
@ -170,6 +289,21 @@ void RunApiTest(bool update_api_def, const string& api_files_dir) {
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& golden_api_entry : golden_api_defs_map) {
|
||||
const auto& file_path = golden_api_entry.first;
|
||||
if (new_api_defs_map.find(file_path) == new_api_defs_map.end()) {
|
||||
if (update_api_def) {
|
||||
std::cout << "Deleting " << file_path << "..." << std::endl;
|
||||
TF_EXPECT_OK(env->DeleteFile(file_path));
|
||||
updated_at_least_one_file = true;
|
||||
} else {
|
||||
EXPECT_EQ("", golden_api_entry.second)
|
||||
<< "To update golden API files, run "
|
||||
<< "tensorflow/core/api_def/update_api_def.sh.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (update_api_def && !updated_at_least_one_file) {
|
||||
std::cout << "Api def files are already up to date." << std::endl;
|
||||
}
|
||||
|
@ -1,670 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "Abort"
|
||||
endpoint {
|
||||
name: "Abort"
|
||||
}
|
||||
summary: "Raise a exception to abort the process when called."
|
||||
description: <<END
|
||||
If exit_without_error is true, the process will exit normally,
|
||||
otherwise it will exit with a SIGABORT signal.
|
||||
|
||||
Returns nothing but an exception.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "Abs"
|
||||
endpoint {
|
||||
name: "Abs"
|
||||
}
|
||||
summary: "Computes the absolute value of a tensor."
|
||||
description: <<END
|
||||
Given a tensor `x`, this operation returns a tensor containing the absolute
|
||||
value of each element in `x`. For example, if x is an input element and y is
|
||||
an output element, this operation computes \\(y = |x|\\).
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AccumulatorApplyGradient"
|
||||
endpoint {
|
||||
name: "AccumulatorApplyGradient"
|
||||
}
|
||||
summary: "Applies a gradient to a given accumulator."
|
||||
description: <<END
|
||||
Does not add if local_step is lesser than the accumulator's global_step.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AccumulatorNumAccumulated"
|
||||
endpoint {
|
||||
name: "AccumulatorNumAccumulated"
|
||||
}
|
||||
summary: "Returns the number of gradients aggregated in the given accumulators."
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AccumulatorSetGlobalStep"
|
||||
endpoint {
|
||||
name: "AccumulatorSetGlobalStep"
|
||||
}
|
||||
summary: "Updates the accumulator with a new value for global_step."
|
||||
description: <<END
|
||||
Logs warning if the accumulator's value is already higher than
|
||||
new_global_step.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AccumulatorTakeGradient"
|
||||
endpoint {
|
||||
name: "AccumulatorTakeGradient"
|
||||
}
|
||||
summary: "Extracts the average gradient in the given ConditionalAccumulator."
|
||||
description: <<END
|
||||
The op blocks until sufficient (i.e., more than num_required)
|
||||
gradients have been accumulated. If the accumulator has already
|
||||
aggregated more than num_required gradients, it returns the average of
|
||||
the accumulated gradients. Also automatically increments the recorded
|
||||
global_step in the accumulator by 1, and resets the aggregate to 0.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "Acos"
|
||||
endpoint {
|
||||
name: "Acos"
|
||||
}
|
||||
summary: "Computes acos of x element-wise."
|
||||
}
|
||||
op {
|
||||
graph_op_name: "Acosh"
|
||||
endpoint {
|
||||
name: "Acosh"
|
||||
}
|
||||
summary: "Computes inverse hyperbolic cosine of x element-wise."
|
||||
}
|
||||
op {
|
||||
graph_op_name: "Add"
|
||||
endpoint {
|
||||
name: "Add"
|
||||
}
|
||||
summary: "Returns x + y element-wise."
|
||||
description: <<END
|
||||
*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
|
||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AddManySparseToTensorsMap"
|
||||
endpoint {
|
||||
name: "AddManySparseToTensorsMap"
|
||||
}
|
||||
summary: "Add an `N`-minibatch `SparseTensor` to a `SparseTensorsMap`, return `N` handles."
|
||||
description: <<END
|
||||
A `SparseTensor` of rank `R` is represented by three tensors: `sparse_indices`,
|
||||
`sparse_values`, and `sparse_shape`, where
|
||||
|
||||
```sparse_indices.shape[1] == sparse_shape.shape[0] == R```
|
||||
|
||||
An `N`-minibatch of `SparseTensor` objects is represented as a `SparseTensor`
|
||||
having a first `sparse_indices` column taking values between `[0, N)`, where
|
||||
the minibatch size `N == sparse_shape[0]`.
|
||||
|
||||
The input `SparseTensor` must have rank `R` greater than 1, and the first
|
||||
dimension is treated as the minibatch dimension. Elements of the `SparseTensor`
|
||||
must be sorted in increasing order of this first dimension. The stored
|
||||
`SparseTensor` objects pointed to by each row of the output `sparse_handles`
|
||||
will have rank `R-1`.
|
||||
|
||||
The `SparseTensor` values can then be read out as part of a minibatch by passing
|
||||
the given keys as vector elements to `TakeManySparseFromTensorsMap`. To ensure
|
||||
the correct `SparseTensorsMap` is accessed, ensure that the same
|
||||
`container` and `shared_name` are passed to that Op. If no `shared_name`
|
||||
is provided here, instead use the *name* of the Operation created by calling
|
||||
`AddManySparseToTensorsMap` as the `shared_name` passed to
|
||||
`TakeManySparseFromTensorsMap`. Ensure the Operations are colocated.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AddN"
|
||||
endpoint {
|
||||
name: "AddN"
|
||||
}
|
||||
summary: "Add all input tensors element wise."
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AddSparseToTensorsMap"
|
||||
endpoint {
|
||||
name: "AddSparseToTensorsMap"
|
||||
}
|
||||
summary: "Add a `SparseTensor` to a `SparseTensorsMap` return its handle."
|
||||
description: <<END
|
||||
A `SparseTensor` is represented by three tensors: `sparse_indices`,
|
||||
`sparse_values`, and `sparse_shape`.
|
||||
|
||||
This operator takes the given `SparseTensor` and adds it to a container
|
||||
object (a `SparseTensorsMap`). A unique key within this container is generated
|
||||
in the form of an `int64`, and this is the value that is returned.
|
||||
|
||||
The `SparseTensor` can then be read out as part of a minibatch by passing
|
||||
the key as a vector element to `TakeManySparseFromTensorsMap`. To ensure
|
||||
the correct `SparseTensorsMap` is accessed, ensure that the same
|
||||
`container` and `shared_name` are passed to that Op. If no `shared_name`
|
||||
is provided here, instead use the *name* of the Operation created by calling
|
||||
`AddSparseToTensorsMap` as the `shared_name` passed to
|
||||
`TakeManySparseFromTensorsMap`. Ensure the Operations are colocated.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AdjustContrast"
|
||||
endpoint {
|
||||
name: "AdjustContrast"
|
||||
}
|
||||
summary: "Deprecated. Disallowed in GraphDef version >= 2."
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AdjustContrastv2"
|
||||
endpoint {
|
||||
name: "AdjustContrastv2"
|
||||
}
|
||||
summary: "Adjust the contrast of one or more images."
|
||||
description: <<END
|
||||
`images` is a tensor of at least 3 dimensions. The last 3 dimensions are
|
||||
interpreted as `[height, width, channels]`. The other dimensions only
|
||||
represent a collection of images, such as `[batch, height, width, channels].`
|
||||
|
||||
Contrast is adjusted independently for each channel of each image.
|
||||
|
||||
For each channel, the Op first computes the mean of the image pixels in the
|
||||
channel and then adjusts each component of each pixel to
|
||||
`(x - mean) * contrast_factor + mean`.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AdjustHue"
|
||||
endpoint {
|
||||
name: "AdjustHue"
|
||||
}
|
||||
summary: "Adjust the hue of one or more images."
|
||||
description: <<END
|
||||
`images` is a tensor of at least 3 dimensions. The last dimension is
|
||||
interpretted as channels, and must be three.
|
||||
|
||||
The input image is considered in the RGB colorspace. Conceptually, the RGB
|
||||
colors are first mapped into HSV. A delta is then applied all the hue values,
|
||||
and then remapped back to RGB colorspace.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AdjustSaturation"
|
||||
endpoint {
|
||||
name: "AdjustSaturation"
|
||||
}
|
||||
summary: "Adjust the saturation of one or more images."
|
||||
description: <<END
|
||||
`images` is a tensor of at least 3 dimensions. The last dimension is
|
||||
interpretted as channels, and must be three.
|
||||
|
||||
The input image is considered in the RGB colorspace. Conceptually, the RGB
|
||||
colors are first mapped into HSV. A scale is then applied all the saturation
|
||||
values, and then remapped back to RGB colorspace.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "All"
|
||||
endpoint {
|
||||
name: "All"
|
||||
}
|
||||
summary: "Computes the \"logical and\" of elements across dimensions of a tensor."
|
||||
description: <<END
|
||||
Reduces `input` along the dimensions given in `reduction_indices`. Unless
|
||||
`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
|
||||
`reduction_indices`. If `keep_dims` is true, the reduced dimensions are
|
||||
retained with length 1.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AllCandidateSampler"
|
||||
endpoint {
|
||||
name: "AllCandidateSampler"
|
||||
}
|
||||
summary: "Generates labels for candidate sampling with a learned unigram distribution."
|
||||
description: <<END
|
||||
See explanations of candidate sampling and the data formats at
|
||||
go/candidate-sampling.
|
||||
|
||||
For each batch, this op picks a single set of sampled candidate labels.
|
||||
|
||||
The advantages of sampling candidates per-batch are simplicity and the
|
||||
possibility of efficient dense matrix multiplication. The disadvantage is that
|
||||
the sampled candidates must be chosen independently of the context and of the
|
||||
true labels.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "Angle"
|
||||
endpoint {
|
||||
name: "Angle"
|
||||
}
|
||||
summary: "Returns the argument of a complex number."
|
||||
description: <<END
|
||||
Given a tensor `input` of complex numbers, this operation returns a tensor of
|
||||
type `float` that is the argument of each element in `input`. All elements in
|
||||
`input` must be complex numbers of the form \\(a + bj\\), where *a*
|
||||
is the real part and *b* is the imaginary part.
|
||||
|
||||
The argument returned by this operation is of the form \\(atan2(b, a)\\).
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
# tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
|
||||
tf.angle(input) ==> [2.0132, 1.056]
|
||||
```
|
||||
|
||||
@compatibility(numpy)
|
||||
Equivalent to np.angle.
|
||||
@end_compatibility
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "Any"
|
||||
endpoint {
|
||||
name: "Any"
|
||||
}
|
||||
summary: "Computes the \"logical or\" of elements across dimensions of a tensor."
|
||||
description: <<END
|
||||
Reduces `input` along the dimensions given in `reduction_indices`. Unless
|
||||
`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
|
||||
`reduction_indices`. If `keep_dims` is true, the reduced dimensions are
|
||||
retained with length 1.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ApplyAdadelta"
|
||||
endpoint {
|
||||
name: "ApplyAdadelta"
|
||||
}
|
||||
summary: "Update \'*var\' according to the adadelta scheme."
|
||||
description: <<END
|
||||
accum = rho() * accum + (1 - rho()) * grad.square();
|
||||
update = (update_accum + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad;
|
||||
update_accum = rho() * update_accum + (1 - rho()) * update.square();
|
||||
var -= update;
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ApplyAdagrad"
|
||||
endpoint {
|
||||
name: "ApplyAdagrad"
|
||||
}
|
||||
summary: "Update \'*var\' according to the adagrad scheme."
|
||||
description: <<END
|
||||
accum += grad * grad
|
||||
var -= lr * grad * (1 / sqrt(accum))
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ApplyAdagradDA"
|
||||
endpoint {
|
||||
name: "ApplyAdagradDA"
|
||||
}
|
||||
summary: "Update \'*var\' according to the proximal adagrad scheme."
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ApplyAdam"
|
||||
endpoint {
|
||||
name: "ApplyAdam"
|
||||
}
|
||||
summary: "Update \'*var\' according to the Adam algorithm."
|
||||
description: <<END
|
||||
lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
|
||||
m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t
|
||||
v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t
|
||||
variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon)
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ApplyCenteredRMSProp"
|
||||
endpoint {
|
||||
name: "ApplyCenteredRMSProp"
|
||||
}
|
||||
summary: "Update \'*var\' according to the centered RMSProp algorithm."
|
||||
description: <<END
|
||||
The centered RMSProp algorithm uses an estimate of the centered second moment
|
||||
(i.e., the variance) for normalization, as opposed to regular RMSProp, which
|
||||
uses the (uncentered) second moment. This often helps with training, but is
|
||||
slightly more expensive in terms of computation and memory.
|
||||
|
||||
Note that in dense implementation of this algorithm, mg, ms, and mom will
|
||||
update even if the grad is zero, but in this sparse implementation, mg, ms,
|
||||
and mom will not update in iterations during which the grad is zero.
|
||||
|
||||
mean_square = decay * mean_square + (1-decay) * gradient ** 2
|
||||
mean_grad = decay * mean_grad + (1-decay) * gradient
|
||||
|
||||
Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2)
|
||||
|
||||
mg <- rho * mg_{t-1} + (1-rho) * grad
|
||||
ms <- rho * ms_{t-1} + (1-rho) * grad * grad
|
||||
mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)
|
||||
var <- var - mom
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ApplyFtrl"
|
||||
endpoint {
|
||||
name: "ApplyFtrl"
|
||||
}
|
||||
summary: "Update \'*var\' according to the Ftrl-proximal scheme."
|
||||
description: <<END
|
||||
accum_new = accum + grad * grad
|
||||
linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var
|
||||
quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2
|
||||
var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0
|
||||
accum = accum_new
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ApplyFtrlV2"
|
||||
endpoint {
|
||||
name: "ApplyFtrlV2"
|
||||
}
|
||||
summary: "Update \'*var\' according to the Ftrl-proximal scheme."
|
||||
description: <<END
|
||||
grad_with_shrinkage = grad + 2 * l2_shrinkage * var
|
||||
accum_new = accum + grad_with_shrinkage * grad_with_shrinkage
|
||||
linear += grad_with_shrinkage +
|
||||
(accum_new^(-lr_power) - accum^(-lr_power)) / lr * var
|
||||
quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2
|
||||
var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0
|
||||
accum = accum_new
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ApplyGradientDescent"
|
||||
endpoint {
|
||||
name: "ApplyGradientDescent"
|
||||
}
|
||||
summary: "Update \'*var\' by subtracting \'alpha\' * \'delta\' from it."
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ApplyMomentum"
|
||||
endpoint {
|
||||
name: "ApplyMomentum"
|
||||
}
|
||||
summary: "Update \'*var\' according to the momentum scheme. Set use_nesterov = True if you"
|
||||
description: <<END
|
||||
want to use Nesterov momentum.
|
||||
|
||||
accum = accum * momentum + grad
|
||||
var -= lr * accum
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ApplyProximalAdagrad"
|
||||
endpoint {
|
||||
name: "ApplyProximalAdagrad"
|
||||
}
|
||||
summary: "Update \'*var\' and \'*accum\' according to FOBOS with Adagrad learning rate."
|
||||
description: <<END
|
||||
accum += grad * grad
|
||||
prox_v = var - lr * grad * (1 / sqrt(accum))
|
||||
var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0}
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ApplyProximalGradientDescent"
|
||||
endpoint {
|
||||
name: "ApplyProximalGradientDescent"
|
||||
}
|
||||
summary: "Update \'*var\' as FOBOS algorithm with fixed learning rate."
|
||||
description: <<END
|
||||
prox_v = var - alpha * delta
|
||||
var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0}
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ApplyRMSProp"
|
||||
endpoint {
|
||||
name: "ApplyRMSProp"
|
||||
}
|
||||
summary: "Update \'*var\' according to the RMSProp algorithm."
|
||||
description: <<END
|
||||
Note that in dense implementation of this algorithm, ms and mom will
|
||||
update even if the grad is zero, but in this sparse implementation, ms
|
||||
and mom will not update in iterations during which the grad is zero.
|
||||
|
||||
mean_square = decay * mean_square + (1-decay) * gradient ** 2
|
||||
Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
|
||||
|
||||
ms <- rho * ms_{t-1} + (1-rho) * grad * grad
|
||||
mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
|
||||
var <- var - mom
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ApproximateEqual"
|
||||
endpoint {
|
||||
name: "ApproximateEqual"
|
||||
}
|
||||
summary: "Returns the truth value of abs(x-y) < tolerance element-wise."
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ArgMax"
|
||||
endpoint {
|
||||
name: "ArgMax"
|
||||
}
|
||||
summary: "Returns the index with the largest value across dimensions of a tensor."
|
||||
description: <<END
|
||||
Note that in case of ties the identity of the return value is not guaranteed.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ArgMin"
|
||||
endpoint {
|
||||
name: "ArgMin"
|
||||
}
|
||||
summary: "Returns the index with the smallest value across dimensions of a tensor."
|
||||
description: <<END
|
||||
Note that in case of ties the identity of the return value is not guaranteed.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AsString"
|
||||
endpoint {
|
||||
name: "AsString"
|
||||
}
|
||||
summary: "Converts each entry in the given tensor to strings. Supports many numeric"
|
||||
description: <<END
|
||||
types and boolean.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "Asin"
|
||||
endpoint {
|
||||
name: "Asin"
|
||||
}
|
||||
summary: "Computes asin of x element-wise."
|
||||
}
|
||||
op {
|
||||
graph_op_name: "Asinh"
|
||||
endpoint {
|
||||
name: "Asinh"
|
||||
}
|
||||
summary: "Computes inverse hyperbolic sine of x element-wise."
|
||||
}
|
||||
op {
|
||||
graph_op_name: "Assert"
|
||||
endpoint {
|
||||
name: "Assert"
|
||||
}
|
||||
summary: "Asserts that the given condition is true."
|
||||
description: <<END
|
||||
If `condition` evaluates to false, print the list of tensors in `data`.
|
||||
`summarize` determines how many entries of the tensors to print.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "Assign"
|
||||
endpoint {
|
||||
name: "Assign"
|
||||
}
|
||||
summary: "Update \'ref\' by assigning \'value\' to it."
|
||||
description: <<END
|
||||
This operation outputs "ref" after the assignment is done.
|
||||
This makes it easier to chain operations that need to use the reset value.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AssignAdd"
|
||||
endpoint {
|
||||
name: "AssignAdd"
|
||||
}
|
||||
summary: "Update \'ref\' by adding \'value\' to it."
|
||||
description: <<END
|
||||
This operation outputs "ref" after the update is done.
|
||||
This makes it easier to chain operations that need to use the reset value.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AssignSub"
|
||||
endpoint {
|
||||
name: "AssignSub"
|
||||
}
|
||||
summary: "Update \'ref\' by subtracting \'value\' from it."
|
||||
description: <<END
|
||||
This operation outputs "ref" after the update is done.
|
||||
This makes it easier to chain operations that need to use the reset value.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "Atan"
|
||||
endpoint {
|
||||
name: "Atan"
|
||||
}
|
||||
summary: "Computes atan of x element-wise."
|
||||
}
|
||||
op {
|
||||
graph_op_name: "Atan2"
|
||||
endpoint {
|
||||
name: "Atan2"
|
||||
}
|
||||
summary: "Computes arctangent of `y/x` element-wise, respecting signs of the arguments."
|
||||
description: <<END
|
||||
This is the angle \( \theta \in [-\pi, \pi] \) such that
|
||||
\[ x = r \cos(\theta) \]
|
||||
and
|
||||
\[ y = r \sin(\theta) \]
|
||||
where \(r = \sqrt(x^2 + y^2) \).
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "Atanh"
|
||||
endpoint {
|
||||
name: "Atanh"
|
||||
}
|
||||
summary: "Computes inverse hyperbolic tangent of x element-wise."
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AudioSpectrogram"
|
||||
endpoint {
|
||||
name: "AudioSpectrogram"
|
||||
}
|
||||
summary: "Produces a visualization of audio data over time."
|
||||
description: <<END
|
||||
Spectrograms are a standard way of representing audio information as a series of
|
||||
slices of frequency information, one slice for each window of time. By joining
|
||||
these together into a sequence, they form a distinctive fingerprint of the sound
|
||||
over time.
|
||||
|
||||
This op expects to receive audio data as an input, stored as floats in the range
|
||||
-1 to 1, together with a window width in samples, and a stride specifying how
|
||||
far to move the window between slices. From this it generates a three
|
||||
dimensional output. The lowest dimension has an amplitude value for each
|
||||
frequency during that time slice. The next dimension is time, with successive
|
||||
frequency slices. The final dimension is for the channels in the input, so a
|
||||
stereo audio input would have two here for example.
|
||||
|
||||
This means the layout when converted and saved as an image is rotated 90 degrees
|
||||
clockwise from a typical spectrogram. Time is descending down the Y axis, and
|
||||
the frequency decreases from left to right.
|
||||
|
||||
Each value in the result represents the square root of the sum of the real and
|
||||
imaginary parts of an FFT on the current window of samples. In this way, the
|
||||
lowest dimension represents the power of each frequency in the current window,
|
||||
and adjacent windows are concatenated in the next dimension.
|
||||
|
||||
To get a more intuitive and visual look at what this operation does, you can run
|
||||
tensorflow/examples/wav_to_spectrogram to read in an audio file and save out the
|
||||
resulting spectrogram as a PNG image.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AudioSummary"
|
||||
endpoint {
|
||||
name: "AudioSummary"
|
||||
}
|
||||
summary: "Outputs a `Summary` protocol buffer with audio."
|
||||
description: <<END
|
||||
The summary has up to `max_outputs` summary values containing audio. The
|
||||
audio is built from `tensor` which must be 3-D with shape `[batch_size,
|
||||
frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are
|
||||
assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`.
|
||||
|
||||
The `tag` argument is a scalar `Tensor` of type `string`. It is used to
|
||||
build the `tag` of the summary values:
|
||||
|
||||
* If `max_outputs` is 1, the summary value tag is '*tag*/audio'.
|
||||
* If `max_outputs` is greater than 1, the summary value tags are
|
||||
generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AudioSummaryV2"
|
||||
endpoint {
|
||||
name: "AudioSummaryV2"
|
||||
}
|
||||
summary: "Outputs a `Summary` protocol buffer with audio."
|
||||
description: <<END
|
||||
The summary has up to `max_outputs` summary values containing audio. The
|
||||
audio is built from `tensor` which must be 3-D with shape `[batch_size,
|
||||
frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are
|
||||
assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`.
|
||||
|
||||
The `tag` argument is a scalar `Tensor` of type `string`. It is used to
|
||||
build the `tag` of the summary values:
|
||||
|
||||
* If `max_outputs` is 1, the summary value tag is '*tag*/audio'.
|
||||
* If `max_outputs` is greater than 1, the summary value tags are
|
||||
generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AvgPool"
|
||||
endpoint {
|
||||
name: "AvgPool"
|
||||
}
|
||||
summary: "Performs average pooling on the input."
|
||||
description: <<END
|
||||
Each entry in `output` is the mean of the corresponding size `ksize`
|
||||
window in `value`.
|
||||
END
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AvgPool3D"
|
||||
endpoint {
|
||||
name: "AvgPool3D"
|
||||
}
|
||||
summary: "Performs 3D average pooling on the input."
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AvgPool3DGrad"
|
||||
endpoint {
|
||||
name: "AvgPool3DGrad"
|
||||
}
|
||||
summary: "Computes gradients of average pooling function."
|
||||
}
|
||||
op {
|
||||
graph_op_name: "AvgPoolGrad"
|
||||
endpoint {
|
||||
name: "AvgPoolGrad"
|
||||
}
|
||||
summary: "Computes gradients of the average pooling function."
|
||||
}
|
16
tensorflow/core/api_def/base_api/api_def_Abort.pbtxt
Normal file
16
tensorflow/core/api_def/base_api/api_def_Abort.pbtxt
Normal file
@ -0,0 +1,16 @@
|
||||
op {
|
||||
graph_op_name: "Abort"
|
||||
attr {
|
||||
name: "error_msg"
|
||||
description: <<END
|
||||
A string which is the message associated with the exception.
|
||||
END
|
||||
}
|
||||
summary: "Raise a exception to abort the process when called."
|
||||
description: <<END
|
||||
If exit_without_error is true, the process will exit normally,
|
||||
otherwise it will exit with a SIGABORT signal.
|
||||
|
||||
Returns nothing but an exception.
|
||||
END
|
||||
}
|
9
tensorflow/core/api_def/base_api/api_def_Abs.pbtxt
Normal file
9
tensorflow/core/api_def/base_api/api_def_Abs.pbtxt
Normal file
@ -0,0 +1,9 @@
|
||||
op {
|
||||
graph_op_name: "Abs"
|
||||
summary: "Computes the absolute value of a tensor."
|
||||
description: <<END
|
||||
Given a tensor `x`, this operation returns a tensor containing the absolute
|
||||
value of each element in `x`. For example, if x is an input element and y is
|
||||
an output element, this operation computes \\(y = |x|\\).
|
||||
END
|
||||
}
|
26
tensorflow/core/api_def/base_api/api_def_AccumulateNV2.pbtxt
Normal file
26
tensorflow/core/api_def/base_api/api_def_AccumulateNV2.pbtxt
Normal file
@ -0,0 +1,26 @@
|
||||
op {
|
||||
graph_op_name: "AccumulateNV2"
|
||||
in_arg {
|
||||
name: "inputs"
|
||||
description: <<END
|
||||
A list of `Tensor` objects, each with same shape and type.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "shape"
|
||||
description: <<END
|
||||
Shape of elements of `inputs`.
|
||||
END
|
||||
}
|
||||
summary: "Returns the element-wise sum of a list of tensors."
|
||||
description: <<END
|
||||
`tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not
|
||||
wait for all of its inputs to be ready before beginning to sum. This can
|
||||
save memory if inputs are ready at different times, since minimum temporary
|
||||
storage is proportional to the output size rather than the inputs size.
|
||||
|
||||
Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable.
|
||||
|
||||
Returns a `Tensor` of same shape and type as the elements of `inputs`.
|
||||
END
|
||||
}
|
@ -0,0 +1,32 @@
|
||||
op {
|
||||
graph_op_name: "AccumulatorApplyGradient"
|
||||
in_arg {
|
||||
name: "handle"
|
||||
description: <<END
|
||||
The handle to a accumulator.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "local_step"
|
||||
description: <<END
|
||||
The local_step value at which the gradient was computed.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "gradient"
|
||||
description: <<END
|
||||
A tensor of the gradient to be accumulated.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The data type of accumulated gradients. Needs to correspond to the type
|
||||
of the accumulator.
|
||||
END
|
||||
}
|
||||
summary: "Applies a gradient to a given accumulator."
|
||||
description: <<END
|
||||
Does not add if local_step is lesser than the accumulator's global_step.
|
||||
END
|
||||
}
|
@ -0,0 +1,16 @@
|
||||
op {
|
||||
graph_op_name: "AccumulatorNumAccumulated"
|
||||
in_arg {
|
||||
name: "handle"
|
||||
description: <<END
|
||||
The handle to an accumulator.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "num_accumulated"
|
||||
description: <<END
|
||||
The number of gradients aggregated in the given accumulator.
|
||||
END
|
||||
}
|
||||
summary: "Returns the number of gradients aggregated in the given accumulators."
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
op {
|
||||
graph_op_name: "AccumulatorSetGlobalStep"
|
||||
in_arg {
|
||||
name: "handle"
|
||||
description: <<END
|
||||
The handle to an accumulator.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "new_global_step"
|
||||
description: <<END
|
||||
The new global_step value to set.
|
||||
END
|
||||
}
|
||||
summary: "Updates the accumulator with a new value for global_step."
|
||||
description: <<END
|
||||
Logs warning if the accumulator's value is already higher than
|
||||
new_global_step.
|
||||
END
|
||||
}
|
@ -0,0 +1,36 @@
|
||||
op {
|
||||
graph_op_name: "AccumulatorTakeGradient"
|
||||
in_arg {
|
||||
name: "handle"
|
||||
description: <<END
|
||||
The handle to an accumulator.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "num_required"
|
||||
description: <<END
|
||||
Number of gradients required before we return an aggregate.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "average"
|
||||
description: <<END
|
||||
The average of the accumulated gradients.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The data type of accumulated gradients. Needs to correspond to the type
|
||||
of the accumulator.
|
||||
END
|
||||
}
|
||||
summary: "Extracts the average gradient in the given ConditionalAccumulator."
|
||||
description: <<END
|
||||
The op blocks until sufficient (i.e., more than num_required)
|
||||
gradients have been accumulated. If the accumulator has already
|
||||
aggregated more than num_required gradients, it returns the average of
|
||||
the accumulated gradients. Also automatically increments the recorded
|
||||
global_step in the accumulator by 1, and resets the aggregate to 0.
|
||||
END
|
||||
}
|
4
tensorflow/core/api_def/base_api/api_def_Acos.pbtxt
Normal file
4
tensorflow/core/api_def/base_api/api_def_Acos.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "Acos"
|
||||
summary: "Computes acos of x element-wise."
|
||||
}
|
4
tensorflow/core/api_def/base_api/api_def_Acosh.pbtxt
Normal file
4
tensorflow/core/api_def/base_api/api_def_Acosh.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "Acosh"
|
||||
summary: "Computes inverse hyperbolic cosine of x element-wise."
|
||||
}
|
8
tensorflow/core/api_def/base_api/api_def_Add.pbtxt
Normal file
8
tensorflow/core/api_def/base_api/api_def_Add.pbtxt
Normal file
@ -0,0 +1,8 @@
|
||||
op {
|
||||
graph_op_name: "Add"
|
||||
summary: "Returns x + y element-wise."
|
||||
description: <<END
|
||||
*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
|
||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
|
||||
END
|
||||
}
|
@ -0,0 +1,68 @@
|
||||
op {
|
||||
graph_op_name: "AddManySparseToTensorsMap"
|
||||
in_arg {
|
||||
name: "sparse_indices"
|
||||
description: <<END
|
||||
2-D. The `indices` of the minibatch `SparseTensor`.
|
||||
`sparse_indices[:, 0]` must be ordered values in `[0, N)`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "sparse_values"
|
||||
description: <<END
|
||||
1-D. The `values` of the minibatch `SparseTensor`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "sparse_shape"
|
||||
description: <<END
|
||||
1-D. The `shape` of the minibatch `SparseTensor`.
|
||||
The minibatch size `N == sparse_shape[0]`.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "sparse_handles"
|
||||
description: <<END
|
||||
1-D. The handles of the `SparseTensor` now stored in the
|
||||
`SparseTensorsMap`. Shape: `[N]`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "container"
|
||||
description: <<END
|
||||
The container name for the `SparseTensorsMap` created by this op.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "shared_name"
|
||||
description: <<END
|
||||
The shared name for the `SparseTensorsMap` created by this op.
|
||||
If blank, the new Operation's unique name is used.
|
||||
END
|
||||
}
|
||||
summary: "Add an `N`-minibatch `SparseTensor` to a `SparseTensorsMap`, return `N` handles."
|
||||
description: <<END
|
||||
A `SparseTensor` of rank `R` is represented by three tensors: `sparse_indices`,
|
||||
`sparse_values`, and `sparse_shape`, where
|
||||
|
||||
```sparse_indices.shape[1] == sparse_shape.shape[0] == R```
|
||||
|
||||
An `N`-minibatch of `SparseTensor` objects is represented as a `SparseTensor`
|
||||
having a first `sparse_indices` column taking values between `[0, N)`, where
|
||||
the minibatch size `N == sparse_shape[0]`.
|
||||
|
||||
The input `SparseTensor` must have rank `R` greater than 1, and the first
|
||||
dimension is treated as the minibatch dimension. Elements of the `SparseTensor`
|
||||
must be sorted in increasing order of this first dimension. The stored
|
||||
`SparseTensor` objects pointed to by each row of the output `sparse_handles`
|
||||
will have rank `R-1`.
|
||||
|
||||
The `SparseTensor` values can then be read out as part of a minibatch by passing
|
||||
the given keys as vector elements to `TakeManySparseFromTensorsMap`. To ensure
|
||||
the correct `SparseTensorsMap` is accessed, ensure that the same
|
||||
`container` and `shared_name` are passed to that Op. If no `shared_name`
|
||||
is provided here, instead use the *name* of the Operation created by calling
|
||||
`AddManySparseToTensorsMap` as the `shared_name` passed to
|
||||
`TakeManySparseFromTensorsMap`. Ensure the Operations are colocated.
|
||||
END
|
||||
}
|
10
tensorflow/core/api_def/base_api/api_def_AddN.pbtxt
Normal file
10
tensorflow/core/api_def/base_api/api_def_AddN.pbtxt
Normal file
@ -0,0 +1,10 @@
|
||||
op {
|
||||
graph_op_name: "AddN"
|
||||
in_arg {
|
||||
name: "inputs"
|
||||
description: <<END
|
||||
Must all be the same size and shape.
|
||||
END
|
||||
}
|
||||
summary: "Add all input tensors element wise."
|
||||
}
|
@ -0,0 +1,58 @@
|
||||
op {
|
||||
graph_op_name: "AddSparseToTensorsMap"
|
||||
in_arg {
|
||||
name: "sparse_indices"
|
||||
description: <<END
|
||||
2-D. The `indices` of the `SparseTensor`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "sparse_values"
|
||||
description: <<END
|
||||
1-D. The `values` of the `SparseTensor`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "sparse_shape"
|
||||
description: <<END
|
||||
1-D. The `shape` of the `SparseTensor`.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "sparse_handle"
|
||||
description: <<END
|
||||
0-D. The handle of the `SparseTensor` now stored in the
|
||||
`SparseTensorsMap`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "container"
|
||||
description: <<END
|
||||
The container name for the `SparseTensorsMap` created by this op.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "shared_name"
|
||||
description: <<END
|
||||
The shared name for the `SparseTensorsMap` created by this op.
|
||||
If blank, the new Operation's unique name is used.
|
||||
END
|
||||
}
|
||||
summary: "Add a `SparseTensor` to a `SparseTensorsMap` return its handle."
|
||||
description: <<END
|
||||
A `SparseTensor` is represented by three tensors: `sparse_indices`,
|
||||
`sparse_values`, and `sparse_shape`.
|
||||
|
||||
This operator takes the given `SparseTensor` and adds it to a container
|
||||
object (a `SparseTensorsMap`). A unique key within this container is generated
|
||||
in the form of an `int64`, and this is the value that is returned.
|
||||
|
||||
The `SparseTensor` can then be read out as part of a minibatch by passing
|
||||
the key as a vector element to `TakeManySparseFromTensorsMap`. To ensure
|
||||
the correct `SparseTensorsMap` is accessed, ensure that the same
|
||||
`container` and `shared_name` are passed to that Op. If no `shared_name`
|
||||
is provided here, instead use the *name* of the Operation created by calling
|
||||
`AddSparseToTensorsMap` as the `shared_name` passed to
|
||||
`TakeManySparseFromTensorsMap`. Ensure the Operations are colocated.
|
||||
END
|
||||
}
|
8
tensorflow/core/api_def/base_api/api_def_AddV2.pbtxt
Normal file
8
tensorflow/core/api_def/base_api/api_def_AddV2.pbtxt
Normal file
@ -0,0 +1,8 @@
|
||||
op {
|
||||
graph_op_name: "AddV2"
|
||||
summary: "Returns x + y element-wise."
|
||||
description: <<END
|
||||
*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
|
||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
|
||||
END
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "AdjustContrast"
|
||||
summary: "Deprecated. Disallowed in GraphDef version >= 2."
|
||||
}
|
@ -0,0 +1,36 @@
|
||||
op {
|
||||
graph_op_name: "AdjustContrastv2"
|
||||
endpoint {
|
||||
name: "AdjustContrast"
|
||||
}
|
||||
in_arg {
|
||||
name: "images"
|
||||
description: <<END
|
||||
Images to adjust. At least 3-D.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "contrast_factor"
|
||||
description: <<END
|
||||
A float multiplier for adjusting contrast.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
The contrast-adjusted image or images.
|
||||
END
|
||||
}
|
||||
summary: "Adjust the contrast of one or more images."
|
||||
description: <<END
|
||||
`images` is a tensor of at least 3 dimensions. The last 3 dimensions are
|
||||
interpreted as `[height, width, channels]`. The other dimensions only
|
||||
represent a collection of images, such as `[batch, height, width, channels].`
|
||||
|
||||
Contrast is adjusted independently for each channel of each image.
|
||||
|
||||
For each channel, the Op first computes the mean of the image pixels in the
|
||||
channel and then adjusts each component of each pixel to
|
||||
`(x - mean) * contrast_factor + mean`.
|
||||
END
|
||||
}
|
30
tensorflow/core/api_def/base_api/api_def_AdjustHue.pbtxt
Normal file
30
tensorflow/core/api_def/base_api/api_def_AdjustHue.pbtxt
Normal file
@ -0,0 +1,30 @@
|
||||
op {
|
||||
graph_op_name: "AdjustHue"
|
||||
in_arg {
|
||||
name: "images"
|
||||
description: <<END
|
||||
Images to adjust. At least 3-D.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "delta"
|
||||
description: <<END
|
||||
A float delta to add to the hue.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
The hue-adjusted image or images.
|
||||
END
|
||||
}
|
||||
summary: "Adjust the hue of one or more images."
|
||||
description: <<END
|
||||
`images` is a tensor of at least 3 dimensions. The last dimension is
|
||||
interpretted as channels, and must be three.
|
||||
|
||||
The input image is considered in the RGB colorspace. Conceptually, the RGB
|
||||
colors are first mapped into HSV. A delta is then applied all the hue values,
|
||||
and then remapped back to RGB colorspace.
|
||||
END
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
op {
|
||||
graph_op_name: "AdjustSaturation"
|
||||
in_arg {
|
||||
name: "images"
|
||||
description: <<END
|
||||
Images to adjust. At least 3-D.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "scale"
|
||||
description: <<END
|
||||
A float scale to add to the saturation.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
The hue-adjusted image or images.
|
||||
END
|
||||
}
|
||||
summary: "Adjust the saturation of one or more images."
|
||||
description: <<END
|
||||
`images` is a tensor of at least 3 dimensions. The last dimension is
|
||||
interpretted as channels, and must be three.
|
||||
|
||||
The input image is considered in the RGB colorspace. Conceptually, the RGB
|
||||
colors are first mapped into HSV. A scale is then applied all the saturation
|
||||
values, and then remapped back to RGB colorspace.
|
||||
END
|
||||
}
|
42
tensorflow/core/api_def/base_api/api_def_All.pbtxt
Normal file
42
tensorflow/core/api_def/base_api/api_def_All.pbtxt
Normal file
@ -0,0 +1,42 @@
|
||||
op {
|
||||
graph_op_name: "All"
|
||||
endpoint {
|
||||
name: "All"
|
||||
}
|
||||
endpoint {
|
||||
name: "ReduceAll"
|
||||
}
|
||||
in_arg {
|
||||
name: "input"
|
||||
description: <<END
|
||||
The tensor to reduce.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "reduction_indices"
|
||||
rename_to: "axis"
|
||||
description: <<END
|
||||
The dimensions to reduce. Must be in the range
|
||||
`[-rank(input), rank(input))`.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
The reduced tensor.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "keep_dims"
|
||||
description: <<END
|
||||
If true, retain reduced dimensions with length 1.
|
||||
END
|
||||
}
|
||||
summary: "Computes the \"logical and\" of elements across dimensions of a tensor."
|
||||
description: <<END
|
||||
Reduces `input` along the dimensions given in `reduction_indices`. Unless
|
||||
`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
|
||||
`reduction_indices`. If `keep_dims` is true, the reduced dimensions are
|
||||
retained with length 1.
|
||||
END
|
||||
}
|
@ -0,0 +1,80 @@
|
||||
op {
|
||||
graph_op_name: "AllCandidateSampler"
|
||||
in_arg {
|
||||
name: "true_classes"
|
||||
description: <<END
|
||||
A batch_size * num_true matrix, in which each row contains the
|
||||
IDs of the num_true target_classes in the corresponding original label.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "sampled_candidates"
|
||||
description: <<END
|
||||
A vector of length num_sampled, in which each element is
|
||||
the ID of a sampled candidate.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "true_expected_count"
|
||||
description: <<END
|
||||
A batch_size * num_true matrix, representing
|
||||
the number of times each candidate is expected to occur in a batch
|
||||
of sampled candidates. If unique=true, then this is a probability.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "sampled_expected_count"
|
||||
description: <<END
|
||||
A vector of length num_sampled, for each sampled
|
||||
candidate representing the number of times the candidate is expected
|
||||
to occur in a batch of sampled candidates. If unique=true, then this is a
|
||||
probability.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "num_true"
|
||||
description: <<END
|
||||
Number of true labels per context.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "num_sampled"
|
||||
description: <<END
|
||||
Number of candidates to produce.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "unique"
|
||||
description: <<END
|
||||
If unique is true, we sample with rejection, so that all sampled
|
||||
candidates in a batch are unique. This requires some approximation to
|
||||
estimate the post-rejection sampling probabilities.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "seed"
|
||||
description: <<END
|
||||
If either seed or seed2 are set to be non-zero, the random number
|
||||
generator is seeded by the given seed. Otherwise, it is seeded by a
|
||||
random seed.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "seed2"
|
||||
description: <<END
|
||||
An second seed to avoid seed collision.
|
||||
END
|
||||
}
|
||||
summary: "Generates labels for candidate sampling with a learned unigram distribution."
|
||||
description: <<END
|
||||
See explanations of candidate sampling and the data formats at
|
||||
go/candidate-sampling.
|
||||
|
||||
For each batch, this op picks a single set of sampled candidate labels.
|
||||
|
||||
The advantages of sampling candidates per-batch are simplicity and the
|
||||
possibility of efficient dense matrix multiplication. The disadvantage is that
|
||||
the sampled candidates must be chosen independently of the context and of the
|
||||
true labels.
|
||||
END
|
||||
}
|
23
tensorflow/core/api_def/base_api/api_def_Angle.pbtxt
Normal file
23
tensorflow/core/api_def/base_api/api_def_Angle.pbtxt
Normal file
@ -0,0 +1,23 @@
|
||||
op {
|
||||
graph_op_name: "Angle"
|
||||
summary: "Returns the argument of a complex number."
|
||||
description: <<END
|
||||
Given a tensor `input` of complex numbers, this operation returns a tensor of
|
||||
type `float` that is the argument of each element in `input`. All elements in
|
||||
`input` must be complex numbers of the form \\(a + bj\\), where *a*
|
||||
is the real part and *b* is the imaginary part.
|
||||
|
||||
The argument returned by this operation is of the form \\(atan2(b, a)\\).
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
# tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
|
||||
tf.angle(input) ==> [2.0132, 1.056]
|
||||
```
|
||||
|
||||
@compatibility(numpy)
|
||||
Equivalent to np.angle.
|
||||
@end_compatibility
|
||||
END
|
||||
}
|
42
tensorflow/core/api_def/base_api/api_def_Any.pbtxt
Normal file
42
tensorflow/core/api_def/base_api/api_def_Any.pbtxt
Normal file
@ -0,0 +1,42 @@
|
||||
op {
|
||||
graph_op_name: "Any"
|
||||
endpoint {
|
||||
name: "Any"
|
||||
}
|
||||
endpoint {
|
||||
name: "ReduceAny"
|
||||
}
|
||||
in_arg {
|
||||
name: "input"
|
||||
description: <<END
|
||||
The tensor to reduce.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "reduction_indices"
|
||||
rename_to: "axis"
|
||||
description: <<END
|
||||
The dimensions to reduce. Must be in the range
|
||||
`[-rank(input), rank(input))`.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
The reduced tensor.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "keep_dims"
|
||||
description: <<END
|
||||
If true, retain reduced dimensions with length 1.
|
||||
END
|
||||
}
|
||||
summary: "Computes the \"logical or\" of elements across dimensions of a tensor."
|
||||
description: <<END
|
||||
Reduces `input` along the dimensions given in `reduction_indices`. Unless
|
||||
`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
|
||||
`reduction_indices`. If `keep_dims` is true, the reduced dimensions are
|
||||
retained with length 1.
|
||||
END
|
||||
}
|
65
tensorflow/core/api_def/base_api/api_def_ApplyAdadelta.pbtxt
Normal file
65
tensorflow/core/api_def/base_api/api_def_ApplyAdadelta.pbtxt
Normal file
@ -0,0 +1,65 @@
|
||||
op {
|
||||
graph_op_name: "ApplyAdadelta"
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "accum"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "accum_update"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "rho"
|
||||
description: <<END
|
||||
Decay factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "epsilon"
|
||||
description: <<END
|
||||
Constant factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "grad"
|
||||
description: <<END
|
||||
The gradient.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "out"
|
||||
description: <<END
|
||||
Same as "var".
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If True, updating of the var, accum and update_accum tensors will be protected by
|
||||
a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||
END
|
||||
}
|
||||
summary: "Update \'*var\' according to the adadelta scheme."
|
||||
description: <<END
|
||||
accum = rho() * accum + (1 - rho()) * grad.square();
|
||||
update = (update_accum + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad;
|
||||
update_accum = rho() * update_accum + (1 - rho()) * update.square();
|
||||
var -= update;
|
||||
END
|
||||
}
|
46
tensorflow/core/api_def/base_api/api_def_ApplyAdagrad.pbtxt
Normal file
46
tensorflow/core/api_def/base_api/api_def_ApplyAdagrad.pbtxt
Normal file
@ -0,0 +1,46 @@
|
||||
op {
|
||||
graph_op_name: "ApplyAdagrad"
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "accum"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "grad"
|
||||
description: <<END
|
||||
The gradient.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "out"
|
||||
description: <<END
|
||||
Same as "var".
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If `True`, updating of the var and accum tensors will be protected
|
||||
by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||
contention.
|
||||
END
|
||||
}
|
||||
summary: "Update \'*var\' according to the adagrad scheme."
|
||||
description: <<END
|
||||
accum += grad * grad
|
||||
var -= lr * grad * (1 / sqrt(accum))
|
||||
END
|
||||
}
|
@ -0,0 +1,65 @@
|
||||
op {
|
||||
graph_op_name: "ApplyAdagradDA"
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "gradient_accumulator"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "gradient_squared_accumulator"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "grad"
|
||||
description: <<END
|
||||
The gradient.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "l1"
|
||||
description: <<END
|
||||
L1 regularization. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "l2"
|
||||
description: <<END
|
||||
L2 regularization. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "global_step"
|
||||
description: <<END
|
||||
Training step number. Must be a scalar.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "out"
|
||||
description: <<END
|
||||
Same as "var".
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If True, updating of the var and accum tensors will be protected by
|
||||
a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||
END
|
||||
}
|
||||
summary: "Update \'*var\' according to the proximal adagrad scheme."
|
||||
}
|
90
tensorflow/core/api_def/base_api/api_def_ApplyAdam.pbtxt
Normal file
90
tensorflow/core/api_def/base_api/api_def_ApplyAdam.pbtxt
Normal file
@ -0,0 +1,90 @@
|
||||
op {
|
||||
graph_op_name: "ApplyAdam"
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "m"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "v"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "beta1_power"
|
||||
description: <<END
|
||||
Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "beta2_power"
|
||||
description: <<END
|
||||
Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "beta1"
|
||||
description: <<END
|
||||
Momentum factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "beta2"
|
||||
description: <<END
|
||||
Momentum factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "epsilon"
|
||||
description: <<END
|
||||
Ridge term. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "grad"
|
||||
description: <<END
|
||||
The gradient.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "out"
|
||||
description: <<END
|
||||
Same as "var".
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If `True`, updating of the var, m, and v tensors will be protected
|
||||
by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||
contention.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_nesterov"
|
||||
description: <<END
|
||||
If `True`, uses the nesterov update.
|
||||
END
|
||||
}
|
||||
summary: "Update \'*var\' according to the Adam algorithm."
|
||||
description: <<END
|
||||
lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
|
||||
m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t
|
||||
v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t
|
||||
variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon)
|
||||
END
|
||||
}
|
@ -0,0 +1,86 @@
|
||||
op {
|
||||
graph_op_name: "ApplyCenteredRMSProp"
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "mg"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "ms"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "mom"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "rho"
|
||||
description: <<END
|
||||
Decay rate. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "epsilon"
|
||||
description: <<END
|
||||
Ridge term. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "grad"
|
||||
description: <<END
|
||||
The gradient.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "out"
|
||||
description: <<END
|
||||
Same as "var".
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If `True`, updating of the var, mg, ms, and mom tensors is
|
||||
protected by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||
contention.
|
||||
END
|
||||
}
|
||||
summary: "Update \'*var\' according to the centered RMSProp algorithm."
|
||||
description: <<END
|
||||
The centered RMSProp algorithm uses an estimate of the centered second moment
|
||||
(i.e., the variance) for normalization, as opposed to regular RMSProp, which
|
||||
uses the (uncentered) second moment. This often helps with training, but is
|
||||
slightly more expensive in terms of computation and memory.
|
||||
|
||||
Note that in dense implementation of this algorithm, mg, ms, and mom will
|
||||
update even if the grad is zero, but in this sparse implementation, mg, ms,
|
||||
and mom will not update in iterations during which the grad is zero.
|
||||
|
||||
mean_square = decay * mean_square + (1-decay) * gradient ** 2
|
||||
mean_grad = decay * mean_grad + (1-decay) * gradient
|
||||
|
||||
Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2)
|
||||
|
||||
mg <- rho * mg_{t-1} + (1-rho) * grad
|
||||
ms <- rho * ms_{t-1} + (1-rho) * grad * grad
|
||||
mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)
|
||||
var <- var - mom
|
||||
END
|
||||
}
|
73
tensorflow/core/api_def/base_api/api_def_ApplyFtrl.pbtxt
Normal file
73
tensorflow/core/api_def/base_api/api_def_ApplyFtrl.pbtxt
Normal file
@ -0,0 +1,73 @@
|
||||
op {
|
||||
graph_op_name: "ApplyFtrl"
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "accum"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "linear"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "grad"
|
||||
description: <<END
|
||||
The gradient.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "l1"
|
||||
description: <<END
|
||||
L1 regulariation. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "l2"
|
||||
description: <<END
|
||||
L2 regulariation. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr_power"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "out"
|
||||
description: <<END
|
||||
Same as "var".
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If `True`, updating of the var and accum tensors will be protected
|
||||
by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||
contention.
|
||||
END
|
||||
}
|
||||
summary: "Update \'*var\' according to the Ftrl-proximal scheme."
|
||||
description: <<END
|
||||
accum_new = accum + grad * grad
|
||||
linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var
|
||||
quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2
|
||||
var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0
|
||||
accum = accum_new
|
||||
END
|
||||
}
|
75
tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt
Normal file
75
tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt
Normal file
@ -0,0 +1,75 @@
|
||||
op {
|
||||
graph_op_name: "ApplyFtrlV2"
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "accum"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "linear"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "grad"
|
||||
description: <<END
|
||||
The gradient.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "l1"
|
||||
description: <<END
|
||||
L1 regulariation. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "l2"
|
||||
description: <<END
|
||||
L2 shrinkage regulariation. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr_power"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "out"
|
||||
description: <<END
|
||||
Same as "var".
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If `True`, updating of the var and accum tensors will be protected
|
||||
by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||
contention.
|
||||
END
|
||||
}
|
||||
summary: "Update \'*var\' according to the Ftrl-proximal scheme."
|
||||
description: <<END
|
||||
grad_with_shrinkage = grad + 2 * l2_shrinkage * var
|
||||
accum_new = accum + grad_with_shrinkage * grad_with_shrinkage
|
||||
linear += grad_with_shrinkage +
|
||||
(accum_new^(-lr_power) - accum^(-lr_power)) / lr * var
|
||||
quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2
|
||||
var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0
|
||||
accum = accum_new
|
||||
END
|
||||
}
|
@ -0,0 +1,35 @@
|
||||
op {
|
||||
graph_op_name: "ApplyGradientDescent"
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "alpha"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "delta"
|
||||
description: <<END
|
||||
The change.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "out"
|
||||
description: <<END
|
||||
Same as "var".
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If `True`, the subtraction will be protected by a lock;
|
||||
otherwise the behavior is undefined, but may exhibit less contention.
|
||||
END
|
||||
}
|
||||
summary: "Update \'*var\' by subtracting \'alpha\' * \'delta\' from it."
|
||||
}
|
62
tensorflow/core/api_def/base_api/api_def_ApplyMomentum.pbtxt
Normal file
62
tensorflow/core/api_def/base_api/api_def_ApplyMomentum.pbtxt
Normal file
@ -0,0 +1,62 @@
|
||||
op {
|
||||
graph_op_name: "ApplyMomentum"
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "accum"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "grad"
|
||||
description: <<END
|
||||
The gradient.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "momentum"
|
||||
description: <<END
|
||||
Momentum. Must be a scalar.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "out"
|
||||
description: <<END
|
||||
Same as "var".
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If `True`, updating of the var and accum tensors will be protected
|
||||
by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||
contention.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_nesterov"
|
||||
description: <<END
|
||||
If `True`, the tensor passed to compute grad will be
|
||||
var - lr * momentum * accum, so in the end, the var you get is actually
|
||||
var - lr * momentum * accum.
|
||||
END
|
||||
}
|
||||
summary: "Update \'*var\' according to the momentum scheme. Set use_nesterov = True if you"
|
||||
description: <<END
|
||||
want to use Nesterov momentum.
|
||||
|
||||
accum = accum * momentum + grad
|
||||
var -= lr * accum
|
||||
END
|
||||
}
|
@ -0,0 +1,58 @@
|
||||
op {
|
||||
graph_op_name: "ApplyProximalAdagrad"
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "accum"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "l1"
|
||||
description: <<END
|
||||
L1 regularization. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "l2"
|
||||
description: <<END
|
||||
L2 regularization. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "grad"
|
||||
description: <<END
|
||||
The gradient.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "out"
|
||||
description: <<END
|
||||
Same as "var".
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If True, updating of the var and accum tensors will be protected by
|
||||
a lock; otherwise the behavior is undefined, but may exhibit less contention.
|
||||
END
|
||||
}
|
||||
summary: "Update \'*var\' and \'*accum\' according to FOBOS with Adagrad learning rate."
|
||||
description: <<END
|
||||
accum += grad * grad
|
||||
prox_v = var - lr * grad * (1 / sqrt(accum))
|
||||
var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0}
|
||||
END
|
||||
}
|
@ -0,0 +1,51 @@
|
||||
op {
|
||||
graph_op_name: "ApplyProximalGradientDescent"
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "alpha"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "l1"
|
||||
description: <<END
|
||||
L1 regularization. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "l2"
|
||||
description: <<END
|
||||
L2 regularization. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "delta"
|
||||
description: <<END
|
||||
The change.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "out"
|
||||
description: <<END
|
||||
Same as "var".
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If True, the subtraction will be protected by a lock;
|
||||
otherwise the behavior is undefined, but may exhibit less contention.
|
||||
END
|
||||
}
|
||||
summary: "Update \'*var\' as FOBOS algorithm with fixed learning rate."
|
||||
description: <<END
|
||||
prox_v = var - alpha * delta
|
||||
var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0}
|
||||
END
|
||||
}
|
72
tensorflow/core/api_def/base_api/api_def_ApplyRMSProp.pbtxt
Normal file
72
tensorflow/core/api_def/base_api/api_def_ApplyRMSProp.pbtxt
Normal file
@ -0,0 +1,72 @@
|
||||
op {
|
||||
graph_op_name: "ApplyRMSProp"
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "ms"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "mom"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "rho"
|
||||
description: <<END
|
||||
Decay rate. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "epsilon"
|
||||
description: <<END
|
||||
Ridge term. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "grad"
|
||||
description: <<END
|
||||
The gradient.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "out"
|
||||
description: <<END
|
||||
Same as "var".
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If `True`, updating of the var, ms, and mom tensors is protected
|
||||
by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||
contention.
|
||||
END
|
||||
}
|
||||
summary: "Update \'*var\' according to the RMSProp algorithm."
|
||||
description: <<END
|
||||
Note that in dense implementation of this algorithm, ms and mom will
|
||||
update even if the grad is zero, but in this sparse implementation, ms
|
||||
and mom will not update in iterations during which the grad is zero.
|
||||
|
||||
mean_square = decay * mean_square + (1-decay) * gradient ** 2
|
||||
Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
|
||||
|
||||
ms <- rho * ms_{t-1} + (1-rho) * grad * grad
|
||||
mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
|
||||
var <- var - mom
|
||||
END
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "ApproximateEqual"
|
||||
summary: "Returns the truth value of abs(x-y) < tolerance element-wise."
|
||||
}
|
15
tensorflow/core/api_def/base_api/api_def_ArgMax.pbtxt
Normal file
15
tensorflow/core/api_def/base_api/api_def_ArgMax.pbtxt
Normal file
@ -0,0 +1,15 @@
|
||||
op {
|
||||
graph_op_name: "ArgMax"
|
||||
in_arg {
|
||||
name: "dimension"
|
||||
description: <<END
|
||||
int32 or int64, must be in the range `[-rank(input), rank(input))`.
|
||||
Describes which dimension of the input Tensor to reduce across. For vectors,
|
||||
use dimension = 0.
|
||||
END
|
||||
}
|
||||
summary: "Returns the index with the largest value across dimensions of a tensor."
|
||||
description: <<END
|
||||
Note that in case of ties the identity of the return value is not guaranteed.
|
||||
END
|
||||
}
|
15
tensorflow/core/api_def/base_api/api_def_ArgMin.pbtxt
Normal file
15
tensorflow/core/api_def/base_api/api_def_ArgMin.pbtxt
Normal file
@ -0,0 +1,15 @@
|
||||
op {
|
||||
graph_op_name: "ArgMin"
|
||||
in_arg {
|
||||
name: "dimension"
|
||||
description: <<END
|
||||
int32 or int64, must be in the range `[-rank(input), rank(input))`.
|
||||
Describes which dimension of the input Tensor to reduce across. For vectors,
|
||||
use dimension = 0.
|
||||
END
|
||||
}
|
||||
summary: "Returns the index with the smallest value across dimensions of a tensor."
|
||||
description: <<END
|
||||
Note that in case of ties the identity of the return value is not guaranteed.
|
||||
END
|
||||
}
|
42
tensorflow/core/api_def/base_api/api_def_AsString.pbtxt
Normal file
42
tensorflow/core/api_def/base_api/api_def_AsString.pbtxt
Normal file
@ -0,0 +1,42 @@
|
||||
op {
|
||||
graph_op_name: "AsString"
|
||||
attr {
|
||||
name: "precision"
|
||||
description: <<END
|
||||
The post-decimal precision to use for floating point numbers.
|
||||
Only used if precision > -1.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "scientific"
|
||||
description: <<END
|
||||
Use scientific notation for floating point numbers.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "shortest"
|
||||
description: <<END
|
||||
Use shortest representation (either scientific or standard) for
|
||||
floating point numbers.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "width"
|
||||
description: <<END
|
||||
Pad pre-decimal numbers to this width.
|
||||
Applies to both floating point and integer numbers.
|
||||
Only used if width > -1.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "fill"
|
||||
description: <<END
|
||||
The value to pad if width > -1. If empty, pads with spaces.
|
||||
Another typical value is '0'. String cannot be longer than 1 character.
|
||||
END
|
||||
}
|
||||
summary: "Converts each entry in the given tensor to strings. Supports many numeric"
|
||||
description: <<END
|
||||
types and boolean.
|
||||
END
|
||||
}
|
4
tensorflow/core/api_def/base_api/api_def_Asin.pbtxt
Normal file
4
tensorflow/core/api_def/base_api/api_def_Asin.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "Asin"
|
||||
summary: "Computes asin of x element-wise."
|
||||
}
|
4
tensorflow/core/api_def/base_api/api_def_Asinh.pbtxt
Normal file
4
tensorflow/core/api_def/base_api/api_def_Asinh.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "Asinh"
|
||||
summary: "Computes inverse hyperbolic sine of x element-wise."
|
||||
}
|
26
tensorflow/core/api_def/base_api/api_def_Assert.pbtxt
Normal file
26
tensorflow/core/api_def/base_api/api_def_Assert.pbtxt
Normal file
@ -0,0 +1,26 @@
|
||||
op {
|
||||
graph_op_name: "Assert"
|
||||
in_arg {
|
||||
name: "condition"
|
||||
description: <<END
|
||||
The condition to evaluate.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "data"
|
||||
description: <<END
|
||||
The tensors to print out when condition is false.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "summarize"
|
||||
description: <<END
|
||||
Print this many entries of each tensor.
|
||||
END
|
||||
}
|
||||
summary: "Asserts that the given condition is true."
|
||||
description: <<END
|
||||
If `condition` evaluates to false, print the list of tensors in `data`.
|
||||
`summarize` determines how many entries of the tensors to print.
|
||||
END
|
||||
}
|
42
tensorflow/core/api_def/base_api/api_def_Assign.pbtxt
Normal file
42
tensorflow/core/api_def/base_api/api_def_Assign.pbtxt
Normal file
@ -0,0 +1,42 @@
|
||||
op {
|
||||
graph_op_name: "Assign"
|
||||
in_arg {
|
||||
name: "ref"
|
||||
description: <<END
|
||||
Should be from a `Variable` node. May be uninitialized.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "value"
|
||||
description: <<END
|
||||
The value to be assigned to the variable.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output_ref"
|
||||
description: <<END
|
||||
= Same as "ref". Returned as a convenience for operations that want
|
||||
to use the new value after the variable has been reset.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "validate_shape"
|
||||
description: <<END
|
||||
If true, the operation will validate that the shape
|
||||
of 'value' matches the shape of the Tensor being assigned to. If false,
|
||||
'ref' will take on the shape of 'value'.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If True, the assignment will be protected by a lock;
|
||||
otherwise the behavior is undefined, but may exhibit less contention.
|
||||
END
|
||||
}
|
||||
summary: "Update \'ref\' by assigning \'value\' to it."
|
||||
description: <<END
|
||||
This operation outputs "ref" after the assignment is done.
|
||||
This makes it easier to chain operations that need to use the reset value.
|
||||
END
|
||||
}
|
34
tensorflow/core/api_def/base_api/api_def_AssignAdd.pbtxt
Normal file
34
tensorflow/core/api_def/base_api/api_def_AssignAdd.pbtxt
Normal file
@ -0,0 +1,34 @@
|
||||
op {
|
||||
graph_op_name: "AssignAdd"
|
||||
in_arg {
|
||||
name: "ref"
|
||||
description: <<END
|
||||
Should be from a `Variable` node.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "value"
|
||||
description: <<END
|
||||
The value to be added to the variable.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output_ref"
|
||||
description: <<END
|
||||
= Same as "ref". Returned as a convenience for operations that want
|
||||
to use the new value after the variable has been updated.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If True, the addition will be protected by a lock;
|
||||
otherwise the behavior is undefined, but may exhibit less contention.
|
||||
END
|
||||
}
|
||||
summary: "Update \'ref\' by adding \'value\' to it."
|
||||
description: <<END
|
||||
This operation outputs "ref" after the update is done.
|
||||
This makes it easier to chain operations that need to use the reset value.
|
||||
END
|
||||
}
|
@ -0,0 +1,29 @@
|
||||
op {
|
||||
graph_op_name: "AssignAddVariableOp"
|
||||
in_arg {
|
||||
name: "resource"
|
||||
description: <<END
|
||||
handle to the resource in which to store the variable.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "value"
|
||||
description: <<END
|
||||
the value by which the variable will be incremented.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
the dtype of the value.
|
||||
END
|
||||
}
|
||||
summary: "Adds a value to the current value of a variable."
|
||||
description: <<END
|
||||
Any ReadVariableOp which depends directly or indirectly on this assign is
|
||||
guaranteed to see the incremented value or a subsequent newer one.
|
||||
|
||||
Outputs the incremented value, which can be used to totally order the
|
||||
increments to this variable.
|
||||
END
|
||||
}
|
34
tensorflow/core/api_def/base_api/api_def_AssignSub.pbtxt
Normal file
34
tensorflow/core/api_def/base_api/api_def_AssignSub.pbtxt
Normal file
@ -0,0 +1,34 @@
|
||||
op {
|
||||
graph_op_name: "AssignSub"
|
||||
in_arg {
|
||||
name: "ref"
|
||||
description: <<END
|
||||
Should be from a `Variable` node.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "value"
|
||||
description: <<END
|
||||
The value to be subtracted to the variable.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output_ref"
|
||||
description: <<END
|
||||
= Same as "ref". Returned as a convenience for operations that want
|
||||
to use the new value after the variable has been updated.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If True, the subtraction will be protected by a lock;
|
||||
otherwise the behavior is undefined, but may exhibit less contention.
|
||||
END
|
||||
}
|
||||
summary: "Update \'ref\' by subtracting \'value\' from it."
|
||||
description: <<END
|
||||
This operation outputs "ref" after the update is done.
|
||||
This makes it easier to chain operations that need to use the reset value.
|
||||
END
|
||||
}
|
@ -0,0 +1,29 @@
|
||||
op {
|
||||
graph_op_name: "AssignSubVariableOp"
|
||||
in_arg {
|
||||
name: "resource"
|
||||
description: <<END
|
||||
handle to the resource in which to store the variable.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "value"
|
||||
description: <<END
|
||||
the value by which the variable will be incremented.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
the dtype of the value.
|
||||
END
|
||||
}
|
||||
summary: "Subtracts a value from the current value of a variable."
|
||||
description: <<END
|
||||
Any ReadVariableOp which depends directly or indirectly on this assign is
|
||||
guaranteed to see the incremented value or a subsequent newer one.
|
||||
|
||||
Outputs the incremented value, which can be used to totally order the
|
||||
increments to this variable.
|
||||
END
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
op {
|
||||
graph_op_name: "AssignVariableOp"
|
||||
in_arg {
|
||||
name: "resource"
|
||||
description: <<END
|
||||
handle to the resource in which to store the variable.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "value"
|
||||
description: <<END
|
||||
the value to set the new tensor to use.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
the dtype of the value.
|
||||
END
|
||||
}
|
||||
summary: "Assigns a new value to a variable."
|
||||
description: <<END
|
||||
Any ReadVariableOp with a control dependency on this op is guaranteed to return
|
||||
this value or a subsequent newer value of the variable.
|
||||
END
|
||||
}
|
4
tensorflow/core/api_def/base_api/api_def_Atan.pbtxt
Normal file
4
tensorflow/core/api_def/base_api/api_def_Atan.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "Atan"
|
||||
summary: "Computes atan of x element-wise."
|
||||
}
|
11
tensorflow/core/api_def/base_api/api_def_Atan2.pbtxt
Normal file
11
tensorflow/core/api_def/base_api/api_def_Atan2.pbtxt
Normal file
@ -0,0 +1,11 @@
|
||||
op {
|
||||
graph_op_name: "Atan2"
|
||||
summary: "Computes arctangent of `y/x` element-wise, respecting signs of the arguments."
|
||||
description: <<END
|
||||
This is the angle \( \theta \in [-\pi, \pi] \) such that
|
||||
\[ x = r \cos(\theta) \]
|
||||
and
|
||||
\[ y = r \sin(\theta) \]
|
||||
where \(r = \sqrt(x^2 + y^2) \).
|
||||
END
|
||||
}
|
4
tensorflow/core/api_def/base_api/api_def_Atanh.pbtxt
Normal file
4
tensorflow/core/api_def/base_api/api_def_Atanh.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "Atanh"
|
||||
summary: "Computes inverse hyperbolic tangent of x element-wise."
|
||||
}
|
@ -0,0 +1,63 @@
|
||||
op {
|
||||
graph_op_name: "AudioSpectrogram"
|
||||
in_arg {
|
||||
name: "input"
|
||||
description: <<END
|
||||
Float representation of audio data.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "spectrogram"
|
||||
description: <<END
|
||||
3D representation of the audio frequencies as an image.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "window_size"
|
||||
description: <<END
|
||||
How wide the input window is in samples. For the highest efficiency
|
||||
this should be a power of two, but other values are accepted.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "stride"
|
||||
description: <<END
|
||||
How widely apart the center of adjacent sample windows should be.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "magnitude_squared"
|
||||
description: <<END
|
||||
Whether to return the squared magnitude or just the
|
||||
magnitude. Using squared magnitude can avoid extra calculations.
|
||||
END
|
||||
}
|
||||
summary: "Produces a visualization of audio data over time."
|
||||
description: <<END
|
||||
Spectrograms are a standard way of representing audio information as a series of
|
||||
slices of frequency information, one slice for each window of time. By joining
|
||||
these together into a sequence, they form a distinctive fingerprint of the sound
|
||||
over time.
|
||||
|
||||
This op expects to receive audio data as an input, stored as floats in the range
|
||||
-1 to 1, together with a window width in samples, and a stride specifying how
|
||||
far to move the window between slices. From this it generates a three
|
||||
dimensional output. The lowest dimension has an amplitude value for each
|
||||
frequency during that time slice. The next dimension is time, with successive
|
||||
frequency slices. The final dimension is for the channels in the input, so a
|
||||
stereo audio input would have two here for example.
|
||||
|
||||
This means the layout when converted and saved as an image is rotated 90 degrees
|
||||
clockwise from a typical spectrogram. Time is descending down the Y axis, and
|
||||
the frequency decreases from left to right.
|
||||
|
||||
Each value in the result represents the square root of the sum of the real and
|
||||
imaginary parts of an FFT on the current window of samples. In this way, the
|
||||
lowest dimension represents the power of each frequency in the current window,
|
||||
and adjacent windows are concatenated in the next dimension.
|
||||
|
||||
To get a more intuitive and visual look at what this operation does, you can run
|
||||
tensorflow/examples/wav_to_spectrogram to read in an audio file and save out the
|
||||
resulting spectrogram as a PNG image.
|
||||
END
|
||||
}
|
47
tensorflow/core/api_def/base_api/api_def_AudioSummary.pbtxt
Normal file
47
tensorflow/core/api_def/base_api/api_def_AudioSummary.pbtxt
Normal file
@ -0,0 +1,47 @@
|
||||
op {
|
||||
graph_op_name: "AudioSummary"
|
||||
in_arg {
|
||||
name: "tag"
|
||||
description: <<END
|
||||
Scalar. Used to build the `tag` attribute of the summary values.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "tensor"
|
||||
description: <<END
|
||||
2-D of shape `[batch_size, frames]`.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "summary"
|
||||
description: <<END
|
||||
Scalar. Serialized `Summary` protocol buffer.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "sample_rate"
|
||||
description: <<END
|
||||
The sample rate of the signal in hertz.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "max_outputs"
|
||||
description: <<END
|
||||
Max number of batch elements to generate audio for.
|
||||
END
|
||||
}
|
||||
summary: "Outputs a `Summary` protocol buffer with audio."
|
||||
description: <<END
|
||||
The summary has up to `max_outputs` summary values containing audio. The
|
||||
audio is built from `tensor` which must be 3-D with shape `[batch_size,
|
||||
frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are
|
||||
assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`.
|
||||
|
||||
The `tag` argument is a scalar `Tensor` of type `string`. It is used to
|
||||
build the `tag` of the summary values:
|
||||
|
||||
* If `max_outputs` is 1, the summary value tag is '*tag*/audio'.
|
||||
* If `max_outputs` is greater than 1, the summary value tags are
|
||||
generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc.
|
||||
END
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
op {
|
||||
graph_op_name: "AudioSummaryV2"
|
||||
endpoint {
|
||||
name: "AudioSummary"
|
||||
}
|
||||
in_arg {
|
||||
name: "tag"
|
||||
description: <<END
|
||||
Scalar. Used to build the `tag` attribute of the summary values.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "tensor"
|
||||
description: <<END
|
||||
2-D of shape `[batch_size, frames]`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "sample_rate"
|
||||
description: <<END
|
||||
The sample rate of the signal in hertz.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "summary"
|
||||
description: <<END
|
||||
Scalar. Serialized `Summary` protocol buffer.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "max_outputs"
|
||||
description: <<END
|
||||
Max number of batch elements to generate audio for.
|
||||
END
|
||||
}
|
||||
summary: "Outputs a `Summary` protocol buffer with audio."
|
||||
description: <<END
|
||||
The summary has up to `max_outputs` summary values containing audio. The
|
||||
audio is built from `tensor` which must be 3-D with shape `[batch_size,
|
||||
frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are
|
||||
assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`.
|
||||
|
||||
The `tag` argument is a scalar `Tensor` of type `string`. It is used to
|
||||
build the `tag` of the summary values:
|
||||
|
||||
* If `max_outputs` is 1, the summary value tag is '*tag*/audio'.
|
||||
* If `max_outputs` is greater than 1, the summary value tags are
|
||||
generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc.
|
||||
END
|
||||
}
|
48
tensorflow/core/api_def/base_api/api_def_AvgPool.pbtxt
Normal file
48
tensorflow/core/api_def/base_api/api_def_AvgPool.pbtxt
Normal file
@ -0,0 +1,48 @@
|
||||
op {
|
||||
graph_op_name: "AvgPool"
|
||||
in_arg {
|
||||
name: "value"
|
||||
description: <<END
|
||||
4-D with shape `[batch, height, width, channels]`.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
The average pooled output tensor.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "ksize"
|
||||
description: <<END
|
||||
The size of the sliding window for each dimension of `value`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "strides"
|
||||
description: <<END
|
||||
The stride of the sliding window for each dimension of `value`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "padding"
|
||||
description: <<END
|
||||
The type of padding algorithm to use.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "data_format"
|
||||
description: <<END
|
||||
Specify the data format of the input and output data. With the
|
||||
default format "NHWC", the data is stored in the order of:
|
||||
[batch, in_height, in_width, in_channels].
|
||||
Alternatively, the format could be "NCHW", the data storage order of:
|
||||
[batch, in_channels, in_height, in_width].
|
||||
END
|
||||
}
|
||||
summary: "Performs average pooling on the input."
|
||||
description: <<END
|
||||
Each entry in `output` is the mean of the corresponding size `ksize`
|
||||
window in `value`.
|
||||
END
|
||||
}
|
46
tensorflow/core/api_def/base_api/api_def_AvgPool3D.pbtxt
Normal file
46
tensorflow/core/api_def/base_api/api_def_AvgPool3D.pbtxt
Normal file
@ -0,0 +1,46 @@
|
||||
op {
|
||||
graph_op_name: "AvgPool3D"
|
||||
in_arg {
|
||||
name: "input"
|
||||
description: <<END
|
||||
Shape `[batch, depth, rows, cols, channels]` tensor to pool over.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
The average pooled output tensor.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "ksize"
|
||||
description: <<END
|
||||
1-D tensor of length 5. The size of the window for each dimension of
|
||||
the input tensor. Must have `ksize[0] = ksize[4] = 1`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "strides"
|
||||
description: <<END
|
||||
1-D tensor of length 5. The stride of the sliding window for each
|
||||
dimension of `input`. Must have `strides[0] = strides[4] = 1`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "padding"
|
||||
description: <<END
|
||||
The type of padding algorithm to use.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "data_format"
|
||||
description: <<END
|
||||
The data format of the input and output data. With the
|
||||
default format "NDHWC", the data is stored in the order of:
|
||||
[batch, in_depth, in_height, in_width, in_channels].
|
||||
Alternatively, the format could be "NCDHW", the data storage order is:
|
||||
[batch, in_channels, in_depth, in_height, in_width].
|
||||
END
|
||||
}
|
||||
summary: "Performs 3D average pooling on the input."
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user