Merge branch 'r2.4' into cherrypick_350241208
This commit is contained in:
commit
ed76f960cb
RELEASE.md
tensorflow
api_template.__init__.py
c/experimental/filesystem
core
framework
kernels
ops
python
keras
mixed_precision
saving/saved_model
kernel_tests
ops
tools
ci_build
dockerfiles
dockerfiles/onednn
ubuntu-18.04-mpi-horovod-jupyter.Dockerfileubuntu-18.04-mpi-horovod.Dockerfileubuntu-18.04-mpich-horovod-jupyter.Dockerfileubuntu-18.04-mpich-horovod.Dockerfileubuntu-20.04-mpi-horovod-jupyter.Dockerfileubuntu-20.04-mpi-horovod.Dockerfileubuntu-20.04-mpich-horovod-jupyter.Dockerfileubuntu-20.04-mpich-horovod.Dockerfile
partials/onednn/ubuntu
spec.ymltools.Dockerfilepip_package
@ -24,7 +24,7 @@
|
||||
## Breaking Changes
|
||||
|
||||
* TF Core:
|
||||
* Certain float32 ops run in lower precsion on Ampere based GPUs, including matmuls and convolutions, due to the use of [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/). Specifically, inputs to such ops are rounded from 23 bits of precision to 10
|
||||
* Certain float32 ops run in lower precision on Ampere based GPUs, including matmuls and convolutions, due to the use of [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/). Specifically, inputs to such ops are rounded from 23 bits of precision to 10
|
||||
bits of precision. This is unlikely to cause issues in practice for deep learning models. In some cases, TensorFloat-32 is also used for complex64 ops.
|
||||
TensorFloat-32 can be disabled by running `tf.config.experimental.enable_tensor_float_32_execution(False)`.
|
||||
* The byte layout for string tensors across the C-API has been updated to match TF Core/C++; i.e., a contiguous array of `tensorflow::tstring`/`TF_TString`s.
|
||||
@ -176,7 +176,7 @@
|
||||
* For Keras model, the individual call of `Model.evaluate` uses no cached data for evaluation, while `Model.fit` uses cached data when `validation_data` arg is provided for better performance.
|
||||
* Adds a `save_traces` argument to `model.save`/ `tf.keras.models.save_model` which determines whether the SavedModel format stores the Keras model/layer call functions. The traced functions allow Keras to revive custom models and layers without the original class definition, but if this isn't required the tracing can be disabled with the added option.
|
||||
* The `tf.keras.mixed_precision` API is now non-experimental. The non-experimental API differs from the experimental API in several ways.
|
||||
* `tf.keras.mixed_precision.Policy` no longer takes in a `tf.mixed_precision.experimental.LossScale` in the constructor, and no longer has a `LossScale` associated with it. Instead, `Model.compile` will automatically wrap the optimizer with a `LossScaleOptimizer` using dynamic loss scaling if `Policy.name` is "mixed_float16".
|
||||
* `tf.keras.mixed_precision.Policy` no longer takes in a `tf.mixed_precision.experimental.LossScale` in the constructor, and no longer has a `LossScale` associated with it. Instead, `Model.compile` will automatically wrap the optimizer with a `LossScaleOptimizer` using dynamic loss scaling if `Policy.name` is `mixed_float16`.
|
||||
* `tf.keras.mixed_precision.LossScaleOptimizer`'s constructor takes in different arguments. In particular, it no longer takes in a `LossScale`, and there is no longer a `LossScale` associated with the `LossScaleOptimizer`. Instead, `LossScaleOptimizer` directly implements fixed or dynamic loss scaling. See the documentation of [`tf.keras.mixed_precision.experimental.LossScaleOptimizer`](https://www.tensorflow.org/api_docs/python/tf/keras/mixed_precision/experimental/LossScaleOptimizer?version=nightly) for details on the differences between the experimental `LossScaleOptimizer` and the new non-experimental `LossScaleOptimizer`.
|
||||
* `tf.mixed_precision.experimental.LossScale` and its subclasses are deprecated, as all of its functionality now exists within `tf.keras.mixed_precision.LossScaleOptimizer`
|
||||
|
||||
|
@ -116,7 +116,8 @@ from tensorflow.python.lib.io import file_io as _fi
|
||||
|
||||
# Get sitepackages directories for the python installation.
|
||||
_site_packages_dirs = []
|
||||
_site_packages_dirs += [] if _site.USER_SITE is None else [_site.USER_SITE]
|
||||
if _site.ENABLE_USER_SITE and _site.USER_SITE is not None:
|
||||
_site_packages_dirs += [_site.USER_SITE]
|
||||
_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p]
|
||||
if 'getsitepackages' in dir(_site):
|
||||
_site_packages_dirs += _site.getsitepackages()
|
||||
|
@ -133,7 +133,7 @@ bool ModularFileSystem::FilesExist(const std::vector<std::string>& files,
|
||||
TransactionToken* token,
|
||||
std::vector<Status>* status) {
|
||||
if (ops_->paths_exist == nullptr)
|
||||
return FileSystem::FilesExist(files, status);
|
||||
return FileSystem::FilesExist(files, token, status);
|
||||
|
||||
std::vector<char*> translated_names;
|
||||
translated_names.reserve(files.size());
|
||||
@ -234,7 +234,7 @@ Status ModularFileSystem::DeleteRecursively(const std::string& dirname,
|
||||
"`undeleted_dirs` set to NULL");
|
||||
|
||||
if (ops_->delete_recursively == nullptr)
|
||||
return FileSystem::DeleteRecursively(dirname, undeleted_files,
|
||||
return FileSystem::DeleteRecursively(dirname, token, undeleted_files,
|
||||
undeleted_dirs);
|
||||
|
||||
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -264,7 +264,7 @@ Status ModularFileSystem::DeleteDir(const std::string& dirname,
|
||||
Status ModularFileSystem::RecursivelyCreateDir(const std::string& dirname,
|
||||
TransactionToken* token) {
|
||||
if (ops_->recursively_create_dir == nullptr)
|
||||
return FileSystem::RecursivelyCreateDir(dirname);
|
||||
return FileSystem::RecursivelyCreateDir(dirname, token);
|
||||
|
||||
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
|
||||
std::string translated_name = TranslateName(dirname);
|
||||
@ -312,7 +312,8 @@ Status ModularFileSystem::Stat(const std::string& fname,
|
||||
|
||||
Status ModularFileSystem::IsDirectory(const std::string& name,
|
||||
TransactionToken* token) {
|
||||
if (ops_->is_directory == nullptr) return FileSystem::IsDirectory(name);
|
||||
if (ops_->is_directory == nullptr)
|
||||
return FileSystem::IsDirectory(name, token);
|
||||
|
||||
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
|
||||
std::string translated_name = TranslateName(name);
|
||||
@ -362,7 +363,8 @@ Status ModularFileSystem::RenameFile(const std::string& src,
|
||||
Status ModularFileSystem::CopyFile(const std::string& src,
|
||||
const std::string& target,
|
||||
TransactionToken* token) {
|
||||
if (ops_->copy_file == nullptr) return FileSystem::CopyFile(src, target);
|
||||
if (ops_->copy_file == nullptr)
|
||||
return FileSystem::CopyFile(src, target, token);
|
||||
|
||||
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
|
||||
std::string translated_src = TranslateName(src);
|
||||
|
@ -83,10 +83,17 @@ Status LookupInterface::CheckFindArguments(const Tensor& key,
|
||||
const Tensor& default_value) {
|
||||
TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value));
|
||||
TF_RETURN_IF_ERROR(CheckKeyShape(key.shape()));
|
||||
if (default_value.shape() != value_shape()) {
|
||||
TensorShape fullsize_value_shape = key.shape();
|
||||
for (int i = 0; i < key_shape().dims(); ++i) {
|
||||
fullsize_value_shape.RemoveDim(fullsize_value_shape.dims() - 1);
|
||||
}
|
||||
fullsize_value_shape.AppendShape(value_shape());
|
||||
if (default_value.shape() != value_shape() &&
|
||||
default_value.shape() != fullsize_value_shape) {
|
||||
return errors::InvalidArgument(
|
||||
"Expected shape ", value_shape().DebugString(),
|
||||
" for default value, got ", default_value.shape().DebugString());
|
||||
"Expected shape ", value_shape().DebugString(), " or ",
|
||||
fullsize_value_shape.DebugString(), " for default value, got ",
|
||||
default_value.shape().DebugString());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -128,7 +128,8 @@ class LookupInterface : public ResourceBase {
|
||||
// requirements are satisfied, otherwise it returns InvalidArgument:
|
||||
// - DataType of the tensor keys equals to the table key_dtype
|
||||
// - DataType of the tensor default_value equals to the table value_dtype
|
||||
// - the default_value tensor shape matches the table's value shape.
|
||||
// - the default_value tensor has the required shape given keys and the
|
||||
// tables's value shape.
|
||||
Status CheckFindArguments(const Tensor& keys, const Tensor& default_value);
|
||||
|
||||
string DebugString() const override {
|
||||
|
@ -56,14 +56,25 @@ class MutableHashTableOfScalars final : public LookupInterface {
|
||||
|
||||
Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
|
||||
const Tensor& default_value) override {
|
||||
const V default_val = default_value.flat<V>()(0);
|
||||
const auto key_values = key.flat<K>();
|
||||
auto value_values = value->flat<V>();
|
||||
const auto default_flat = default_value.flat<V>();
|
||||
|
||||
int64 total = value_values.size();
|
||||
int64 default_total = default_flat.size();
|
||||
bool is_full_size_default = (total == default_total);
|
||||
|
||||
tf_shared_lock l(mu_);
|
||||
for (int64 i = 0; i < key_values.size(); ++i) {
|
||||
// is_full_size_default is true:
|
||||
// Each key has an independent default value, key_values(i)
|
||||
// corresponding uses default_flat(i) as its default value.
|
||||
//
|
||||
// is_full_size_default is false:
|
||||
// All keys will share the default_flat(0) as default value.
|
||||
value_values(i) = gtl::FindWithDefault(
|
||||
table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
|
||||
table_, SubtleMustCopyIfIntegral(key_values(i)),
|
||||
is_full_size_default ? default_flat(i) : default_flat(0));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -173,11 +184,15 @@ class MutableHashTableOfTensors final : public LookupInterface {
|
||||
|
||||
Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
|
||||
const Tensor& default_value) override {
|
||||
const auto default_flat = default_value.flat<V>();
|
||||
const auto default_flat = default_value.flat_inner_dims<V, 2>();
|
||||
const auto key_values = key.flat<K>();
|
||||
auto value_values = value->flat_inner_dims<V, 2>();
|
||||
int64 value_dim = value_shape_.dim_size(0);
|
||||
|
||||
int64 total = value_values.size();
|
||||
int64 default_total = default_flat.size();
|
||||
bool is_full_size_default = (total == default_total);
|
||||
|
||||
tf_shared_lock l(mu_);
|
||||
for (int64 i = 0; i < key_values.size(); ++i) {
|
||||
ValueArray* value_vec =
|
||||
@ -187,8 +202,15 @@ class MutableHashTableOfTensors final : public LookupInterface {
|
||||
value_values(i, j) = value_vec->at(j);
|
||||
}
|
||||
} else {
|
||||
// is_full_size_default is true:
|
||||
// Each key has an independent default value, key_values(i)
|
||||
// corresponding uses default_flat(i) as its default value.
|
||||
//
|
||||
// is_full_size_default is false:
|
||||
// All keys will share the default_flat(0) as default value.
|
||||
for (int64 j = 0; j < value_dim; j++) {
|
||||
value_values(i, j) = default_flat(j);
|
||||
value_values(i, j) =
|
||||
is_full_size_default ? default_flat(i, j) : default_flat(0, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -169,10 +169,6 @@ REGISTER_OP("LookupTableFindV2")
|
||||
ShapeHandle handle;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
|
||||
|
||||
// Default value must be scalar or vector.
|
||||
ShapeHandle keys;
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &keys));
|
||||
|
||||
ShapeAndType value_shape_and_type;
|
||||
TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
|
||||
c,
|
||||
|
@ -25,7 +25,6 @@ namespace {
|
||||
TEST(LookupOpsTest, LookupTableFindV2_ShapeFn) {
|
||||
ShapeInferenceTestOp op("LookupTableFindV2");
|
||||
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[?];?;?");
|
||||
INFER_ERROR("Shape must be at most rank 1 but is rank 2", op, "[];?;[1,1]");
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", "LookupTableFindV2")
|
||||
.Input({"table_handle", 0, DT_RESOURCE})
|
||||
.Input({"keys", 0, DT_INT64})
|
||||
|
@ -57,12 +57,11 @@ class AutoCastVariable(variables.Variable, core.Tensor):
|
||||
called.
|
||||
"""
|
||||
|
||||
def __init__(self, variable, op=None):
|
||||
def __init__(self, variable):
|
||||
"""Creates an AutoCastVariable instance.
|
||||
|
||||
Args:
|
||||
variable: A floating-point resource variable to wrap.
|
||||
op: Optional operation of this variable.
|
||||
|
||||
Raises:
|
||||
ValueError: If `variable` is not a floating-point resource variable
|
||||
@ -74,7 +73,11 @@ class AutoCastVariable(variables.Variable, core.Tensor):
|
||||
raise ValueError('variable must be a floating point variable but has '
|
||||
'type: %s' % variable.dtype.name)
|
||||
self._variable = variable
|
||||
self._op = op
|
||||
# 'delegate' means AutoCastVariable.op return self._variable.op, which will
|
||||
# raise an AttributeError in Eager (as intended). If set to any other value,
|
||||
# AutoCastVariable.op returns that value instead, which is used to set the
|
||||
# op attribute in AutoCastVariable.assign().
|
||||
self._op = 'delegate'
|
||||
|
||||
def _should_cast(self):
|
||||
"""Returns True if this variable should be casted when accessed."""
|
||||
@ -199,10 +202,18 @@ class AutoCastVariable(variables.Variable, core.Tensor):
|
||||
use_locking=None,
|
||||
name=None,
|
||||
read_value=True):
|
||||
# TODO(b/146181571): This logic can be simplified once
|
||||
# DistributedVariable.assign returns a DistributedVariable. Currently for
|
||||
# MirroredStrategy, it returns a Mirrored value.
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
assign_op = update_fn(value, use_locking, name, False)
|
||||
if read_value:
|
||||
return create_autocast_variable(self._variable, op=assign_op)
|
||||
# We create a new AutoCastVariable with the same underlying tf.Variable.
|
||||
# The new AutoCastVariable is identical except the 'op' attribute is
|
||||
# defined. This matches the behavior of tf.Variable.assign.
|
||||
var = create_autocast_variable(self._variable)
|
||||
var._op = assign_op # pylint:disable=protected-access
|
||||
return var
|
||||
return assign_op
|
||||
|
||||
# Fallback to wrapping the returned variable in graph mode if possible
|
||||
@ -298,9 +309,9 @@ class AutoCastVariable(variables.Variable, core.Tensor):
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
if self._op is not None:
|
||||
return self._op
|
||||
if self._op == 'delegate':
|
||||
return self._variable.op
|
||||
return self._op
|
||||
|
||||
def _as_graph_element(self):
|
||||
graph_element = self._variable._as_graph_element() # pylint:disable=protected-access
|
||||
@ -469,7 +480,7 @@ ops.register_tensor_conversion_function(AutoCastVariable,
|
||||
AutoCastVariable._dense_var_to_tensor) # pylint:disable=protected-access
|
||||
|
||||
|
||||
def create_autocast_variable(variable, op=None):
|
||||
def create_autocast_variable(variable):
|
||||
"""Creates an AutoCastVariable that wraps another variable.
|
||||
|
||||
This typically just returns `AutoCastVariable(variable)`. But, if the variable
|
||||
@ -481,14 +492,13 @@ def create_autocast_variable(variable, op=None):
|
||||
|
||||
Args:
|
||||
variable: A floating-point resource variable to wrap.
|
||||
op: Optional operation of this variable.
|
||||
|
||||
Returns:
|
||||
An AutoCastVariable that wraps the variable.
|
||||
"""
|
||||
if not isinstance(variable, (distribute_values.DistributedVariable,
|
||||
ps_distribute_values.AggregatingVariable)):
|
||||
return AutoCastVariable(variable, op=op)
|
||||
return AutoCastVariable(variable)
|
||||
|
||||
class AutoCastDistributedVariable(AutoCastVariable, variable.__class__):
|
||||
"""An AutoCastVariable that also subclasses from variable.__class__.
|
||||
@ -511,7 +521,7 @@ def create_autocast_variable(variable, op=None):
|
||||
).format(v=self)
|
||||
# pylint: enable=missing-format-attribute
|
||||
|
||||
return AutoCastDistributedVariable(variable, op=op)
|
||||
return AutoCastDistributedVariable(variable)
|
||||
|
||||
|
||||
class enable_auto_cast_variables(object): # pylint:disable=invalid-name
|
||||
|
@ -37,7 +37,14 @@ from tensorflow.python.framework import indexed_slices
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_combinations as combinations
|
||||
from tensorflow.python.keras.mixed_precision import autocast_variable
|
||||
from tensorflow.python.keras.optimizer_v2 import adadelta
|
||||
from tensorflow.python.keras.optimizer_v2 import adagrad
|
||||
from tensorflow.python.keras.optimizer_v2 import adam
|
||||
from tensorflow.python.keras.optimizer_v2 import adamax
|
||||
from tensorflow.python.keras.optimizer_v2 import ftrl
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
|
||||
from tensorflow.python.keras.optimizer_v2 import nadam
|
||||
from tensorflow.python.keras.optimizer_v2 import rmsprop
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -352,11 +359,28 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllClose(5., self.evaluate(run_assign()))
|
||||
|
||||
@ds_combinations.generate(maybe_distribute)
|
||||
def test_assign_op(self, distribution):
|
||||
def test_op_attribute(self, distribution):
|
||||
with distribution.scope():
|
||||
x = get_var(0., dtypes.float32)
|
||||
x = autocast_variable.create_autocast_variable(x)
|
||||
|
||||
# Variable.op raises an AttributeError in Eager mode and is an op in graph
|
||||
# mode. Variable.assign(...).op is None in Eager mode and an op in Graph
|
||||
# mode or a tf.function. We test this is also true of AutoCastVariable.
|
||||
if context.executing_eagerly():
|
||||
with self.assertRaisesRegex(
|
||||
AttributeError,
|
||||
'Tensor.op is meaningless when eager execution is enabled'):
|
||||
x.op # pylint: disable=pointless-statement
|
||||
self.assertIsNone(x.assign(1.0).op)
|
||||
self.assertIsNone(x.assign_add(1.0).op)
|
||||
self.assertIsNone(x.assign_sub(1.0).op)
|
||||
else:
|
||||
self.assertIsNotNone(x.op)
|
||||
self.assertIsNotNone(x.assign(1.0).op)
|
||||
self.assertIsNotNone(x.assign_add(1.0).op)
|
||||
self.assertIsNotNone(x.assign_sub(1.0).op)
|
||||
|
||||
@def_function.function
|
||||
def func():
|
||||
self.assertIsNotNone(x.assign(1.0).op)
|
||||
@ -503,24 +527,50 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
'dtype_to_cast_to=float32 '
|
||||
'inner_variable=MirroredVariable.*>')
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('v1', gradient_descent_v1.GradientDescentOptimizer),
|
||||
('v2', gradient_descent_v2.SGD))
|
||||
def test_optimizer(self, optimizer_class):
|
||||
@ds_combinations.generate(combinations.combine(
|
||||
optimizer_class=[
|
||||
adadelta.Adadelta,
|
||||
adagrad.Adagrad,
|
||||
adam.Adam,
|
||||
adamax.Adamax,
|
||||
ftrl.Ftrl,
|
||||
gradient_descent_v2.SGD,
|
||||
nadam.Nadam,
|
||||
rmsprop.RMSprop,
|
||||
gradient_descent_v1.GradientDescentOptimizer
|
||||
],
|
||||
use_tf_function=[False, True]))
|
||||
def test_optimizer(self, optimizer_class, use_tf_function):
|
||||
if use_tf_function and not context.executing_eagerly():
|
||||
self.skipTest('Test does not support graph mode with tf.function')
|
||||
x = get_var(1., dtypes.float32)
|
||||
x = autocast_variable.create_autocast_variable(x)
|
||||
opt = optimizer_class(1.)
|
||||
y = get_var(1., dtypes.float32)
|
||||
opt = optimizer_class(learning_rate=1.)
|
||||
|
||||
@def_function.function
|
||||
def f():
|
||||
opt.minimize(lambda: x + 1., var_list=[x])
|
||||
# Minimize both the AutoCastVariable and the normal tf.Variable. Both
|
||||
# variables should be updated to the same value.
|
||||
op = opt.minimize(lambda: x + y, var_list=[x, y])
|
||||
return None if ops.executing_eagerly_outside_functions() else op
|
||||
|
||||
if use_tf_function:
|
||||
f = def_function.function(f)
|
||||
|
||||
if context.executing_eagerly():
|
||||
f()
|
||||
else:
|
||||
op = f() # pylint: disable=assignment-from-no-return
|
||||
op = f()
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(op)
|
||||
# Assert the AutoCastVariable has changed from its initial value
|
||||
self.assertNotEqual(self.evaluate(x), 1.)
|
||||
# Assert AutoCastVariable is updated correctly by comparing it to the normal
|
||||
# variable
|
||||
self.assertAlmostEqual(self.evaluate(x), self.evaluate(y))
|
||||
if optimizer_class in (gradient_descent_v2.SGD,
|
||||
gradient_descent_v1.GradientDescentOptimizer):
|
||||
# With SGD, the variables decreases by exactly 1
|
||||
self.assertEqual(self.evaluate(x), 0)
|
||||
|
||||
|
||||
|
@ -135,7 +135,7 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
|
||||
|
||||
# Recreate layers and metrics using the info stored in the metadata.
|
||||
keras_loader = KerasObjectLoader(metadata, object_graph_def)
|
||||
keras_loader.load_layers()
|
||||
keras_loader.load_layers(compile=compile)
|
||||
|
||||
# Generate a dictionary of all loaded nodes.
|
||||
nodes_to_load = {'root': None}
|
||||
@ -360,7 +360,7 @@ class KerasObjectLoader(object):
|
||||
obj_child, child_proto, child_id)
|
||||
self.loaded_nodes[child_id] = obj_child, setter
|
||||
|
||||
def load_layers(self):
|
||||
def load_layers(self, compile=True): # pylint: disable=redefined-builtin
|
||||
"""Load all layer nodes from the metadata."""
|
||||
# Load metrics after models and layers, since it's likely that models
|
||||
# and layers will create the metric when initialized (this avoids wasting
|
||||
@ -376,9 +376,21 @@ class KerasObjectLoader(object):
|
||||
node_metadata.metadata)
|
||||
|
||||
for node_metadata in metric_list:
|
||||
try:
|
||||
self.loaded_nodes[node_metadata.node_id] = self._load_layer(
|
||||
node_metadata.node_id, node_metadata.identifier,
|
||||
node_metadata.metadata)
|
||||
except ValueError:
|
||||
# Metrics are only needed when the model is compiled later. We ignore
|
||||
# errors when trying to load custom metrics when `compile=False` until
|
||||
# custom metrics are serialized properly (b/135550038).
|
||||
if compile:
|
||||
raise
|
||||
logging.warning('Unable to restore custom metric. Please ensure that '
|
||||
'the layer implements `get_config` and `from_config` '
|
||||
'when saving. In addition, please use the '
|
||||
'`custom_objects` arg when calling `load_model()`.')
|
||||
|
||||
|
||||
def _load_layer(self, node_id, identifier, metadata):
|
||||
"""Load a single layer from a SavedUserObject proto."""
|
||||
|
@ -1147,6 +1147,26 @@ class MetricTest(test.TestCase, parameterized.TestCase):
|
||||
self._test_metric_save_and_load(
|
||||
metric, self._save_model_dir(), 1, test_sample_weight=False)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_custom_metric_model(self):
|
||||
|
||||
class CustomMetric(keras.metrics.MeanSquaredError):
|
||||
pass
|
||||
|
||||
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
|
||||
model.compile(
|
||||
loss='mse',
|
||||
optimizer='rmsprop',
|
||||
metrics=[CustomMetric()])
|
||||
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
with self.assertRaisesRegex(ValueError, 'metric'):
|
||||
keras_load.load(saved_model_dir)
|
||||
|
||||
keras_load.load(saved_model_dir, compile=False)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -3375,6 +3375,71 @@ class MutableHashTableOpTest(test.TestCase):
|
||||
result = self.evaluate(output)
|
||||
self.assertAllEqual([[0, 1], [-1, -1]], result)
|
||||
|
||||
def testMutableHashTableFindWithInvalidShapeDefaultValue(self):
|
||||
default_val = [-1, -1]
|
||||
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
||||
default_val)
|
||||
|
||||
input_string = constant_op.constant([["brain", "salad"],
|
||||
["tank", "tarkus"]])
|
||||
|
||||
invalid_default_val = constant_op.constant(
|
||||
[[-2, -3], [-4, -5], [-6, -7], [-8, -9]], dtypes.int64)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
(ValueError, errors_impl.InvalidArgumentError),
|
||||
"Expected shape \[2\] or \[2,2,2\] for default value, got \[4,2]"):
|
||||
self.evaluate(table.lookup(input_string, invalid_default_val))
|
||||
|
||||
invalid_default_val = constant_op.constant([[[-2, -3], [-4, -5]]],
|
||||
dtypes.int64)
|
||||
with self.assertRaisesRegex(
|
||||
(ValueError, errors_impl.InvalidArgumentError),
|
||||
"Expected shape \[2\] or \[2,2,2\] for default value, got \[1,2,2\]"):
|
||||
self.evaluate(table.lookup(input_string, invalid_default_val))
|
||||
|
||||
def testMutableHashTableFindHighRankScalarWithDynamicDefaultValue(self):
|
||||
default_val = -1
|
||||
keys = constant_op.constant(["brain", "salad", "surgery"])
|
||||
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
||||
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
||||
default_val)
|
||||
|
||||
self.evaluate(table.insert(keys, values))
|
||||
self.assertAllEqual(3, self.evaluate(table.size()))
|
||||
|
||||
input_string = constant_op.constant([["brain", "salad"],
|
||||
["tank", "tarkus"]])
|
||||
|
||||
dynamic_default_val = constant_op.constant([[-2, -3], [-4, -5]],
|
||||
dtypes.int64)
|
||||
output = table.lookup(input_string, dynamic_default_val)
|
||||
self.assertAllEqual([2, 2], output.get_shape())
|
||||
|
||||
result = self.evaluate(output)
|
||||
self.assertAllEqual([[0, 1], [-4, -5]], result)
|
||||
|
||||
def testMutableHashTableFindHighRankVectorWithDynamicDefaultValue(self):
|
||||
default_val = [-1, -1]
|
||||
keys = constant_op.constant(["brain", "salad", "surgery"])
|
||||
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
|
||||
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
||||
default_val)
|
||||
|
||||
self.evaluate(table.insert(keys, values))
|
||||
self.assertAllEqual(3, self.evaluate(table.size()))
|
||||
|
||||
input_string = constant_op.constant([["brain", "salad"],
|
||||
["tank", "tarkus"]])
|
||||
|
||||
dynamic_default_val = constant_op.constant(
|
||||
[[[-2, -3], [-4, -5]], [[-6, -7], [-8, -9]]], dtypes.int64)
|
||||
output = table.lookup(input_string, dynamic_default_val)
|
||||
self.assertAllEqual([2, 2, 2], output.get_shape())
|
||||
|
||||
result = self.evaluate(output)
|
||||
self.assertAllEqual([[[0, 1], [2, 3]], [[-6, -7], [-8, -9]]], result)
|
||||
|
||||
def testMutableHashTableInsertHighRank(self):
|
||||
default_val = -1
|
||||
keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])
|
||||
|
@ -1206,6 +1206,26 @@ class SummaryOpsTest(test_util.TensorFlowTestCase):
|
||||
# Reset to default state for other tests.
|
||||
summary_ops.set_step(None)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testTrace_withProfiler(self):
|
||||
|
||||
@def_function.function
|
||||
def f():
|
||||
x = constant_op.constant(2)
|
||||
y = constant_op.constant(3)
|
||||
return x**y
|
||||
|
||||
assert context.executing_eagerly()
|
||||
logdir = self.get_temp_dir()
|
||||
writer = summary_ops.create_file_writer(logdir)
|
||||
summary_ops.trace_on(graph=True, profiler=True)
|
||||
profiler_outdir = self.get_temp_dir()
|
||||
with writer.as_default():
|
||||
f()
|
||||
summary_ops.trace_export(
|
||||
name='foo', step=1, profiler_outdir=profiler_outdir)
|
||||
writer.close()
|
||||
|
||||
|
||||
def events_from_file(filepath):
|
||||
"""Returns all events in a single event file.
|
||||
|
@ -1849,7 +1849,7 @@ class MutableHashTable(LookupInterface):
|
||||
|
||||
return op
|
||||
|
||||
def lookup(self, keys, name=None):
|
||||
def lookup(self, keys, dynamic_default_values=None, name=None):
|
||||
"""Looks up `keys` in a table, outputs the corresponding values.
|
||||
|
||||
The `default_value` is used for keys not present in the table.
|
||||
@ -1857,6 +1857,23 @@ class MutableHashTable(LookupInterface):
|
||||
Args:
|
||||
keys: Keys to look up. Can be a tensor of any shape. Must match the
|
||||
table's key_dtype.
|
||||
dynamic_default_values: The values to use if a key is missing in the
|
||||
table. If None (by default), the `table.default_value` will be used.
|
||||
Shape of `dynamic_default_values` must be same with
|
||||
`table.default_value` or the lookup result tensor.
|
||||
In the latter case, each key will have a different default value.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
keys = [0, 1, 3]
|
||||
dynamic_default_values = [[1, 3, 4], [2, 3, 9], [8, 3, 0]]
|
||||
|
||||
# The key '0' will use [1, 3, 4] as default value.
|
||||
# The key '1' will use [2, 3, 9] as default value.
|
||||
# The key '3' will use [8, 3, 0] as default value.
|
||||
```
|
||||
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
@ -1870,8 +1887,9 @@ class MutableHashTable(LookupInterface):
|
||||
(self.resource_handle, keys, self._default_value)):
|
||||
keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
|
||||
with ops.colocate_with(self.resource_handle):
|
||||
values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys,
|
||||
self._default_value)
|
||||
values = gen_lookup_ops.lookup_table_find_v2(
|
||||
self.resource_handle, keys, dynamic_default_values
|
||||
if dynamic_default_values is not None else self._default_value)
|
||||
return values
|
||||
|
||||
def insert(self, keys, values, name=None):
|
||||
|
@ -1370,4 +1370,7 @@ def trace_off():
|
||||
context.context().disable_run_metadata()
|
||||
|
||||
if profiler:
|
||||
try:
|
||||
_profiler.stop()
|
||||
except _profiler.ProfilerNotRunningError:
|
||||
pass
|
||||
|
@ -0,0 +1,80 @@
|
||||
# Dockerfile for Ubuntu 16.04 manylinux2010 custom ops with GPU.
|
||||
|
||||
FROM nvidia/cuda:11.0-cudnn8-devel-ubuntu16.04 as devtoolset
|
||||
|
||||
LABEL maintainer="Amit Patankar <amitpatankar@google.com>"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && apt-get install -y \
|
||||
cpio \
|
||||
file \
|
||||
flex \
|
||||
g++ \
|
||||
make \
|
||||
rpm2cpio \
|
||||
unar \
|
||||
wget \
|
||||
&& \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ADD devtoolset/fixlinks.sh fixlinks.sh
|
||||
ADD devtoolset/build_devtoolset.sh build_devtoolset.sh
|
||||
ADD devtoolset/rpm-patch.sh rpm-patch.sh
|
||||
|
||||
# Set up a sysroot for glibc 2.12 / libstdc++ 4.4 / devtoolset-7 in /dt7.
|
||||
RUN /build_devtoolset.sh devtoolset-7 /dt7
|
||||
# Set up a sysroot for glibc 2.12 / libstdc++ 4.4 / devtoolset-8 in /dt8.
|
||||
RUN /build_devtoolset.sh devtoolset-8 /dt8
|
||||
|
||||
# TODO(klimek): Split up into two different docker images.
|
||||
FROM nvidia/cuda:11.0-cudnn8-devel-ubuntu16.04
|
||||
|
||||
LABEL maintainer="Amit Patankar <amitpatankar@google.com>"
|
||||
|
||||
COPY --from=devtoolset /dt7 /dt7
|
||||
COPY --from=devtoolset /dt8 /dt8
|
||||
|
||||
# Install TensorRT.
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libnvinfer-dev=7.1.3-1+cuda11.0 \
|
||||
libnvinfer7=7.1.3-1+cuda11.0 \
|
||||
libnvinfer-plugin-dev=7.1.3-1+cuda11.0 \
|
||||
libnvinfer-plugin7=7.1.3-1+cuda11.0 \
|
||||
&& \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy and run the install scripts.
|
||||
COPY install/*.sh /install/
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
RUN /install/install_bootstrap_deb_packages.sh
|
||||
RUN /install/install_deb_packages.sh
|
||||
RUN /install/install_clang.sh
|
||||
RUN /install/install_bazel.sh
|
||||
RUN /install/install_buildifier.sh
|
||||
|
||||
ENV TF_NEED_CUDA=1
|
||||
|
||||
# Install python 3.6.
|
||||
RUN add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get update && apt-get install -y \
|
||||
python3.6 python3.6-dev python3-pip python3.6-venv && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
python3.6 -m pip install pip --upgrade && \
|
||||
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.6 0
|
||||
|
||||
# Install python 3.7
|
||||
RUN /install/install_python37.sh
|
||||
|
||||
# Install pip3.5
|
||||
RUN wget https://bootstrap.pypa.io/get-pip.py && python3.5 get-pip.py && rm get-pip.py
|
||||
|
||||
RUN /install/install_pip_packages.sh
|
||||
RUN /install/install_auditwheel.sh
|
||||
|
||||
# Make python3.6 the default python version
|
||||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.6 0
|
||||
|
||||
# Install given tensorflow or tf-nightly version, if not specified, install the # latest official release
|
||||
ARG TF_PACKAGE=tensorflow
|
||||
ARG TF_PACKAGE_VERSION=
|
||||
RUN pip3 install ${TF_PACKAGE}${TF_PACKAGE_VERSION:+==${TF_PACKAGE_VERSION}}
|
@ -22,16 +22,16 @@ pip --version
|
||||
pip install portpicker
|
||||
pip install *.whl
|
||||
|
||||
# Make bazel version the same as the env that invokes this script
|
||||
rm -rf ~/bazel
|
||||
mkdir ~/bazel
|
||||
pushd ~/bazel
|
||||
wget https://github.com/bazelbuild/bazel/releases/download/"${BAZEL_VERSION}"/bazel-"${BAZEL_VERSION}"-installer-linux-x86_64.sh
|
||||
chmod +x bazel-*.sh
|
||||
./bazel-"${BAZEL_VERSION}"-installer-linux-x86_64.sh --user
|
||||
rm bazel-"${BAZEL_VERSION}"-installer-linux-x86_64.sh
|
||||
PATH="/bazel_pip/bin:$PATH"
|
||||
popd
|
||||
# Install bazelisk
|
||||
rm -rf ~/bin/bazel
|
||||
mkdir ~/bin/bazel
|
||||
wget --no-verbose -O "~/bin/bazel" \
|
||||
"https://github.com/bazelbuild/bazelisk/releases/download/v1.3.0/bazelisk-linux-amd64"
|
||||
chmod u+x "~/bin/bazel"
|
||||
if [[ ! ":$PATH:" =~ :"~"/bin/?: ]]; then
|
||||
PATH="~/bin:$PATH"
|
||||
fi
|
||||
which bazel
|
||||
bazel version
|
||||
|
||||
# Use default configuration
|
||||
|
@ -85,6 +85,7 @@ ARG HOROVOD_VERSION=
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \
|
||||
build-essential \
|
||||
cmake \
|
||||
g++-8 \
|
||||
gcc-8 \
|
||||
python3-dev
|
||||
|
@ -85,6 +85,7 @@ ARG HOROVOD_VERSION=
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \
|
||||
build-essential \
|
||||
cmake \
|
||||
g++-8 \
|
||||
gcc-8 \
|
||||
python3-dev
|
||||
|
@ -81,6 +81,7 @@ ARG HOROVOD_VERSION=
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \
|
||||
build-essential \
|
||||
cmake \
|
||||
g++-8 \
|
||||
gcc-8 \
|
||||
python3-dev
|
||||
|
@ -81,6 +81,7 @@ ARG HOROVOD_VERSION=
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \
|
||||
build-essential \
|
||||
cmake \
|
||||
g++-8 \
|
||||
gcc-8 \
|
||||
python3-dev
|
||||
|
@ -95,6 +95,7 @@ ARG HOROVOD_VERSION=
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \
|
||||
build-essential \
|
||||
cmake \
|
||||
g++-8 \
|
||||
gcc-8 \
|
||||
${PYTHON}-dev
|
||||
|
@ -95,6 +95,7 @@ ARG HOROVOD_VERSION=
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \
|
||||
build-essential \
|
||||
cmake \
|
||||
g++-8 \
|
||||
gcc-8 \
|
||||
${PYTHON}-dev
|
||||
|
@ -91,6 +91,7 @@ ARG HOROVOD_VERSION=
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \
|
||||
build-essential \
|
||||
cmake \
|
||||
g++-8 \
|
||||
gcc-8 \
|
||||
${PYTHON}-dev
|
||||
|
@ -91,6 +91,7 @@ ARG HOROVOD_VERSION=
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \
|
||||
build-essential \
|
||||
cmake \
|
||||
g++-8 \
|
||||
gcc-8 \
|
||||
${PYTHON}-dev
|
||||
|
@ -6,6 +6,7 @@ ARG HOROVOD_VERSION=
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \
|
||||
build-essential \
|
||||
cmake \
|
||||
g++-8 \
|
||||
gcc-8 \
|
||||
${PYTHON}-dev
|
||||
|
@ -6,6 +6,7 @@ ARG HOROVOD_VERSION=
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \
|
||||
build-essential \
|
||||
cmake \
|
||||
g++-8 \
|
||||
gcc-8 \
|
||||
python3-dev
|
||||
|
@ -421,6 +421,7 @@ slice_sets:
|
||||
- TF_PACKAGE=intel-tensorflow
|
||||
- UBUNTU_VERSION=20.04
|
||||
- PYTHON=python3.7
|
||||
- DEBIAN_FRONTEND="noninteractive"
|
||||
|
||||
ubuntu-devel-onednn:
|
||||
- add_to_name: "-16.04-devel"
|
||||
|
@ -22,7 +22,7 @@ LABEL maintainer="Austin Anderson <angerson@google.com>"
|
||||
|
||||
RUN apt-get update && apt-get install -y python3 python3-pip bash curl
|
||||
RUN curl -sSL https://get.docker.com/ | sh
|
||||
RUN pip3 install --upgrade pip setuptools pyyaml absl-py cerberus docker
|
||||
RUN pip3 install --upgrade pip setuptools pyyaml absl-py cerberus 'docker<=4.3.0'
|
||||
|
||||
WORKDIR /tf
|
||||
VOLUME ["/tf"]
|
||||
|
@ -132,10 +132,10 @@ function prepare_src() {
|
||||
unzip -o -q ./bazel-bin/tensorflow/tools/pip_package/simple_console_for_windows.zip -d ./bazel-bin/tensorflow/tools/pip_package/simple_console_for_window_unzip
|
||||
echo "Unzip finished."
|
||||
# runfiles structure after unzip the python binary
|
||||
cp \
|
||||
cp -L \
|
||||
bazel-bin/tensorflow/tools/pip_package/simple_console_for_window_unzip/runfiles/org_tensorflow/LICENSE \
|
||||
"${TMPDIR}"
|
||||
cp -R \
|
||||
cp -LR \
|
||||
bazel-bin/tensorflow/tools/pip_package/simple_console_for_window_unzip/runfiles/org_tensorflow/tensorflow \
|
||||
"${TMPDIR}"
|
||||
cp_external \
|
||||
@ -149,10 +149,10 @@ function prepare_src() {
|
||||
RUNFILES=bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow
|
||||
if [ -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/external ]; then
|
||||
# Old-style runfiles structure (--legacy_external_runfiles).
|
||||
cp \
|
||||
cp -L \
|
||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/LICENSE \
|
||||
"${TMPDIR}"
|
||||
cp -R \
|
||||
cp -LR \
|
||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/tensorflow \
|
||||
"${TMPDIR}"
|
||||
cp_external \
|
||||
@ -172,10 +172,10 @@ function prepare_src() {
|
||||
fi
|
||||
else
|
||||
# New-style runfiles structure (--nolegacy_external_runfiles).
|
||||
cp \
|
||||
cp -L \
|
||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/LICENSE \
|
||||
"${TMPDIR}"
|
||||
cp -R \
|
||||
cp -LR \
|
||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/tensorflow \
|
||||
"${TMPDIR}"
|
||||
cp_external \
|
||||
|
Loading…
Reference in New Issue
Block a user