Yi Kong | 878f994 | 2023-12-13 12:55:04 +0900 | [diff] [blame^] | 1 | //===----------------------------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #ifndef _LIBCPP___ALGORITHM_SAMPLE_H |
| 10 | #define _LIBCPP___ALGORITHM_SAMPLE_H |
| 11 | |
| 12 | #include <__algorithm/iterator_operations.h> |
| 13 | #include <__algorithm/min.h> |
| 14 | #include <__assert> |
| 15 | #include <__config> |
| 16 | #include <__iterator/distance.h> |
| 17 | #include <__iterator/iterator_traits.h> |
| 18 | #include <__random/uniform_int_distribution.h> |
| 19 | #include <__type_traits/common_type.h> |
| 20 | #include <__utility/move.h> |
| 21 | |
| 22 | #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) |
| 23 | # pragma GCC system_header |
| 24 | #endif |
| 25 | |
| 26 | _LIBCPP_PUSH_MACROS |
| 27 | #include <__undef_macros> |
| 28 | |
| 29 | _LIBCPP_BEGIN_NAMESPACE_STD |
| 30 | |
| 31 | template <class _AlgPolicy, |
| 32 | class _PopulationIterator, class _PopulationSentinel, class _SampleIterator, class _Distance, |
| 33 | class _UniformRandomNumberGenerator> |
| 34 | _LIBCPP_INLINE_VISIBILITY |
| 35 | _SampleIterator __sample(_PopulationIterator __first, |
| 36 | _PopulationSentinel __last, _SampleIterator __output_iter, |
| 37 | _Distance __n, |
| 38 | _UniformRandomNumberGenerator& __g, |
| 39 | input_iterator_tag) { |
| 40 | |
| 41 | _Distance __k = 0; |
| 42 | for (; __first != __last && __k < __n; ++__first, (void) ++__k) |
| 43 | __output_iter[__k] = *__first; |
| 44 | _Distance __sz = __k; |
| 45 | for (; __first != __last; ++__first, (void) ++__k) { |
| 46 | _Distance __r = uniform_int_distribution<_Distance>(0, __k)(__g); |
| 47 | if (__r < __sz) |
| 48 | __output_iter[__r] = *__first; |
| 49 | } |
| 50 | return __output_iter + _VSTD::min(__n, __k); |
| 51 | } |
| 52 | |
| 53 | template <class _AlgPolicy, |
| 54 | class _PopulationIterator, class _PopulationSentinel, class _SampleIterator, class _Distance, |
| 55 | class _UniformRandomNumberGenerator> |
| 56 | _LIBCPP_INLINE_VISIBILITY |
| 57 | _SampleIterator __sample(_PopulationIterator __first, |
| 58 | _PopulationSentinel __last, _SampleIterator __output_iter, |
| 59 | _Distance __n, |
| 60 | _UniformRandomNumberGenerator& __g, |
| 61 | forward_iterator_tag) { |
| 62 | _Distance __unsampled_sz = _IterOps<_AlgPolicy>::distance(__first, __last); |
| 63 | for (__n = _VSTD::min(__n, __unsampled_sz); __n != 0; ++__first) { |
| 64 | _Distance __r = uniform_int_distribution<_Distance>(0, --__unsampled_sz)(__g); |
| 65 | if (__r < __n) { |
| 66 | *__output_iter++ = *__first; |
| 67 | --__n; |
| 68 | } |
| 69 | } |
| 70 | return __output_iter; |
| 71 | } |
| 72 | |
| 73 | template <class _AlgPolicy, |
| 74 | class _PopulationIterator, class _PopulationSentinel, class _SampleIterator, class _Distance, |
| 75 | class _UniformRandomNumberGenerator> |
| 76 | _LIBCPP_INLINE_VISIBILITY |
| 77 | _SampleIterator __sample(_PopulationIterator __first, |
| 78 | _PopulationSentinel __last, _SampleIterator __output_iter, |
| 79 | _Distance __n, _UniformRandomNumberGenerator& __g) { |
| 80 | _LIBCPP_ASSERT_UNCATEGORIZED(__n >= 0, "N must be a positive number."); |
| 81 | |
| 82 | using _PopIterCategory = typename _IterOps<_AlgPolicy>::template __iterator_category<_PopulationIterator>; |
| 83 | using _Difference = typename _IterOps<_AlgPolicy>::template __difference_type<_PopulationIterator>; |
| 84 | using _CommonType = typename common_type<_Distance, _Difference>::type; |
| 85 | |
| 86 | return std::__sample<_AlgPolicy>( |
| 87 | std::move(__first), std::move(__last), std::move(__output_iter), _CommonType(__n), |
| 88 | __g, _PopIterCategory()); |
| 89 | } |
| 90 | |
| 91 | #if _LIBCPP_STD_VER >= 17 |
| 92 | template <class _PopulationIterator, class _SampleIterator, class _Distance, |
| 93 | class _UniformRandomNumberGenerator> |
| 94 | inline _LIBCPP_INLINE_VISIBILITY |
| 95 | _SampleIterator sample(_PopulationIterator __first, |
| 96 | _PopulationIterator __last, _SampleIterator __output_iter, |
| 97 | _Distance __n, _UniformRandomNumberGenerator&& __g) { |
| 98 | static_assert(__has_forward_iterator_category<_PopulationIterator>::value || |
| 99 | __has_random_access_iterator_category<_SampleIterator>::value, |
| 100 | "SampleIterator must meet the requirements of RandomAccessIterator"); |
| 101 | |
| 102 | return std::__sample<_ClassicAlgPolicy>( |
| 103 | std::move(__first), std::move(__last), std::move(__output_iter), __n, __g); |
| 104 | } |
| 105 | |
| 106 | #endif // _LIBCPP_STD_VER >= 17 |
| 107 | |
| 108 | _LIBCPP_END_NAMESPACE_STD |
| 109 | |
| 110 | _LIBCPP_POP_MACROS |
| 111 | |
| 112 | #endif // _LIBCPP___ALGORITHM_SAMPLE_H |