희소 테이블(Sparse Table)
설명
희소 테이블이란 정적 데이터에서 구간 쿼리를 빠르게 계산할수 있는 자료 구조다. 예를 들어 Array[0] - Array[N]
과 함수 $F$가 존재할 때, $0 \leq i \leq j \leq N$을 만족하는 $F(i, j)$를 빠르게 구할 수 있다.
이를 사용하면 희소 테이블을 구성하는 데에는 $O(N \log N)$, 쿼리 당 $O(\log N)$의 처리가 가능해 문제를 $O(N \log N)$의 시간에 해결할 수 있다.
희소 테이블을 사용할 수 있는 조건은 다음과 같다.
- 결합 법칙이 성립: $F(a, b, c) = F(F(a, b), c) = F(a, F(b, c))$
- 데이터가 변하지 않음
이 두 조건을 만족하는 데이터와 함수, 그리고 쿼리 수가 많은 문제가 주어진다면 희소 테이블을 사용하는 것이 좋다.
최댓값, 최솟값 같은 특정 쿼리의 경우에 $O(1)$에 쿼리 처리가 가능하다.
예제
구간 합
전처리
$F(a, b) =\sum_{i = a}^{b} Array[i]$
i | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
---|---|---|---|---|---|---|---|---|
Array[i] | 1 | 3 | 5 | 2 | 8 | 5 | 3 | 10 |
table[k][i]
= $F(i,i+2^{k}-1)$
table[0][i]
table[0][i]
= $F(i, i)$이므로 전부 Array[i]
를 대입한다.
나머지
$F(i, i + 2^{k} - 1) = F(i, i + 2^{k - 1} - 1) + F(i + 2^{k - 1}, i + 2^{k - 1} + 2^{k - 1} - 1)$
table[k][i] = table[k - 1][i] + table[k - 1][i + (1 << (k - 1))]
한 행이 모두 있다면 다음 행은 밑 행을 기준으로 bottom-up하게 채울 수 있다.
i | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
---|---|---|---|---|---|---|---|---|
table[0][i] | 1 | 3 | 5 | 2 | 8 | 5 | 3 | 10 |
table[1][i] | 4 | 8 | 7 | 10 | 13 | 18 | 13 | |
table[2][i] | 11 | 18 | 20 | 28 | 26 | |||
table[3][i] | 37 |
쿼리
$F(a,b$를 구할 때, 범위 $[a, b]$의 길이는 $b - a + 1$이고, 이는 2의 제곱으로 나타낼 수 있다. 2의 제곱수만큼의 길이로 분리된 $F(a, b)$를 계산하면 $O(\log N)$의 시간에 쿼리를 구할 수 있다.
$diff = a - b + 1 = 2^{x} + 2^{y} + 2^{z}$
$F(a, b) = \ F(a, a + 2^{x} - 1) + F(a + 2^{x}, a + 2^{x} + 2^{y} - 1) + F(a + 2^{x} + 2^{y}, a + 2^{x} + 2^{y} + 2^{z} - 1) + F(a + 2^{x} + 2^{y} + 2^{z}, b)$
int partialSum(vector<vector<int>> &table, int a, int b) {
int diff = b - a + 1;
int sum = 0;
for (int i = table.size() - 1, j = a; i >= 0; i--) {
if (diff & (1 << i)) {
sum += table[i][j];
j += (1 << i);
}
}
return sum;
}
문제
위 문제를 희소 테이블로 해결 가능하다.
#include <iostream>
#include <vector>
#include <cmath>
using namespace std;
int partialMin(vector<vector<int>> &table, int a, int b) {
int diff = b - a + 1;
int sum = 0;
for (int i = table.size() - 1, j = a; i >= 0; i--) {
if (diff & (1 << i)) {
sum += table[i][j];
j += (1 << i);
}
}
return sum;
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
int N, M;
cin >> N >> M;
int maxLog = static_cast<int>(log2(N));
vector<int> arr(N);
vector<vector<int>> table(maxLog + 1, vector<int>(N));
for (int i = 0; i < N; i++) {
cin >> arr[i];
table[0][i] = arr[i];
}
for (int k = 1; k <= maxLog; k++) {
for (int i = 0; i + (1 << (k - 1)) < N; i++) {
table[k][i] = table[k - 1][i] + table[k - 1][i + (1 << (k - 1))];
}
}
while (M-- > 0) {
int i, j;
cin >> i >> j;
cout << partialMin(table, i - 1, j - 1) << '\n';
}
return 0;
}
그런데 사실 위 문제의 경우 누적합으로 구하는 것이 더 빠르고 편리하다.
구간 최솟값
희소 테이블은 쿼리가 $O(\log N)$일 때도 좋지만 $O(1)$에 처리가 될 때 굉장히 유용하다. 구간 최솟값은 그러한 문제 중 하나다.
전처리
$F(a, b) = Min(a, b)$
i | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
---|---|---|---|---|---|---|---|---|
Array[i] | 1 | 3 | 5 | 2 | 8 | 5 | 3 | 10 |
$F(i, i + 2^{k} - 1) = Min(F(i, i + 2^{k - 1} - 1), F(i + 2^{k - 1}, i + 2^{k - 1} + 2^{k - 1} - 1))$
table[k][i] = min(table[k - 1][i], table[k - 1][i + (1 << (k - 1))])
구간 합과 비슷한 방법으로 테이블을 채울 수 있다.
i | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
---|---|---|---|---|---|---|---|---|
table[0][i] | 1 | 3 | 5 | 2 | 8 | 5 | 3 | 10 |
table[1][i] | 1 | 3 | 2 | 2 | 5 | 3 | 3 | |
table[2][i] | 1 | 2 | 2 | 2 | 3 | |||
table[3][i] | 1 |
쿼리
구간 합을 구할 때는 전체 구간을 2의 제곱수의 크기로 나누어 하나 하나 더해주었지만, 최솟값의 경우 그렇게 할 필요가 없다. 일단 구간의 경계 a, b가 주어지면,$a + 2^{x} \leq b$를 만족하는 $x$의 최솟값을 찾는다.$x = \log (b - a + 1)$로 구할 수 있다.
구간을 여러 개가 아닌 크기 $2^{x}$인 구간 두 개만 있어도 a와 b 사이의 원소는 전부 커버할 수 있다. 따라서 $O(1)$ 안에 구간 최솟값을 구할 수 있다.
문제
위 문제는 구간 최솟값과 최댓값을 구하는 문제인데, 최댓값도 최솟값과 마찬가지로 희소 테이블을 사용해 $O(1)$ 시간에 구할 수 있다.
#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
using namespace std;
int partialMin(vector<vector<int>> &minTable, int a, int b) {
int x = log2(b - a + 1);
return min(minTable[x][a], minTable[x][b - (1 << x) + 1]);
}
int partialMax(vector<vector<int>> &maxTable, int a, int b) {
int x = log2(b - a + 1);
return max(maxTable[x][a], maxTable[x][b - (1 << x) + 1]);
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
int N, M;
cin >> N >> M;
int maxLog = static_cast<int>(log2(N));
vector<int> arr(N);
vector<vector<int>> minTable(maxLog + 1, vector<int>(N));
vector<vector<int>> maxTable(maxLog + 1, vector<int>(N));
for (int i = 0; i < N; i++) {
cin >> arr[i];
minTable[0][i] = arr[i];
maxTable[0][i] = arr[i];
}
for (int k = 1; k <= maxLog; k++) {
for (int i = 0; i + (1 << (k - 1)) < N; i++) {
minTable[k][i] = min(minTable[k - 1][i], minTable[k - 1][i + (1 << (k - 1))]);
maxTable[k][i] = max(maxTable[k - 1][i], maxTable[k - 1][i + (1 << (k - 1))]);
}
}
while (M-- > 0) {
int i, j;
cin >> i >> j;
cout << partialMin(minTable, i - 1, j - 1) << ' ' << partialMax(maxTable, i - 1, j - 1) << '\n';
}
return 0;
}
Comments