partition.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. // -*- C++ -*-
  2. // Copyright (C) 2007-2015 Free Software Foundation, Inc.
  3. //
  4. // This file is part of the GNU ISO C++ Library. This library is free
  5. // software; you can redistribute it and/or modify it under the terms
  6. // of the GNU General Public License as published by the Free Software
  7. // Foundation; either version 3, or (at your option) any later
  8. // version.
  9. // This library is distributed in the hope that it will be useful, but
  10. // WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  12. // General Public License for more details.
  13. // Under Section 7 of GPL version 3, you are granted additional
  14. // permissions described in the GCC Runtime Library Exception, version
  15. // 3.1, as published by the Free Software Foundation.
  16. // You should have received a copy of the GNU General Public License and
  17. // a copy of the GCC Runtime Library Exception along with this program;
  18. // see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
  19. // <http://www.gnu.org/licenses/>.
  20. /** @file parallel/partition.h
  21. * @brief Parallel implementation of std::partition(),
  22. * std::nth_element(), and std::partial_sort().
  23. * This file is a GNU parallel extension to the Standard C++ Library.
  24. */
  25. // Written by Johannes Singler and Felix Putze.
  26. #ifndef _GLIBCXX_PARALLEL_PARTITION_H
  27. #define _GLIBCXX_PARALLEL_PARTITION_H 1
  28. #include <parallel/basic_iterator.h>
  29. #include <parallel/sort.h>
  30. #include <parallel/random_number.h>
  31. #include <bits/stl_algo.h>
  32. #include <parallel/parallel.h>
  33. /** @brief Decide whether to declare certain variables volatile. */
  34. #define _GLIBCXX_VOLATILE volatile
  35. namespace __gnu_parallel
  36. {
  37. /** @brief Parallel implementation of std::partition.
  38. * @param __begin Begin iterator of input sequence to split.
  39. * @param __end End iterator of input sequence to split.
  40. * @param __pred Partition predicate, possibly including some kind
  41. * of pivot.
  42. * @param __num_threads Maximum number of threads to use for this task.
  43. * @return Number of elements not fulfilling the predicate. */
  44. template<typename _RAIter, typename _Predicate>
  45. typename std::iterator_traits<_RAIter>::difference_type
  46. __parallel_partition(_RAIter __begin, _RAIter __end,
  47. _Predicate __pred, _ThreadIndex __num_threads)
  48. {
  49. typedef std::iterator_traits<_RAIter> _TraitsType;
  50. typedef typename _TraitsType::value_type _ValueType;
  51. typedef typename _TraitsType::difference_type _DifferenceType;
  52. _DifferenceType __n = __end - __begin;
  53. _GLIBCXX_CALL(__n)
  54. const _Settings& __s = _Settings::get();
  55. // shared
  56. _GLIBCXX_VOLATILE _DifferenceType __left = 0, __right = __n - 1,
  57. __dist = __n,
  58. __leftover_left, __leftover_right,
  59. __leftnew, __rightnew;
  60. // just 0 or 1, but int to allow atomic operations
  61. int* __reserved_left = 0, * __reserved_right = 0;
  62. _DifferenceType __chunk_size = __s.partition_chunk_size;
  63. //at least two chunks per thread
  64. if (__dist >= 2 * __num_threads * __chunk_size)
  65. # pragma omp parallel num_threads(__num_threads)
  66. {
  67. # pragma omp single
  68. {
  69. __num_threads = omp_get_num_threads();
  70. __reserved_left = new int[__num_threads];
  71. __reserved_right = new int[__num_threads];
  72. if (__s.partition_chunk_share > 0.0)
  73. __chunk_size = std::max<_DifferenceType>
  74. (__s.partition_chunk_size, (double)__n
  75. * __s.partition_chunk_share / (double)__num_threads);
  76. else
  77. __chunk_size = __s.partition_chunk_size;
  78. }
  79. while (__dist >= 2 * __num_threads * __chunk_size)
  80. {
  81. # pragma omp single
  82. {
  83. _DifferenceType __num_chunks = __dist / __chunk_size;
  84. for (_ThreadIndex __r = 0; __r < __num_threads; ++__r)
  85. {
  86. __reserved_left [__r] = 0; // false
  87. __reserved_right[__r] = 0; // false
  88. }
  89. __leftover_left = 0;
  90. __leftover_right = 0;
  91. } //implicit barrier
  92. // Private.
  93. _DifferenceType __thread_left, __thread_left_border,
  94. __thread_right, __thread_right_border;
  95. __thread_left = __left + 1;
  96. // Just to satisfy the condition below.
  97. __thread_left_border = __thread_left - 1;
  98. __thread_right = __n - 1;
  99. // Just to satisfy the condition below.
  100. __thread_right_border = __thread_right + 1;
  101. bool __iam_finished = false;
  102. while (!__iam_finished)
  103. {
  104. if (__thread_left > __thread_left_border)
  105. {
  106. _DifferenceType __former_dist =
  107. __fetch_and_add(&__dist, -__chunk_size);
  108. if (__former_dist < __chunk_size)
  109. {
  110. __fetch_and_add(&__dist, __chunk_size);
  111. __iam_finished = true;
  112. break;
  113. }
  114. else
  115. {
  116. __thread_left =
  117. __fetch_and_add(&__left, __chunk_size);
  118. __thread_left_border =
  119. __thread_left + (__chunk_size - 1);
  120. }
  121. }
  122. if (__thread_right < __thread_right_border)
  123. {
  124. _DifferenceType __former_dist =
  125. __fetch_and_add(&__dist, -__chunk_size);
  126. if (__former_dist < __chunk_size)
  127. {
  128. __fetch_and_add(&__dist, __chunk_size);
  129. __iam_finished = true;
  130. break;
  131. }
  132. else
  133. {
  134. __thread_right =
  135. __fetch_and_add(&__right, -__chunk_size);
  136. __thread_right_border =
  137. __thread_right - (__chunk_size - 1);
  138. }
  139. }
  140. // Swap as usual.
  141. while (__thread_left < __thread_right)
  142. {
  143. while (__pred(__begin[__thread_left])
  144. && __thread_left <= __thread_left_border)
  145. ++__thread_left;
  146. while (!__pred(__begin[__thread_right])
  147. && __thread_right >= __thread_right_border)
  148. --__thread_right;
  149. if (__thread_left > __thread_left_border
  150. || __thread_right < __thread_right_border)
  151. // Fetch new chunk(__s).
  152. break;
  153. std::iter_swap(__begin + __thread_left,
  154. __begin + __thread_right);
  155. ++__thread_left;
  156. --__thread_right;
  157. }
  158. }
  159. // Now swap the leftover chunks to the right places.
  160. if (__thread_left <= __thread_left_border)
  161. # pragma omp atomic
  162. ++__leftover_left;
  163. if (__thread_right >= __thread_right_border)
  164. # pragma omp atomic
  165. ++__leftover_right;
  166. # pragma omp barrier
  167. _DifferenceType
  168. __leftold = __left,
  169. __leftnew = __left - __leftover_left * __chunk_size,
  170. __rightold = __right,
  171. __rightnew = __right + __leftover_right * __chunk_size;
  172. // <=> __thread_left_border + (__chunk_size - 1) >= __leftnew
  173. if (__thread_left <= __thread_left_border
  174. && __thread_left_border >= __leftnew)
  175. {
  176. // Chunk already in place, reserve spot.
  177. __reserved_left[(__left - (__thread_left_border + 1))
  178. / __chunk_size] = 1;
  179. }
  180. // <=> __thread_right_border - (__chunk_size - 1) <= __rightnew
  181. if (__thread_right >= __thread_right_border
  182. && __thread_right_border <= __rightnew)
  183. {
  184. // Chunk already in place, reserve spot.
  185. __reserved_right[((__thread_right_border - 1) - __right)
  186. / __chunk_size] = 1;
  187. }
  188. # pragma omp barrier
  189. if (__thread_left <= __thread_left_border
  190. && __thread_left_border < __leftnew)
  191. {
  192. // Find spot and swap.
  193. _DifferenceType __swapstart = -1;
  194. for (int __r = 0; __r < __leftover_left; ++__r)
  195. if (__reserved_left[__r] == 0
  196. && __compare_and_swap(&(__reserved_left[__r]), 0, 1))
  197. {
  198. __swapstart = __leftold - (__r + 1) * __chunk_size;
  199. break;
  200. }
  201. #if _GLIBCXX_ASSERTIONS
  202. _GLIBCXX_PARALLEL_ASSERT(__swapstart != -1);
  203. #endif
  204. std::swap_ranges(__begin + __thread_left_border
  205. - (__chunk_size - 1),
  206. __begin + __thread_left_border + 1,
  207. __begin + __swapstart);
  208. }
  209. if (__thread_right >= __thread_right_border
  210. && __thread_right_border > __rightnew)
  211. {
  212. // Find spot and swap
  213. _DifferenceType __swapstart = -1;
  214. for (int __r = 0; __r < __leftover_right; ++__r)
  215. if (__reserved_right[__r] == 0
  216. && __compare_and_swap(&(__reserved_right[__r]), 0, 1))
  217. {
  218. __swapstart = __rightold + __r * __chunk_size + 1;
  219. break;
  220. }
  221. #if _GLIBCXX_ASSERTIONS
  222. _GLIBCXX_PARALLEL_ASSERT(__swapstart != -1);
  223. #endif
  224. std::swap_ranges(__begin + __thread_right_border,
  225. __begin + __thread_right_border
  226. + __chunk_size, __begin + __swapstart);
  227. }
  228. #if _GLIBCXX_ASSERTIONS
  229. # pragma omp barrier
  230. # pragma omp single
  231. {
  232. for (_DifferenceType __r = 0; __r < __leftover_left; ++__r)
  233. _GLIBCXX_PARALLEL_ASSERT(__reserved_left[__r] == 1);
  234. for (_DifferenceType __r = 0; __r < __leftover_right; ++__r)
  235. _GLIBCXX_PARALLEL_ASSERT(__reserved_right[__r] == 1);
  236. }
  237. #endif
  238. __left = __leftnew;
  239. __right = __rightnew;
  240. __dist = __right - __left + 1;
  241. }
  242. # pragma omp flush(__left, __right)
  243. } // end "recursion" //parallel
  244. _DifferenceType __final_left = __left, __final_right = __right;
  245. while (__final_left < __final_right)
  246. {
  247. // Go right until key is geq than pivot.
  248. while (__pred(__begin[__final_left])
  249. && __final_left < __final_right)
  250. ++__final_left;
  251. // Go left until key is less than pivot.
  252. while (!__pred(__begin[__final_right])
  253. && __final_left < __final_right)
  254. --__final_right;
  255. if (__final_left == __final_right)
  256. break;
  257. std::iter_swap(__begin + __final_left, __begin + __final_right);
  258. ++__final_left;
  259. --__final_right;
  260. }
  261. // All elements on the left side are < piv, all elements on the
  262. // right are >= piv
  263. delete[] __reserved_left;
  264. delete[] __reserved_right;
  265. // Element "between" __final_left and __final_right might not have
  266. // been regarded yet
  267. if (__final_left < __n && !__pred(__begin[__final_left]))
  268. // Really swapped.
  269. return __final_left;
  270. else
  271. return __final_left + 1;
  272. }
  273. /**
  274. * @brief Parallel implementation of std::nth_element().
  275. * @param __begin Begin iterator of input sequence.
  276. * @param __nth _Iterator of element that must be in position afterwards.
  277. * @param __end End iterator of input sequence.
  278. * @param __comp Comparator.
  279. */
  280. template<typename _RAIter, typename _Compare>
  281. void
  282. __parallel_nth_element(_RAIter __begin, _RAIter __nth,
  283. _RAIter __end, _Compare __comp)
  284. {
  285. typedef std::iterator_traits<_RAIter> _TraitsType;
  286. typedef typename _TraitsType::value_type _ValueType;
  287. typedef typename _TraitsType::difference_type _DifferenceType;
  288. _GLIBCXX_CALL(__end - __begin)
  289. _RAIter __split;
  290. _RandomNumber __rng;
  291. const _Settings& __s = _Settings::get();
  292. _DifferenceType __minimum_length = std::max<_DifferenceType>(2,
  293. std::max(__s.nth_element_minimal_n, __s.partition_minimal_n));
  294. // Break if input range to small.
  295. while (static_cast<_SequenceIndex>(__end - __begin) >= __minimum_length)
  296. {
  297. _DifferenceType __n = __end - __begin;
  298. _RAIter __pivot_pos = __begin + __rng(__n);
  299. // Swap __pivot_pos value to end.
  300. if (__pivot_pos != (__end - 1))
  301. std::iter_swap(__pivot_pos, __end - 1);
  302. __pivot_pos = __end - 1;
  303. // _Compare must have first_value_type, second_value_type,
  304. // result_type
  305. // _Compare ==
  306. // __gnu_parallel::_Lexicographic<S, int,
  307. // __gnu_parallel::_Less<S, S> >
  308. // __pivot_pos == std::pair<S, int>*
  309. __gnu_parallel::__binder2nd<_Compare, _ValueType, _ValueType, bool>
  310. __pred(__comp, *__pivot_pos);
  311. // Divide, leave pivot unchanged in last place.
  312. _RAIter __split_pos1, __split_pos2;
  313. __split_pos1 = __begin + __parallel_partition(__begin, __end - 1,
  314. __pred,
  315. __get_max_threads());
  316. // Left side: < __pivot_pos; __right side: >= __pivot_pos
  317. // Swap pivot back to middle.
  318. if (__split_pos1 != __pivot_pos)
  319. std::iter_swap(__split_pos1, __pivot_pos);
  320. __pivot_pos = __split_pos1;
  321. // In case all elements are equal, __split_pos1 == 0
  322. if ((__split_pos1 + 1 - __begin) < (__n >> 7)
  323. || (__end - __split_pos1) < (__n >> 7))
  324. {
  325. // Very unequal split, one part smaller than one 128th
  326. // elements not strictly larger than the pivot.
  327. __gnu_parallel::__unary_negate<__gnu_parallel::
  328. __binder1st<_Compare, _ValueType,
  329. _ValueType, bool>, _ValueType>
  330. __pred(__gnu_parallel::__binder1st<_Compare, _ValueType,
  331. _ValueType, bool>(__comp, *__pivot_pos));
  332. // Find other end of pivot-equal range.
  333. __split_pos2 = __gnu_sequential::partition(__split_pos1 + 1,
  334. __end, __pred);
  335. }
  336. else
  337. // Only skip the pivot.
  338. __split_pos2 = __split_pos1 + 1;
  339. // Compare iterators.
  340. if (__split_pos2 <= __nth)
  341. __begin = __split_pos2;
  342. else if (__nth < __split_pos1)
  343. __end = __split_pos1;
  344. else
  345. break;
  346. }
  347. // Only at most _Settings::partition_minimal_n __elements __left.
  348. __gnu_sequential::nth_element(__begin, __nth, __end, __comp);
  349. }
  350. /** @brief Parallel implementation of std::partial_sort().
  351. * @param __begin Begin iterator of input sequence.
  352. * @param __middle Sort until this position.
  353. * @param __end End iterator of input sequence.
  354. * @param __comp Comparator. */
  355. template<typename _RAIter, typename _Compare>
  356. void
  357. __parallel_partial_sort(_RAIter __begin,
  358. _RAIter __middle,
  359. _RAIter __end, _Compare __comp)
  360. {
  361. __parallel_nth_element(__begin, __middle, __end, __comp);
  362. std::sort(__begin, __middle, __comp);
  363. }
  364. } //namespace __gnu_parallel
  365. #undef _GLIBCXX_VOLATILE
  366. #endif /* _GLIBCXX_PARALLEL_PARTITION_H */