Introduce new RngBitGenerator HLO
The new instruction has the same signature as xla::BitGeneratorTy what takes a key and a state and returns uniformly distributed random bits and a new value for the state. Its aim is to enable backend specific lowering for the various random bit generator algorithms what should unlock optimization opportunities. PiperOrigin-RevId: 293569472 Change-Id: I4f69d4f9858378fb1241435032ef75657933c056
This commit is contained in:
		
							parent
							
								
									fc1961d9c1
								
							
						
					
					
						commit
						a1c59c73da
					
				| @ -28,6 +28,7 @@ limitations under the License. | ||||
| #include "tensorflow/compiler/xla/client/lib/math.h" | ||||
| #include "tensorflow/compiler/xla/client/lib/prng.h" | ||||
| #include "tensorflow/compiler/xla/client/xla_builder.h" | ||||
| #include "tensorflow/compiler/xla/xla_data.pb.h" | ||||
| #include "tensorflow/core/framework/op_kernel.h" | ||||
| #include "tensorflow/core/framework/tensor.h" | ||||
| #include "tensorflow/core/framework/tensor_shape.h" | ||||
| @ -38,11 +39,27 @@ namespace { | ||||
| 
 | ||||
| xla::BitGeneratorTy BitGen(Algorithm alg) { | ||||
|   if (alg == RNG_ALG_PHILOX) { | ||||
|     return [](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { | ||||
|       return xla::PhiloxBitGenerator(key, state, shape, /*scramble=*/false); | ||||
|     return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { | ||||
|       state = | ||||
|           xla::ConcatInDim(key.builder(), {xla::Reshape(key, {1}), state}, 0); | ||||
|       xla::XlaOp result = | ||||
|           xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX, state, shape); | ||||
|       xla::XlaOp data = xla::GetTupleElement(result, 1); | ||||
|       xla::XlaOp new_state = | ||||
|           xla::Slice(xla::GetTupleElement(result, 0), {1}, {3}, {1}); | ||||
|       return xla::RngOutput{data, new_state}; | ||||
|     }; | ||||
|   } else { | ||||
|     return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { | ||||
|       state = xla::ConcatScalars(key.builder(), {key, state}); | ||||
|       xla::XlaOp result = xla::RngBitGenerator( | ||||
|           xla::RandomAlgorithm::RNG_THREE_FRY, state, shape); | ||||
|       xla::XlaOp data = xla::GetTupleElement(result, 1); | ||||
|       xla::XlaOp new_state = xla::Reshape( | ||||
|           xla::Slice(xla::GetTupleElement(result, 0), {1}, {2}, {1}), {}); | ||||
|       return xla::RngOutput{data, new_state}; | ||||
|     }; | ||||
|   } | ||||
|   return xla::ThreeFryBitGenerator; | ||||
| } | ||||
| 
 | ||||
| xla::RngOutput StatefulRngUniform(Algorithm alg, xla::XlaOp key, | ||||
|  | ||||
| @ -26,6 +26,7 @@ limitations under the License. | ||||
| #include "tensorflow/compiler/xla/client/lib/math.h" | ||||
| #include "tensorflow/compiler/xla/client/lib/prng.h" | ||||
| #include "tensorflow/compiler/xla/client/xla_builder.h" | ||||
| #include "tensorflow/compiler/xla/xla_data.pb.h" | ||||
| #include "tensorflow/core/framework/op_kernel.h" | ||||
| #include "tensorflow/core/framework/tensor.h" | ||||
| #include "tensorflow/core/framework/tensor_shape.h" | ||||
| @ -41,11 +42,23 @@ xla::BitGeneratorTy GetBitGeneratorForDevice( | ||||
|   // Turn on the Philox algorithm for the CPU and GPU backends only.
 | ||||
|   if (device_type_string == DEVICE_GPU_XLA_JIT || | ||||
|       device_type_string == DEVICE_CPU_XLA_JIT) { | ||||
|     return [](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { | ||||
|       return xla::PhiloxBitGenerator(key, state, shape, /*scramble=*/true); | ||||
|     return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { | ||||
|       std::tie(state, key) = xla::ScramblePhiloxKey(key); | ||||
|       xla::XlaOp philox_state = | ||||
|           xla::ConcatInDim(key.builder(), {xla::Reshape(key, {1}), state}, 0); | ||||
|       xla::XlaOp result = xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX, | ||||
|                                                philox_state, shape); | ||||
|       return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1), | ||||
|                             /*state=*/xla::GetTupleElement(result, 0)}; | ||||
|     }; | ||||
|   } | ||||
|   return xla::ThreeFryBitGenerator; | ||||
|   return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { | ||||
|     state = xla::ConcatScalars(key.builder(), {key, state}); | ||||
|     xla::XlaOp result = | ||||
|         xla::RngBitGenerator(xla::RandomAlgorithm::RNG_DEFAULT, state, shape); | ||||
|     return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1), | ||||
|                           /*state=*/xla::GetTupleElement(result, 0)}; | ||||
|   }; | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
|  | ||||
| @ -305,17 +305,9 @@ std::pair<Philox4x32State, XlaOp> GetPhiloxInputsAndUpdatedState( | ||||
| // numbers are generated in the unit of 128bits.
 | ||||
| std::pair<Philox4x32State, XlaOp> GeneratePhiloxBits(int64 num_elems, | ||||
|                                                      XlaOp initial_state, | ||||
|                                                      Philox4x32Key key, | ||||
|                                                      bool scramble) { | ||||
|                                                      Philox4x32Key key) { | ||||
|   Philox4x32State state; | ||||
|   if (scramble) { | ||||
|     // When `scramble` is true, `initial_state` is not used. This is because
 | ||||
|     // scramble is true only when this function is called by stateless random
 | ||||
|     // ops, for which `initial_state` is always zero.
 | ||||
|     std::tie(state, key) = ScramblePhiloxKey(key); | ||||
|   } else { | ||||
|   state = Uint128ToUint32s(Uint128FromOp(initial_state)); | ||||
|   } | ||||
|   const int64 num_vector4 = CeilOfRatio<int64>(num_elems, 4); | ||||
|   Philox4x32State inputs; | ||||
|   XlaOp new_state; | ||||
| @ -328,16 +320,15 @@ std::pair<Philox4x32State, XlaOp> GeneratePhiloxBits(int64 num_elems, | ||||
| // Generates an array of primitive type U32 with the given shape containing
 | ||||
| // random bits generated by the Philox algorithm. Returns the array and the new
 | ||||
| // state of the random number generator.
 | ||||
| RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state, const Shape& shape, | ||||
|                          bool scramble) { | ||||
| RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state, | ||||
|                          const Shape& shape) { | ||||
|   XlaBuilder* builder = op_key.builder(); | ||||
|   const int64 num_elems = ShapeUtil::ElementsIn(shape); | ||||
| 
 | ||||
|   Philox4x32Key key = Uint64ToUint32s(op_key); | ||||
|   Philox4x32State bits; | ||||
|   XlaOp new_state; | ||||
|   std::tie(bits, new_state) = | ||||
|       GeneratePhiloxBits(num_elems, initial_state, key, scramble); | ||||
|   std::tie(bits, new_state) = GeneratePhiloxBits(num_elems, initial_state, key); | ||||
|   // Combining bits[i] in a round-robin fashion, to align with non-XLA
 | ||||
|   // implementations
 | ||||
|   int64 bits_len = (num_elems + 3) / 4; | ||||
| @ -356,8 +347,8 @@ RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state, const Shape& shape, | ||||
| // Generates an array of primitive type U64 with the given shape containing
 | ||||
| // random bits generated by the Philox algorithm. Returns the array and the new
 | ||||
| // state of the random number generator.
 | ||||
| RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state, const Shape& shape, | ||||
|                          bool scramble) { | ||||
| RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state, | ||||
|                          const Shape& shape) { | ||||
|   XlaBuilder* builder = op_key.builder(); | ||||
|   const int64 num_elems = ShapeUtil::ElementsIn(shape); | ||||
| 
 | ||||
| @ -365,7 +356,7 @@ RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state, const Shape& shape, | ||||
|   Philox4x32State bits32; | ||||
|   XlaOp new_state; | ||||
|   std::tie(bits32, new_state) = | ||||
|       GeneratePhiloxBits(num_elems * 2, initial_state, key, scramble); | ||||
|       GeneratePhiloxBits(num_elems * 2, initial_state, key); | ||||
| 
 | ||||
|   std::array<XlaOp, 2> bits64; | ||||
|   bits64[0] = Uint32sToUint64({bits32[0], bits32[1]}); | ||||
| @ -463,18 +454,18 @@ RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state, | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, const Shape& shape, | ||||
|                              bool scramble) { | ||||
| RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, | ||||
|                              const Shape& shape) { | ||||
|   PrimitiveType type = shape.element_type(); | ||||
|   switch (type) { | ||||
|     case F32: | ||||
|     case U32: | ||||
|     case S32: | ||||
|       return PhiloxRngBit32(key, initial_state, shape, scramble); | ||||
|       return PhiloxRngBit32(key, initial_state, shape); | ||||
|     case F64: | ||||
|     case U64: | ||||
|     case S64: | ||||
|       return PhiloxRngBit64(key, initial_state, shape, scramble); | ||||
|       return PhiloxRngBit64(key, initial_state, shape); | ||||
|     default: | ||||
|       return {key.builder()->ReportError(Unimplemented( | ||||
|                   "Types other than F32, F64, U32, S32, U64 and S64 " | ||||
| @ -484,6 +475,13 @@ RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, const Shape& shape, | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| std::pair<XlaOp, XlaOp> ScramblePhiloxKey(XlaOp key) { | ||||
|   Philox4x32Key pkey = Uint64ToUint32s(key); | ||||
|   auto state_key = ScramblePhiloxKey(pkey); | ||||
|   return std::make_pair(Uint128ToOp(Uint32sToUint128(state_key.first)), | ||||
|                         Uint32sToUint64(state_key.second)); | ||||
| } | ||||
| 
 | ||||
| RngOutput UniformFloatingPointDistribution(XlaOp key, XlaOp initial_state, | ||||
|                                            BitGeneratorTy bit_generator, | ||||
|                                            XlaOp minval, XlaOp maxval, | ||||
|  | ||||
| @ -58,10 +58,10 @@ RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state, | ||||
| // 4x32_10 version of the algorithm for the following reasons:
 | ||||
| //   . 4x32 uses 32-bit multiplication which is fast on GPUs.
 | ||||
| //   . The authors recommend the 10-round variant, and TensorFlow also uses it.
 | ||||
| // 'scramble` controls whether to scramble 'key' and 'initial_state' to form
 | ||||
| // the actual key and state fed to the Philox algorithm.
 | ||||
| RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, const Shape& shape, | ||||
|                              bool scramble); | ||||
| RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, | ||||
|                              const Shape& shape); | ||||
| // Returns a scrambled pair of (state, key) from a single key.
 | ||||
| std::pair<XlaOp, XlaOp> ScramblePhiloxKey(XlaOp key); | ||||
| 
 | ||||
| // Uses the given bit generator to generate random bits and then converts the
 | ||||
| // random bits to random numbers of uniform distribution in the given range.
 | ||||
|  | ||||
| @ -37,6 +37,7 @@ limitations under the License. | ||||
| #include "tensorflow/compiler/xla/service/shape_inference.h" | ||||
| #include "tensorflow/compiler/xla/status_macros.h" | ||||
| #include "tensorflow/compiler/xla/util.h" | ||||
| #include "tensorflow/compiler/xla/xla_data.pb.h" | ||||
| 
 | ||||
| namespace xla { | ||||
| 
 | ||||
| @ -1752,6 +1753,36 @@ XlaOp XlaBuilder::RngUniform(XlaOp a, XlaOp b, const Shape& shape) { | ||||
|   return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape); | ||||
| } | ||||
| 
 | ||||
| XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm, | ||||
|                                   XlaOp initial_state, const Shape& shape) { | ||||
|   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { | ||||
|     HloInstructionProto instr; | ||||
|     TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); | ||||
|     TF_ASSIGN_OR_RETURN(Shape state_shape, GetShape(initial_state)); | ||||
|     Shape output_shape = shape; | ||||
|     switch (output_shape.element_type()) { | ||||
|       case PrimitiveType::F32: | ||||
|       case PrimitiveType::S32: | ||||
|       case PrimitiveType::U32: | ||||
|         output_shape.set_element_type(PrimitiveType::U32); | ||||
|         break; | ||||
|       case PrimitiveType::F64: | ||||
|       case PrimitiveType::S64: | ||||
|       case PrimitiveType::U64: | ||||
|         output_shape.set_element_type(PrimitiveType::U64); | ||||
|         break; | ||||
|       default: | ||||
|         return InvalidArgument("Unsupported shape for RngBitGenerator: %s", | ||||
|                                PrimitiveType_Name(output_shape.element_type())); | ||||
|     } | ||||
|     *instr.mutable_shape() = | ||||
|         ShapeUtil::MakeTupleShape({state_shape, output_shape}).ToProto(); | ||||
|     instr.set_rng_algorithm(algorithm); | ||||
|     return AddInstruction(std::move(instr), HloOpcode::kRngBitGenerator, | ||||
|                           {initial_state}); | ||||
|   }); | ||||
| } | ||||
| 
 | ||||
| XlaOp XlaBuilder::While(const XlaComputation& condition, | ||||
|                         const XlaComputation& body, XlaOp init) { | ||||
|   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { | ||||
| @ -3460,6 +3491,12 @@ XlaOp RngUniform(const XlaOp a, const XlaOp b, const Shape& shape) { | ||||
|   return a.builder()->RngUniform(a, b, shape); | ||||
| } | ||||
| 
 | ||||
| XlaOp RngBitGenerator(RandomAlgorithm algorithm, const XlaOp initial_state, | ||||
|                       const Shape& shape) { | ||||
|   return initial_state.builder()->RngBitGenerator(algorithm, initial_state, | ||||
|                                                   shape); | ||||
| } | ||||
| 
 | ||||
| XlaOp While(const XlaComputation& condition, const XlaComputation& body, | ||||
|             const XlaOp init) { | ||||
|   return init.builder()->While(condition, body, init); | ||||
|  | ||||
| @ -565,6 +565,9 @@ class XlaBuilder { | ||||
| 
 | ||||
|   XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape); | ||||
| 
 | ||||
|   XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, | ||||
|                         const Shape& shape); | ||||
| 
 | ||||
|   XlaOp While(const XlaComputation& condition, const XlaComputation& body, | ||||
|               XlaOp init); | ||||
| 
 | ||||
| @ -985,6 +988,8 @@ class XlaBuilder { | ||||
|                    absl::Span<const XlaOp> static_operands); | ||||
|   friend XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape); | ||||
|   friend XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape); | ||||
|   friend XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, | ||||
|                                const Shape& shape); | ||||
|   friend XlaOp While(const XlaComputation& condition, | ||||
|                      const XlaComputation& body, XlaOp init); | ||||
|   friend XlaOp Conditional(XlaOp predicate, XlaOp true_operand, | ||||
| @ -1856,6 +1861,11 @@ XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape); | ||||
| // computation. Returns values in the semi-open interval [a, b).
 | ||||
| XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape); | ||||
| 
 | ||||
| // Enqueues a B(initial_state) random bit generation instruction onto the
 | ||||
| // computation. Resturns the new key and random bits with the specified shape.
 | ||||
| XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, | ||||
|                       const Shape& shape); | ||||
| 
 | ||||
| // Enqueues a while node onto the computation.
 | ||||
| XlaOp While(const XlaComputation& condition, const XlaComputation& body, | ||||
|             XlaOp init); | ||||
|  | ||||
| @ -2271,6 +2271,34 @@ implementation-defined. | ||||
| :           :                         : limit of interval                 : | ||||
| | `shape`   | `Shape`                 | Output shape of type T            | | ||||
| 
 | ||||
| ## RngBitGenerator | ||||
| 
 | ||||
| Generates an output with a given shape filled with uniform random bits using the | ||||
| specified algorithm (or backend default) and returns an updated state (with the | ||||
| same shape as initial state) and the generated random data. | ||||
| 
 | ||||
| Initial state is the initial state of the current random number generation. It | ||||
| and the required shape and valid values are dependent on the algorithm used. | ||||
| 
 | ||||
| The output is guaranteed to be a deterministic function of the initial state but | ||||
| it is *not* guaranteed to be deterministic between backends and different | ||||
| compiler versions. | ||||
| 
 | ||||
| <b>`RngBitGenerator(algorithm, key, shape)`</b> | Arguments | Type | Semantics | | ||||
| |---------------- | ----------------- | ------------------------------------- | | ||||
| | `algorithm` | `RandomAlgorithm` | PRNG algorithm to be used. | | | ||||
| `initial_state` | `XlaOp` | Initial state for the PRNG algorithm. | | `shape` | | ||||
| `Shape` | Output shape for generated data. | | ||||
| 
 | ||||
| Available values for `algorithm`: * `rng_default`: Backend specific algorithm | ||||
| with backend specific shape requirements. * `rng_three_fry`: ThreeFry | ||||
| counter-based PRNG algorithm. The `initial_state` shape is `u64[2]` with | ||||
| arbitrary values. | ||||
| [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) | ||||
| * `rng_philox`: Philox algorithm to generate random numbers in parallel. The | ||||
| `initial_state` shape is `u64[3]` with arbitrary values. | ||||
| [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) | ||||
| 
 | ||||
| ## Scatter | ||||
| 
 | ||||
| The XLA scatter operation generates a result which is the value of the input | ||||
|  | ||||
| @ -4492,6 +4492,25 @@ cc_library( | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "rng_bit_generator_expander", | ||||
|     srcs = ["rng_bit_generator_expander.cc"], | ||||
|     hdrs = ["rng_bit_generator_expander.h"], | ||||
|     deps = [ | ||||
|         ":hlo", | ||||
|         ":hlo_casting_utils", | ||||
|         ":op_expander_pass", | ||||
|         "//tensorflow/compiler/xla:shape_util", | ||||
|         "//tensorflow/compiler/xla:statusor", | ||||
|         "//tensorflow/compiler/xla:util", | ||||
|         "//tensorflow/compiler/xla:xla_data_proto_cc", | ||||
|         "//tensorflow/compiler/xla/client:xla_builder", | ||||
|         "//tensorflow/compiler/xla/client/lib:prng", | ||||
|         "//tensorflow/stream_executor/lib", | ||||
|         "@com_google_absl//absl/container:flat_hash_map", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "slow_operation_alarm", | ||||
|     srcs = ["slow_operation_alarm.cc"], | ||||
|  | ||||
| @ -94,6 +94,7 @@ cc_library( | ||||
|         "//tensorflow/compiler/xla/service:hlo_casting_utils", | ||||
|         "//tensorflow/compiler/xla/service:dump", | ||||
|         "//tensorflow/compiler/xla/service:map_inliner", | ||||
|         "//tensorflow/compiler/xla/service:rng_bit_generator_expander", | ||||
|         "//tensorflow/compiler/xla/service:tree_reduction_rewriter", | ||||
|         "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", | ||||
|         "//tensorflow/compiler/xla/service:conditional_to_select", | ||||
|  | ||||
| @ -93,6 +93,7 @@ limitations under the License. | ||||
| #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" | ||||
| #include "tensorflow/compiler/xla/service/map_inliner.h" | ||||
| #include "tensorflow/compiler/xla/service/reshape_mover.h" | ||||
| #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h" | ||||
| #include "tensorflow/compiler/xla/service/rng_expander.h" | ||||
| #include "tensorflow/compiler/xla/service/scatter_expander.h" | ||||
| #include "tensorflow/compiler/xla/service/slice_sinker.h" | ||||
| @ -241,6 +242,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( | ||||
| 
 | ||||
|   // Expand random number generation.
 | ||||
|   pipeline.AddPass<RngExpander>(); | ||||
|   pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX); | ||||
| 
 | ||||
|   // Remove zero-sized HLO from the input so that other passes don't have to
 | ||||
|   // handle it.
 | ||||
|  | ||||
| @ -225,6 +225,7 @@ class DfsHloVisitorBase { | ||||
|   virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; | ||||
|   virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; | ||||
|   virtual Status HandleRng(HloInstructionPtr hlo) = 0; | ||||
|   virtual Status HandleRngBitGenerator(HloInstructionPtr hlo) = 0; | ||||
|   virtual Status HandleRngGetAndUpdateState(HloInstructionPtr hlo) = 0; | ||||
|   virtual Status HandleReverse(HloInstructionPtr hlo) = 0; | ||||
|   virtual Status HandleSort(HloInstructionPtr hlo) = 0; | ||||
|  | ||||
| @ -116,6 +116,9 @@ class DfsHloVisitorWithDefaultBase | ||||
|   Status HandleRng(HloInstructionPtr random) override { | ||||
|     return DefaultAction(random); | ||||
|   } | ||||
|   Status HandleRngBitGenerator(HloInstructionPtr random) override { | ||||
|     return DefaultAction(random); | ||||
|   } | ||||
|   Status HandleRngGetAndUpdateState(HloInstructionPtr random) override { | ||||
|     return DefaultAction(random); | ||||
|   } | ||||
|  | ||||
| @ -1155,6 +1155,7 @@ cc_library( | ||||
|         "//tensorflow/compiler/xla/service:hlo_verifier", | ||||
|         "//tensorflow/compiler/xla/service:llvm_compiler", | ||||
|         "//tensorflow/compiler/xla/service:reshape_mover", | ||||
|         "//tensorflow/compiler/xla/service:rng_bit_generator_expander", | ||||
|         "//tensorflow/compiler/xla/service:rng_expander", | ||||
|         "//tensorflow/compiler/xla/service:slice_sinker", | ||||
|         "//tensorflow/compiler/xla/service:slow_operation_alarm", | ||||
|  | ||||
| @ -80,6 +80,7 @@ limitations under the License. | ||||
| #include "tensorflow/compiler/xla/service/hlo_verifier.h" | ||||
| #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" | ||||
| #include "tensorflow/compiler/xla/service/reshape_mover.h" | ||||
| #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h" | ||||
| #include "tensorflow/compiler/xla/service/rng_expander.h" | ||||
| #include "tensorflow/compiler/xla/service/slice_sinker.h" | ||||
| #include "tensorflow/compiler/xla/service/slow_operation_alarm.h" | ||||
| @ -127,6 +128,7 @@ Status GpuCompiler::OptimizeHloModule( | ||||
| 
 | ||||
|     // Expand random number generation.
 | ||||
|     pipeline.AddPass<RngExpander>(); | ||||
|     pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX); | ||||
| 
 | ||||
|     // Remove zero-sized HLO from the input so that other passes don't have to
 | ||||
|     // handle it.
 | ||||
|  | ||||
| @ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; | ||||
| option cc_enable_arenas = true; | ||||
| 
 | ||||
| // Serialization of HloInstruction. | ||||
| // Next ID: 70 | ||||
| // Next ID: 71 | ||||
| message HloInstructionProto { | ||||
|   reserved 10; | ||||
|   reserved "parameter_name"; | ||||
| @ -241,6 +241,9 @@ message HloInstructionProto { | ||||
|   // Specifies if all elements updated are guaranteed to be unique by | ||||
|   // the caller. | ||||
|   bool unique_indices = 69; | ||||
| 
 | ||||
|   // RNG algorithm used by kRngBitGenerator. | ||||
|   xla.RandomAlgorithm rng_algorithm = 70; | ||||
| } | ||||
| 
 | ||||
| // Serialization of HloComputation. | ||||
|  | ||||
| @ -746,6 +746,15 @@ Status HloCostAnalysis::HandleRng(const HloInstruction* random) { | ||||
|   return Status::OK(); | ||||
| } | ||||
| 
 | ||||
| Status HloCostAnalysis::HandleRngBitGenerator(const HloInstruction* random) { | ||||
|   // TODO(b/26346211): Implement better estimates for the RNG cost, since the
 | ||||
|   // cost changes with the implementation and the distribution. For now, assume
 | ||||
|   // the cost of each RNG is same as a transcendental operation.
 | ||||
|   current_properties_[kTranscendentalsKey] = | ||||
|       ShapeUtil::ElementsIn(random->shape()); | ||||
|   return Status::OK(); | ||||
| } | ||||
| 
 | ||||
| Status HloCostAnalysis::HandleRngGetAndUpdateState( | ||||
|     const HloInstruction* random) { | ||||
|   return Status::OK(); | ||||
|  | ||||
| @ -84,6 +84,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { | ||||
|   Status HandleInfeed(const HloInstruction* infeed) override; | ||||
|   Status HandleOutfeed(const HloInstruction* outfeed) override; | ||||
|   Status HandleRng(const HloInstruction* random) override; | ||||
|   Status HandleRngBitGenerator(const HloInstruction* random) override; | ||||
|   Status HandleRngGetAndUpdateState(const HloInstruction* random) override; | ||||
|   Status HandleReverse(const HloInstruction* reverse) override; | ||||
|   Status HandleSort(const HloInstruction* sort) override; | ||||
|  | ||||
| @ -39,6 +39,7 @@ limitations under the License. | ||||
| #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_instructions.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_module.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_opcode.h" | ||||
| #include "tensorflow/compiler/xla/service/pattern_matcher.h" | ||||
| #include "tensorflow/compiler/xla/shape_util.h" | ||||
| #include "tensorflow/compiler/xla/types.h" | ||||
| @ -967,6 +968,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { | ||||
|     case HloOpcode::kRemainder: | ||||
|     case HloOpcode::kRng: | ||||
|     case HloOpcode::kRngGetAndUpdateState: | ||||
|     case HloOpcode::kRngBitGenerator: | ||||
|     case HloOpcode::kRoundNearestAfz: | ||||
|     case HloOpcode::kRsqrt: | ||||
|     case HloOpcode::kSelect: | ||||
|  | ||||
| @ -31,6 +31,7 @@ limitations under the License. | ||||
| #include "absl/strings/numbers.h" | ||||
| #include "absl/strings/str_cat.h" | ||||
| #include "absl/strings/str_join.h" | ||||
| #include "absl/strings/string_view.h" | ||||
| #include "absl/types/span.h" | ||||
| #include "tensorflow/compiler/xla/layout_util.h" | ||||
| #include "tensorflow/compiler/xla/literal.h" | ||||
| @ -46,6 +47,7 @@ limitations under the License. | ||||
| #include "tensorflow/compiler/xla/status_macros.h" | ||||
| #include "tensorflow/compiler/xla/types.h" | ||||
| #include "tensorflow/compiler/xla/util.h" | ||||
| #include "tensorflow/compiler/xla/xla_data.pb.h" | ||||
| #include "tensorflow/core/lib/core/errors.h" | ||||
| #include "tensorflow/core/lib/gtl/map_util.h" | ||||
| #include "tensorflow/core/platform/human_readable_json.h" | ||||
| @ -346,6 +348,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( | ||||
|     case HloOpcode::kRng: | ||||
|       instruction = CreateRng(shape, proto.distribution(), all_operands()); | ||||
|       break; | ||||
|     case HloOpcode::kRngBitGenerator: | ||||
|       instruction = | ||||
|           CreateRngBitGenerator(shape, operands(0), proto.rng_algorithm()); | ||||
|       break; | ||||
|     case HloOpcode::kRngGetAndUpdateState: | ||||
|       instruction = CreateRngGetAndUpdateState(shape, proto.delta()); | ||||
|       break; | ||||
| @ -743,6 +749,13 @@ HloInstruction::CreateRngGetAndUpdateState(const Shape& shape, int64 delta) { | ||||
|   return absl::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta); | ||||
| } | ||||
| 
 | ||||
| /* static */ std::unique_ptr<HloInstruction> | ||||
| HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state, | ||||
|                                       RandomAlgorithm algorithm) { | ||||
|   return absl::make_unique<HloRngBitGeneratorInstruction>(shape, state, | ||||
|                                                           algorithm); | ||||
| } | ||||
| 
 | ||||
| /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary( | ||||
|     const Shape& shape, HloOpcode opcode, | ||||
|     absl::Span<HloInstruction* const> operands) { | ||||
| @ -1480,6 +1493,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( | ||||
|     case HloOpcode::kTrace: | ||||
|     case HloOpcode::kFusion: | ||||
|     case HloOpcode::kRng: | ||||
|     case HloOpcode::kRngBitGenerator: | ||||
|     case HloOpcode::kRngGetAndUpdateState: | ||||
|     case HloOpcode::kParameter: | ||||
|     case HloOpcode::kGetTupleElement: | ||||
| @ -1950,6 +1964,7 @@ bool HloInstruction::IdenticalSlowPath( | ||||
|     case HloOpcode::kTrace: | ||||
|     case HloOpcode::kFusion: | ||||
|     case HloOpcode::kRng: | ||||
|     case HloOpcode::kRngBitGenerator: | ||||
|     case HloOpcode::kRngGetAndUpdateState: | ||||
|     case HloOpcode::kParameter: | ||||
|     case HloOpcode::kGetTupleElement: | ||||
| @ -2882,6 +2897,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) { | ||||
|       return visitor->HandleOutfeed(this); | ||||
|     case HloOpcode::kRng: | ||||
|       return visitor->HandleRng(this); | ||||
|     case HloOpcode::kRngBitGenerator: | ||||
|       return visitor->HandleRngBitGenerator(this); | ||||
|     case HloOpcode::kRngGetAndUpdateState: | ||||
|       return visitor->HandleRngGetAndUpdateState(this); | ||||
|     case HloOpcode::kWhile: | ||||
| @ -3357,6 +3374,9 @@ string OpMetadataToString(const OpMetadata& metadata) { | ||||
| string RandomDistributionToString(const RandomDistribution& distribution) { | ||||
|   return absl::AsciiStrToLower(RandomDistribution_Name(distribution)); | ||||
| } | ||||
| string RandomAlgorithmToString(const RandomAlgorithm& algorithm) { | ||||
|   return absl::AsciiStrToLower(RandomAlgorithm_Name(algorithm)); | ||||
| } | ||||
| 
 | ||||
| string PrecisionToString(const PrecisionConfig::Precision& precision) { | ||||
|   return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision)); | ||||
| @ -3401,6 +3421,24 @@ string ReplicaGroupsToString(const std::vector<ReplicaGroup>& replica_groups) { | ||||
|   return StrCat("{", StrJoin(replica_group_str, ","), "}"); | ||||
| } | ||||
| 
 | ||||
| StatusOr<RandomAlgorithm> StringToRandomAlgorithm(const string& name) { | ||||
|   static std::unordered_map<string, RandomAlgorithm>* map = [] { | ||||
|     static auto* map = new std::unordered_map<string, RandomAlgorithm>; | ||||
|     for (int i = 0; i < RandomAlgorithm_ARRAYSIZE; i++) { | ||||
|       if (RandomAlgorithm_IsValid(i)) { | ||||
|         auto value = static_cast<RandomAlgorithm>(i); | ||||
|         (*map)[RandomAlgorithmToString(value)] = value; | ||||
|       } | ||||
|     } | ||||
|     return map; | ||||
|   }(); | ||||
|   auto found = map->find(absl::AsciiStrToLower(name)); | ||||
|   if (found == map->end()) { | ||||
|     return InvalidArgument("Unknown algorithm"); | ||||
|   } | ||||
|   return found->second; | ||||
| } | ||||
| 
 | ||||
| StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) { | ||||
|   static std::unordered_map<string, RandomDistribution>* map = [] { | ||||
|     static auto* map = new std::unordered_map<string, RandomDistribution>; | ||||
|  | ||||
| @ -512,6 +512,11 @@ class HloInstruction { | ||||
|       const Shape& shape, RandomDistribution distribution, | ||||
|       absl::Span<HloInstruction* const> parameters); | ||||
| 
 | ||||
|   // Creates a stateless random bit generator instruction that fills a shape
 | ||||
|   // with random bits.
 | ||||
|   static std::unique_ptr<HloInstruction> CreateRngBitGenerator( | ||||
|       const Shape& shape, HloInstruction* state, RandomAlgorithm algorithm); | ||||
| 
 | ||||
|   // Creates an instruction to update the random number generator state to
 | ||||
|   // reflect the new state after `delta` units of 32 random bits are generated
 | ||||
|   // and returns the old state.
 | ||||
| @ -2025,12 +2030,14 @@ string PaddingConfigToString(const PaddingConfig& padding); | ||||
| string FrontendAttributesToString( | ||||
|     const FrontendAttributes& frontend_attributes); | ||||
| string OpMetadataToString(const OpMetadata& metadata); | ||||
| string RandomAlgorithmToString(const RandomAlgorithm& algorithm); | ||||
| string RandomDistributionToString(const RandomDistribution& distribution); | ||||
| string PrecisionToString(const PrecisionConfig::Precision& precision); | ||||
| string ConvolutionDimensionNumbersToString( | ||||
|     const ConvolutionDimensionNumbers& dnums); | ||||
| string ReplicaGroupsToString(const std::vector<ReplicaGroup>& replica_groups); | ||||
| 
 | ||||
| StatusOr<RandomAlgorithm> StringToRandomAlgorithm(const string& name); | ||||
| StatusOr<RandomDistribution> StringToRandomDistribution(const string& name); | ||||
| StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name); | ||||
| 
 | ||||
|  | ||||
| @ -31,6 +31,7 @@ limitations under the License. | ||||
| #include "tensorflow/compiler/xla/service/hlo_module.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" | ||||
| #include "tensorflow/compiler/xla/window_util.h" | ||||
| #include "tensorflow/compiler/xla/xla_data.pb.h" | ||||
| #include "tensorflow/core/platform/protobuf.h" | ||||
| 
 | ||||
| namespace xla { | ||||
| @ -2902,4 +2903,40 @@ HloRngGetAndUpdateStateInstruction::CloneWithNewOperandsImpl( | ||||
|   return absl::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta()); | ||||
| } | ||||
| 
 | ||||
| HloRngBitGeneratorInstruction::HloRngBitGeneratorInstruction( | ||||
|     const Shape& shape, HloInstruction* state, RandomAlgorithm algorithm) | ||||
|     : HloInstruction(HloOpcode::kRngBitGenerator, shape), | ||||
|       algorithm_(algorithm) { | ||||
|   AppendOperand(state); | ||||
| } | ||||
| 
 | ||||
| HloInstructionProto HloRngBitGeneratorInstruction::ToProto() const { | ||||
|   HloInstructionProto proto = HloInstruction::ToProto(); | ||||
|   proto.set_rng_algorithm(algorithm_); | ||||
|   return proto; | ||||
| } | ||||
| 
 | ||||
| std::vector<string> HloRngBitGeneratorInstruction::ExtraAttributesToStringImpl( | ||||
|     const HloPrintOptions& options) const { | ||||
|   return {StrCat("algorithm=", RandomAlgorithmToString(algorithm_))}; | ||||
| } | ||||
| 
 | ||||
| bool HloRngBitGeneratorInstruction::IdenticalSlowPath( | ||||
|     const HloInstruction& other, | ||||
|     const std::function<bool(const HloComputation*, const HloComputation*)>& | ||||
|         eq_computations) const { | ||||
|   const auto& casted_other = | ||||
|       static_cast<const HloRngBitGeneratorInstruction&>(other); | ||||
|   return algorithm() == casted_other.algorithm(); | ||||
| } | ||||
| 
 | ||||
| std::unique_ptr<HloInstruction> | ||||
| HloRngBitGeneratorInstruction::CloneWithNewOperandsImpl( | ||||
|     const Shape& shape, absl::Span<HloInstruction* const> new_operands, | ||||
|     HloCloneContext* /*context*/) const { | ||||
|   CHECK_EQ(new_operands.size(), 1); | ||||
|   return absl::make_unique<HloRngBitGeneratorInstruction>( | ||||
|       shape, new_operands[0], algorithm()); | ||||
| } | ||||
| 
 | ||||
| }  // namespace xla
 | ||||
|  | ||||
| @ -20,6 +20,7 @@ limitations under the License. | ||||
| 
 | ||||
| #include "absl/memory/memory.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" | ||||
| #include "tensorflow/compiler/xla/xla_data.pb.h" | ||||
| 
 | ||||
| namespace xla { | ||||
| 
 | ||||
| @ -1721,6 +1722,28 @@ class HloRngGetAndUpdateStateInstruction : public HloInstruction { | ||||
|   int64 delta_; | ||||
| }; | ||||
| 
 | ||||
| class HloRngBitGeneratorInstruction : public HloInstruction { | ||||
|  public: | ||||
|   HloRngBitGeneratorInstruction(const Shape& shape, HloInstruction* state, | ||||
|                                 RandomAlgorithm algorithm); | ||||
| 
 | ||||
|   RandomAlgorithm algorithm() const { return algorithm_; } | ||||
|   HloInstructionProto ToProto() const override; | ||||
| 
 | ||||
|  private: | ||||
|   std::vector<string> ExtraAttributesToStringImpl( | ||||
|       const HloPrintOptions& options) const override; | ||||
|   bool IdenticalSlowPath( | ||||
|       const HloInstruction& other, | ||||
|       const std::function<bool(const HloComputation*, const HloComputation*)>& | ||||
|           eq_computations) const override; | ||||
|   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( | ||||
|       const Shape& shape, absl::Span<HloInstruction* const> new_operands, | ||||
|       HloCloneContext* context) const override; | ||||
| 
 | ||||
|   RandomAlgorithm algorithm_; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace xla
 | ||||
| 
 | ||||
| #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
 | ||||
|  | ||||
| @ -122,6 +122,7 @@ namespace xla { | ||||
|   V(kReverse, "reverse", 1)                                            \ | ||||
|   V(kRng, "rng", kHloOpcodeIsVariadic)                                 \ | ||||
|   V(kRngGetAndUpdateState, "rng-get-and-update-state", 0)              \ | ||||
|   V(kRngBitGenerator, "rng-bit-generator", 1)                          \ | ||||
|   V(kRoundNearestAfz, "round-nearest-afz", 1)                          \ | ||||
|   V(kRsqrt, "rsqrt", 1)                                                \ | ||||
|   V(kScatter, "scatter", 3)                                            \ | ||||
|  | ||||
| @ -36,6 +36,7 @@ limitations under the License. | ||||
| #include "tensorflow/compiler/xla/primitive_util.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_instructions.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_lexer.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_opcode.h" | ||||
| @ -203,7 +204,8 @@ class HloParserImpl : public HloParser { | ||||
|     kDomain, | ||||
|     kPrecisionList, | ||||
|     kShapeList, | ||||
|     kEnum | ||||
|     kEnum, | ||||
|     kRandomAlgorithm, | ||||
|   }; | ||||
| 
 | ||||
|   struct AttrConfig { | ||||
| @ -322,6 +324,7 @@ class HloParserImpl : public HloParser { | ||||
|   bool ParseComparisonDirection(ComparisonDirection* result); | ||||
|   bool ParseFusionKind(HloInstruction::FusionKind* result); | ||||
|   bool ParseRandomDistribution(RandomDistribution* result); | ||||
|   bool ParseRandomAlgorithm(RandomAlgorithm* result); | ||||
|   bool ParsePrecision(PrecisionConfig::Precision* result); | ||||
|   bool ParseInt64(int64* result); | ||||
|   bool ParseDouble(double* result); | ||||
| @ -1501,6 +1504,18 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, | ||||
|           HloInstruction::CreateRngGetAndUpdateState(shape, *delta)); | ||||
|       break; | ||||
|     } | ||||
|     case HloOpcode::kRngBitGenerator: { | ||||
|       optional<RandomAlgorithm> algorithm; | ||||
|       attrs["algorithm"] = {/*required=*/true, AttrTy::kRandomAlgorithm, | ||||
|                             &algorithm}; | ||||
|       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { | ||||
|         return false; | ||||
|       } | ||||
|       instruction = | ||||
|           builder->AddInstruction(HloInstruction::CreateRngBitGenerator( | ||||
|               shape, operands[0], *algorithm)); | ||||
|       break; | ||||
|     } | ||||
|     case HloOpcode::kReducePrecision: { | ||||
|       optional<int64> exponent_bits; | ||||
|       optional<int64> mantissa_bits; | ||||
| @ -2972,6 +2987,14 @@ bool HloParserImpl::ParseAttributeHelper( | ||||
|             ->emplace(result); | ||||
|         return true; | ||||
|       } | ||||
|       case AttrTy::kRandomAlgorithm: { | ||||
|         RandomAlgorithm result; | ||||
|         if (!ParseRandomAlgorithm(&result)) { | ||||
|           return false; | ||||
|         } | ||||
|         static_cast<optional<RandomAlgorithm>*>(attr_out_ptr)->emplace(result); | ||||
|         return true; | ||||
|       } | ||||
|     } | ||||
|   }(); | ||||
|   if (!success) { | ||||
| @ -3963,6 +3986,23 @@ bool HloParserImpl::ParseRandomDistribution(RandomDistribution* result) { | ||||
|   return true; | ||||
| } | ||||
| 
 | ||||
| bool HloParserImpl::ParseRandomAlgorithm(RandomAlgorithm* result) { | ||||
|   VLOG(3) << "ParseRandomAlgorithm"; | ||||
|   if (lexer_.GetKind() != TokKind::kIdent) { | ||||
|     return TokenError("expects random algorithm"); | ||||
|   } | ||||
|   std::string val = lexer_.GetStrVal(); | ||||
|   auto status_or_result = StringToRandomAlgorithm(val); | ||||
|   if (!status_or_result.ok()) { | ||||
|     return TokenError( | ||||
|         StrFormat("expects random algorithm but sees: %s, error: %s", val, | ||||
|                   status_or_result.status().error_message())); | ||||
|   } | ||||
|   *result = status_or_result.ValueOrDie(); | ||||
|   lexer_.Lex(); | ||||
|   return true; | ||||
| } | ||||
| 
 | ||||
| bool HloParserImpl::ParsePrecision(PrecisionConfig::Precision* result) { | ||||
|   VLOG(3) << "ParsePrecision"; | ||||
|   if (lexer_.GetKind() != TokKind::kIdent) { | ||||
|  | ||||
| @ -1055,6 +1055,17 @@ ENTRY %RngGetAndUpdateState () -> u64[2] { | ||||
| 
 | ||||
| )" | ||||
| }, | ||||
| { | ||||
| "RngBitGenerator", | ||||
| R"(HloModule gng_bit_generator | ||||
| 
 | ||||
| ENTRY %RngBitGenerator (p0: u64[2]) -> (u64[2], u32[11,17]) { | ||||
|   %p0 = u64[2]{0} parameter(0) | ||||
|   ROOT %rand = (u64[2]{0}, u32[11,17]{1,0}) rng-bit-generator(u64[2]{0} %p0), algorithm=rng_three_fry | ||||
| } | ||||
| 
 | ||||
| )" | ||||
| } | ||||
|   }); | ||||
|   // clang-format on
 | ||||
| } | ||||
|  | ||||
| @ -415,6 +415,23 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { | ||||
|   return Status::OK(); | ||||
| } | ||||
| 
 | ||||
| Status ShapeVerifier::HandleRngBitGenerator(HloInstruction* hlo) { | ||||
|   if (!hlo->shape().IsTuple() || hlo->shape().tuple_shapes_size() != 2) { | ||||
|     return InternalError( | ||||
|         "Expected tuple shape with 2 elements for RngBitGenerator. Got: %s", | ||||
|         hlo->shape().ToString()); | ||||
|   } | ||||
|   if (!ShapeUtil::Compatible(hlo->operand(0)->shape(), | ||||
|                              hlo->shape().tuple_shapes(0))) { | ||||
|     return InternalError( | ||||
|         "Expected state shape to match between input and output for " | ||||
|         "RngBitGenerator. Got %s vs. %s", | ||||
|         hlo->operand(0)->shape().ToString(), | ||||
|         hlo->shape().tuple_shapes(0).ToString()); | ||||
|   } | ||||
|   return Status::OK(); | ||||
| } | ||||
| 
 | ||||
| Status ShapeVerifier::HandleRngGetAndUpdateState(HloInstruction* instruction) { | ||||
|   TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0)); | ||||
|   const Shape& result_shape = instruction->shape(); | ||||
|  | ||||
| @ -65,6 +65,7 @@ class ShapeVerifier : public DfsHloVisitor { | ||||
|   Status HandleInfeed(HloInstruction*) override; | ||||
|   Status HandleOutfeed(HloInstruction*) override; | ||||
|   Status HandleRng(HloInstruction*) override; | ||||
|   Status HandleRngBitGenerator(HloInstruction*) override; | ||||
|   Status HandleRngGetAndUpdateState(HloInstruction*) override; | ||||
|   Status HandleReverse(HloInstruction* reverse) override; | ||||
|   Status HandleSort(HloInstruction* sort) override; | ||||
|  | ||||
| @ -167,6 +167,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { | ||||
|     case HloOpcode::kReduceWindow: | ||||
|     case HloOpcode::kRng: | ||||
|     case HloOpcode::kRngGetAndUpdateState: | ||||
|     case HloOpcode::kRngBitGenerator: | ||||
|     case HloOpcode::kRsqrt: | ||||
|     case HloOpcode::kScatter: | ||||
|     case HloOpcode::kSelectAndScatter: | ||||
|  | ||||
| @ -2260,6 +2260,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( | ||||
|     case HloOpcode::kReplicaId: | ||||
|     case HloOpcode::kReshape: | ||||
|     case HloOpcode::kRng: | ||||
|     case HloOpcode::kRngBitGenerator: | ||||
|     case HloOpcode::kRngGetAndUpdateState: | ||||
|     case HloOpcode::kSend: | ||||
|     case HloOpcode::kSendDone: | ||||
|  | ||||
							
								
								
									
										105
									
								
								tensorflow/compiler/xla/service/rng_bit_generator_expander.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								tensorflow/compiler/xla/service/rng_bit_generator_expander.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,105 @@ | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h" | ||||
| 
 | ||||
| #include "tensorflow/compiler/xla/client/lib/prng.h" | ||||
| #include "tensorflow/compiler/xla/client/xla_builder.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_computation.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_instructions.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_module.h" | ||||
| #include "tensorflow/compiler/xla/shape.h" | ||||
| #include "tensorflow/compiler/xla/shape_util.h" | ||||
| #include "tensorflow/compiler/xla/statusor.h" | ||||
| #include "tensorflow/compiler/xla/util.h" | ||||
| #include "tensorflow/compiler/xla/xla_data.pb.h" | ||||
| #include "tensorflow/stream_executor/lib/statusor.h" | ||||
| 
 | ||||
| namespace xla { | ||||
| 
 | ||||
| bool RngBitGeneratorExpander::InstructionMatchesPattern( | ||||
|     HloInstruction* instruction) { | ||||
|   return instruction->opcode() == HloOpcode::kRngBitGenerator; | ||||
| } | ||||
| 
 | ||||
| StatusOr<HloComputation*> RngBitGeneratorExpander::GetGeneratorComputation( | ||||
|     const Shape& data_shape, const Shape& state_shape, | ||||
|     RandomAlgorithm algorithm, HloModule* module) { | ||||
|   RngGeneratorKey cache_key{data_shape, state_shape, algorithm, module}; | ||||
|   auto it = computation_cache_.find(cache_key); | ||||
|   if (it != computation_cache_.end()) { | ||||
|     return it->second; | ||||
|   } | ||||
| 
 | ||||
|   XlaBuilder builder("rng"); | ||||
|   XlaOp state_param = Parameter(&builder, 0, state_shape, "state"); | ||||
|   XlaOp key_op = Reshape(Slice(state_param, {0}, {1}, {1}), {}); | ||||
|   XlaOp state_op; | ||||
| 
 | ||||
|   BitGeneratorTy generator = nullptr; | ||||
|   switch (algorithm) { | ||||
|     case RandomAlgorithm::RNG_THREE_FRY: | ||||
|       generator = ThreeFryBitGenerator; | ||||
|       state_op = Slice(state_param, {1}, {2}, {1}); | ||||
|       break; | ||||
|     case RandomAlgorithm::RNG_PHILOX: | ||||
|       generator = PhiloxBitGenerator; | ||||
|       state_op = Slice(state_param, {1}, {3}, {1}); | ||||
|       break; | ||||
|     default: | ||||
|       return Unimplemented("Unsupported random algorthm: %s", | ||||
|                            RandomAlgorithm_Name(algorithm)); | ||||
|   } | ||||
| 
 | ||||
|   RngOutput output = generator(key_op, state_op, data_shape); | ||||
|   XlaOp final_state = | ||||
|       ConcatInDim(&builder, {Reshape(key_op, {1}), output.state}, 0); | ||||
|   Tuple(&builder, {final_state, output.value}); | ||||
|   TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); | ||||
| 
 | ||||
|   TF_ASSIGN_OR_RETURN(ProgramShape program_shape, | ||||
|                       xla_computation.GetProgramShape()); | ||||
|   HloModuleConfig config(program_shape); | ||||
|   TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( | ||||
|                                            xla_computation.proto(), config)); | ||||
|   HloCloneContext context(module); | ||||
|   HloComputation* new_computation = | ||||
|       module->DeepCloneComputation(new_module->entry_computation(), &context); | ||||
|   computation_cache_.emplace(cache_key, new_computation); | ||||
|   return new_computation; | ||||
| } | ||||
| 
 | ||||
| StatusOr<HloInstruction*> RngBitGeneratorExpander::ExpandInstruction( | ||||
|     HloInstruction* hlo) { | ||||
|   HloRngBitGeneratorInstruction* rng = Cast<HloRngBitGeneratorInstruction>(hlo); | ||||
|   RandomAlgorithm algorithm = rng->algorithm(); | ||||
|   if (algorithm == RandomAlgorithm::RNG_DEFAULT) { | ||||
|     algorithm = default_algorithm_; | ||||
|   } | ||||
| 
 | ||||
|   HloModule* module = hlo->parent()->parent(); | ||||
|   const Shape& data_shape = rng->shape().tuple_shapes(1); | ||||
|   const Shape& state_shape = rng->operand(0)->shape(); | ||||
|   TF_ASSIGN_OR_RETURN( | ||||
|       HloComputation * generator_computation, | ||||
|       GetGeneratorComputation(data_shape, state_shape, algorithm, module)); | ||||
|   return hlo->parent()->AddInstruction(HloInstruction::CreateCall( | ||||
|       ShapeUtil::MakeTupleShape({state_shape, data_shape}), | ||||
|       {hlo->mutable_operand(0)}, generator_computation)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace xla
 | ||||
							
								
								
									
										72
									
								
								tensorflow/compiler/xla/service/rng_bit_generator_expander.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										72
									
								
								tensorflow/compiler/xla/service/rng_bit_generator_expander.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,72 @@ | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_RNG_BIT_GENERATOR_EXPANDER_H_ | ||||
| #define TENSORFLOW_COMPILER_XLA_SERVICE_RNG_BIT_GENERATOR_EXPANDER_H_ | ||||
| 
 | ||||
| #include "absl/container/flat_hash_map.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_computation.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_module.h" | ||||
| #include "tensorflow/compiler/xla/service/op_expander_pass.h" | ||||
| #include "tensorflow/compiler/xla/shape_util.h" | ||||
| #include "tensorflow/compiler/xla/statusor.h" | ||||
| #include "tensorflow/compiler/xla/xla_data.pb.h" | ||||
| 
 | ||||
| namespace xla { | ||||
| 
 | ||||
| class RngBitGeneratorExpander : public OpExpanderPass { | ||||
|  public: | ||||
|   explicit RngBitGeneratorExpander(RandomAlgorithm default_algorithm) | ||||
|       : default_algorithm_(default_algorithm) { | ||||
|     CHECK_NE(default_algorithm_, RandomAlgorithm::RNG_DEFAULT); | ||||
|   } | ||||
| 
 | ||||
|   absl::string_view name() const override { | ||||
|     return "rng-bit-generator-expander"; | ||||
|   } | ||||
| 
 | ||||
|  protected: | ||||
|   struct RngGeneratorKey { | ||||
|     Shape data_shape; | ||||
|     Shape state_shape; | ||||
|     RandomAlgorithm algorithm; | ||||
|     HloModule* module; | ||||
| 
 | ||||
|     template <typename H> | ||||
|     friend H AbslHashValue(H h, const RngGeneratorKey& c) { | ||||
|       return H::combine(std::move(h), c.state_shape, c.data_shape, c.algorithm, | ||||
|                         c.module); | ||||
|     } | ||||
| 
 | ||||
|     bool operator==(const RngGeneratorKey& o) const { | ||||
|       return data_shape == o.data_shape && state_shape == o.state_shape && | ||||
|              algorithm == o.algorithm && module == o.module; | ||||
|     } | ||||
|   }; | ||||
| 
 | ||||
|   bool InstructionMatchesPattern(HloInstruction* instruction) override; | ||||
|   StatusOr<HloInstruction*> ExpandInstruction(HloInstruction* hlo) override; | ||||
|   StatusOr<HloComputation*> GetGeneratorComputation(const Shape& data_shape, | ||||
|                                                     const Shape& state_shape, | ||||
|                                                     RandomAlgorithm algorithm, | ||||
|                                                     HloModule* module); | ||||
| 
 | ||||
|   const RandomAlgorithm default_algorithm_; | ||||
|   absl::flat_hash_map<RngGeneratorKey, HloComputation*> computation_cache_; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace xla
 | ||||
| 
 | ||||
| #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_RNG_BIT_GENERATOR_EXPANDER_H_
 | ||||
| @ -84,7 +84,7 @@ StatusOr<HloComputation*> GetComputationForRng(HloInstruction* rng) { | ||||
| 
 | ||||
|   auto generator = [](xla::XlaOp key, xla::XlaOp state, | ||||
|                       const xla::Shape& shape) { | ||||
|     return PhiloxBitGenerator(key, state, shape, /*scramble=*/false); | ||||
|     return PhiloxBitGenerator(key, state, shape); | ||||
|   }; | ||||
| 
 | ||||
|   XlaOp result; | ||||
|  | ||||
| @ -559,6 +559,13 @@ enum RandomDistribution { | ||||
|   // Next: 4 | ||||
| } | ||||
| 
 | ||||
| enum RandomAlgorithm { | ||||
|   RNG_DEFAULT = 0;  // Backend dependent default algorithm. | ||||
|   RNG_THREE_FRY = 1; | ||||
|   RNG_PHILOX = 2; | ||||
|   // Next: 2 | ||||
| } | ||||
| 
 | ||||
| message TriangularSolveOptions { | ||||
|   // If true, solves ax = b. If false, solves xa = b. | ||||
|   bool left_side = 1; | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user