cai_lw's competitive programming library
 
Loading...
Searching...
No Matches
wavelet_array.hpp
1#pragma once
2
3#include <algorithm>
4#include <array>
5#include <iterator>
6#include <vector>
7
8#include "cplib/range/bit_dict.hpp"
9
10namespace cplib {
11
22template <typename T, int M = std::numeric_limits<T>::digits>
24 public:
25 static_assert(M > 0 && M <= std::numeric_limits<T>::digits);
26
27 static constexpr T max_value = M == std::numeric_limits<T>::digits ? std::numeric_limits<T>::max() : (T(1) << M) - 1;
28 using size_type = BitDict::size_type;
29
32
34 WaveletArray(std::vector<T>&& data) {
35 std::vector<T> temp = std::move(data);
36 *this = build_and_sort(temp.data(), temp.size());
37 }
38
45 static WaveletArray build_and_sort(T* data, size_type size) {
47 std::unique_ptr<T[]> temp(new T[size]);
48 for (int lvl = M - 1; lvl >= 0; lvl--) {
49 BitDict& dict = wa.bit_dict[lvl];
50 dict = BitDict(size);
51 // Set the dict with the lvl-th bit of each element, then stably sort them by this bit.
52 T* in = data;
53 T* out0 = temp.get();
54 T* out1 = temp.get() + (size - 1);
55 const auto bit_generator = [&]() {
56 bool bit = (*in >> lvl) & 1;
57 if (bit) {
58 *out1 = *in;
59 out1--;
60 } else {
61 *out0 = *in;
62 out0++;
63 }
64 in++;
65 return bit;
66 };
67 dict.fill_with_bit_generator(bit_generator);
68 dict.build();
69 // temp has all "0" elements in original order, followed by all "1" elements in reversed original order.
70 T* data_mid = std::copy(temp.get(), out0, data);
71 std::reverse_copy(out0, temp.get() + size, data_mid);
72 }
73 return wa;
74 }
75
77 size_type size() const { return bit_dict[0].size(); }
78
84 T get(size_type idx) const {
85 T ret = 0;
86 for (int lvl = M - 1; lvl >= 0; lvl--) {
87 bool bit = bit_dict[lvl].get(idx);
88 ret |= T(bit) << lvl;
89 idx = bit_dict[lvl].rank_to_child(idx, bit);
90 }
91 return ret;
92 }
93
99 T range_nth(size_type left, size_type right, size_type n) const {
100 T ret = 0;
101 for (int lvl = M - 1; lvl >= 0; lvl--) {
102 const BitDict& bd = bit_dict[lvl];
103 size_type zero_count = bd.rank0(right) - bd.rank0(left);
104 bool bit = n >= zero_count;
105 ret |= T(bit) << lvl;
106 if (bit) {
107 n -= zero_count;
108 }
109 left = bd.rank_to_child(left, bit);
110 right = bd.rank_to_child(right, bit);
111 }
112 return ret;
113 }
114
120 size_type range_count(size_type left, size_type right, T val) const {
121 for (int lvl = M - 1; lvl >= 0; lvl--) {
122 bool bit = (val >> lvl) & 1;
123 left = bit_dict[lvl].rank_to_child(left, bit);
124 right = bit_dict[lvl].rank_to_child(right, bit);
125 }
126 return right - left;
127 }
128
135 size_type range_count_between(size_type left, size_type right, T low, T high) const {
136 return _rangefreq(left, right, low, high, M - 1);
137 }
138
139 private:
140 std::array<BitDict, M> bit_dict;
141
142 size_type _rangefreq(size_type left, size_type right, T low, T high, int lvl) const {
143 if (left >= right) {
144 return 0;
145 } else if (high - low == (lvl == M - 1 ? max_value : (T(1) << (lvl + 1)) - 1)) {
146 return right - left;
147 }
148 T bit_mask = T(1) << lvl;
149 const BitDict& bd = bit_dict[lvl];
150 if (bit_mask & (low ^ high)) {
151 T split = ~(bit_mask - 1) & high;
152 return _rangefreq(bd.rank_to_child(left, false), bd.rank_to_child(right, false), low, split - 1, lvl - 1) +
153 _rangefreq(bd.rank_to_child(left, true), bd.rank_to_child(right, true), split, high, lvl - 1);
154 } else {
155 bool bit = bit_mask & low;
156 return _rangefreq(bd.rank_to_child(left, bit), bd.rank_to_child(right, bit), low, high, lvl - 1);
157 }
158 }
159};
160
161} // namespace cplib
Static bit sequence with rank query in .
Definition: bit_dict.hpp:73
size_type rank_to_child(size_type idx, bool bit) const
Move one level downwards in a wavelet tree.
Definition: bit_dict.hpp:151
void fill_with_bit_generator(BitGenerator &gen)
Overwrite all bits with a bit generator.
Definition: bit_dict.hpp:112
void build()
Set up auxiliary data to prepare for rank queries.
Definition: bit_dict.hpp:121
size_type rank0(size_type idx) const
Returns the number of 0 bits in the range [0, idx).
Definition: bit_dict.hpp:138
Efficient representation of a wavelet tree, supporting various static range queries.
Definition: wavelet_array.hpp:23
static WaveletArray build_and_sort(T *data, size_type size)
Builds a WaveletArray from a mutable buffer, sorting it in place as a side effect.
Definition: wavelet_array.hpp:45
T get(size_type idx) const
Returns the element at the given index.
Definition: wavelet_array.hpp:84
WaveletArray()
Creates an empty WaveletArray.
Definition: wavelet_array.hpp:31
size_type size() const
Returns the number of elements.
Definition: wavelet_array.hpp:77
size_type range_count(size_type left, size_type right, T val) const
Returns the number of the given value in the range [left, right).
Definition: wavelet_array.hpp:120
size_type range_count_between(size_type left, size_type right, T low, T high) const
Returns the number of elements in the range [left, right) whose values are between [low,...
Definition: wavelet_array.hpp:135
T range_nth(size_type left, size_type right, size_type n) const
Returns the 0-indexed n-th smallest element in the range [left, right).
Definition: wavelet_array.hpp:99
WaveletArray(std::vector< T > &&data)
Creates a WaveletArray by consuming an array of elements.
Definition: wavelet_array.hpp:34