From aafebf2488854e4607147f4f3d442f7329fb06eb Mon Sep 17 00:00:00 2001
From: Scott Zhu <scottzhu@google.com>
Date: Fri, 18 Sep 2020 10:59:08 -0700
Subject: [PATCH] Internal changes only.

PiperOrigin-RevId: 332484552
Change-Id: Ie9b9b1529706d284fe6d3485ca5eeea04335d04f
---
 .../python/framework/test_combinations.py     | 54 ++++++++++++++++++-
 ...est.combinations.-optional-parameter.pbtxt | 18 +++++++
 ...est.combinations.-parameter-modifier.pbtxt | 17 ++++++
 ...rflow.__internal__.test.combinations.pbtxt |  8 +++
 4 files changed, 95 insertions(+), 2 deletions(-)
 create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.-optional-parameter.pbtxt
 create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.-parameter-modifier.pbtxt

diff --git a/tensorflow/python/framework/test_combinations.py b/tensorflow/python/framework/test_combinations.py
index 09920f68adf..09b6ba478db 100644
--- a/tensorflow/python/framework/test_combinations.py
+++ b/tensorflow/python/framework/test_combinations.py
@@ -114,8 +114,27 @@ class TestCombination(object):
     return []
 
 
+@tf_export("__internal__.test.combinations.ParameterModifier", v1=[])
 class ParameterModifier(object):
-  """Customizes the behavior of a particular parameter."""
+  """Customizes the behavior of a particular parameter.
+
+  Users should override `modified_arguments()` to modify the parameter they
+  want, eg: change the value of certain parameter or filter it from the params
+  passed to the test case.
+
+  See the sample usage below, it will change any negative parameters to zero
+  before it gets passed to test case.
+  ```
+  class NonNegativeParameterModifier(ParameterModifier):
+
+    def modified_arguments(self, kwargs, requested_parameters):
+      updates = {}
+      for name, value in kwargs.items():
+        if value < 0:
+          updates[name] = 0
+      return updates
+  ```
+  """
 
   DO_NOT_PASS_TO_THE_TEST = object()
 
@@ -171,8 +190,39 @@ class ParameterModifier(object):
       return id(self.__class__)
 
 
+@tf_export("__internal__.test.combinations.OptionalParameter", v1=[])
 class OptionalParameter(ParameterModifier):
-  """A parameter that is optional in `combine()` and in the test signature."""
+  """A parameter that is optional in `combine()` and in the test signature.
+
+  `OptionalParameter` is usually used with `TestCombination` in the
+  `parameter_modifiers()`. It allows `TestCombination` to skip certain
+  parameters when passing them to `combine()`, since the `TestCombination` might
+  consume the param and create some context based on the value it gets.
+
+  See the sample usage below:
+
+  ```
+  class EagerGraphCombination(TestCombination):
+
+    def context_managers(self, kwargs):
+      mode = kwargs.pop("mode", None)
+      if mode is None:
+        return []
+      elif mode == "eager":
+        return [context.eager_mode()]
+      elif mode == "graph":
+        return [ops.Graph().as_default(), context.graph_mode()]
+      else:
+        raise ValueError(
+            "'mode' has to be either 'eager' or 'graph', got {}".format(mode))
+
+    def parameter_modifiers(self):
+      return [test_combinations.OptionalParameter("mode")]
+  ```
+
+  When the test case is generated, the param "mode" will not be passed to the
+  test method, since it is consumed by the `EagerGraphCombination`.
+  """
 
   def modified_arguments(self, kwargs, requested_parameters):
     if self._parameter_name in requested_parameters:
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.-optional-parameter.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.-optional-parameter.pbtxt
new file mode 100644
index 00000000000..4dedda05bf6
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.-optional-parameter.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.__internal__.test.combinations.OptionalParameter"
+tf_class {
+  is_instance: "<class \'tensorflow.python.framework.test_combinations.OptionalParameter\'>"
+  is_instance: "<class \'tensorflow.python.framework.test_combinations.ParameterModifier\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "DO_NOT_PASS_TO_THE_TEST"
+    mtype: "<type \'object\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'parameter_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "modified_arguments"
+    argspec: "args=[\'self\', \'kwargs\', \'requested_parameters\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.-parameter-modifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.-parameter-modifier.pbtxt
new file mode 100644
index 00000000000..9b2438ccc8a
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.-parameter-modifier.pbtxt
@@ -0,0 +1,17 @@
+path: "tensorflow.__internal__.test.combinations.ParameterModifier"
+tf_class {
+  is_instance: "<class \'tensorflow.python.framework.test_combinations.ParameterModifier\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "DO_NOT_PASS_TO_THE_TEST"
+    mtype: "<type \'object\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'parameter_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "modified_arguments"
+    argspec: "args=[\'self\', \'kwargs\', \'requested_parameters\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.pbtxt
index 08695f72bea..b5190c37802 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.test.combinations.pbtxt
@@ -4,6 +4,14 @@ tf_module {
     name: "NamedObject"
     mtype: "<type \'type\'>"
   }
+  member {
+    name: "OptionalParameter"
+    mtype: "<type \'type\'>"
+  }
+  member {
+    name: "ParameterModifier"
+    mtype: "<type \'type\'>"
+  }
   member {
     name: "TestCombination"
     mtype: "<type \'type\'>"