백준 1167번: 트리의 지름
문제
트리의 지름
트리의 지름이란 트리에서 임의의 두 점 사이의 거리 중 가장 긴 것을 의미한다.
아이디어
트리의 지름을 구하는 알고리즘은 잘 알려진 웰노운 테크닉이다. 굉장히 간단한 알고리즘으로, 정리하면 다음과 같다.
- 임의의 한 정점 u에서 dfs를 통해 가장 멀리 있는 정점을 찾는다. 이 정점을 v라고 하자.
- v를 시작점으로 dfs를 한다. 이 때 v와 가장 멀리 있는 정점을 w라고 하자.
- v와 w의 거리가 트리의 지름이다.
증명
$u$: 탐색을 시작하는 임의의 정점
$v$: u에서 가장 멀리 있는 정점
$d_1, d_2$: 트리의 지름의 양 끝
$d(u, v)$: $u$와 $v$의 경로의 길이
$p(u, v)$: $u$에서 $v$로 가는 경로
트리의 지름 알고리즘은 귀류법으로 증명할 수 있다. 원래대로라면 $v = d_1$ or $v = d_2$지만, $v \neq d_1. d_2$라고 가정해보자.
Case 1
$u$가 트리의 지름에 포함된 경우이다. 여기서 트리의 지름에 포함된다는 말은 지름의 양 끝을 잇는 경로에 $u$가 포함된다는 의미이다.
이 경우 모순이 발생한다. $v$와 가까운 지름의 끝을 $d_1$, 먼 지름의 끝을 $d_2$라고 가정하자. $d(v, u) + d(u, d_2) \geq d(d_1, d_2)$로 v를 포함한 더 큰 지름을 얻을 수 있기 때문에 $u$가 트리의 지름에 포함되는 경우 $v = d_1$ or $v = d_2$이다.
Case 2
$u$가 트리의 지름과 독립적이고, $p(u, v)$또한 $p(d_1, d_2)$와 독립적인 경우이다.
$d(v, u) + d(u, max(d(u, d_1), d(u, d_2))) \geq d(d_1, d_2)$가 성립해서 모순이 발생하므로 이 경우도 $v = d_1$ or $v = d_2$이다.
Case 3
$p(u, v)$가 $p(d_1, d_2)$와 한 교점에서 만나는 경우이다.
$d(v, max(d(v, d_1), d(v, d_2))) \geq d(d_1, d_2)$가 성립해서 모순이 발생하므로 이 경우도 $v = d_1$ or $v = d_2$이다.
따라서 임의의 한 정점에서 가장 먼 정점을 구했을 때, 그 정점은 항상 트리의 지름의 양 끝 점 중 하나이다.
코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <bits/stdc++.h>
#define FastIO ios::sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL);
using namespace std;
int V, max_dist = 0, max_node = 1;
vector<pair<int, int>> graph[100001];
bool visited[100001];
void Output() {
cout << max_dist;
}
void Dfs(int cur_node, int cur_dist) {
visited[cur_node] = true;
if(cur_dist > max_dist) {
max_dist = cur_dist;
max_node = cur_node;
}
for(pair nxt : graph[cur_node]) {
int nxt_node = nxt.first;
int nxt_dist = nxt.second;
if(visited[nxt_node]) {
continue;
}
Dfs(nxt_node, cur_dist + nxt_dist);
}
}
void InitVisited() {
for(int i = 0; i < 100001; i++) {
visited[i] = false;
}
}
void Solve() {
Dfs(1, 0);
InitVisited();
Dfs(max_node, 0);
}
void Input() {
cin >> V;
for(int i = 0; i < V; i++) {
int u;
cin >> u;
while(true) {
int v, dist;
cin >> v;
if(v == -1) {
break;
}
cin >> dist;
graph[u].push_back(make_pair(v, dist));
graph[v].push_back(make_pair(u, dist));
}
}
}
int main() {
FastIO;
Input();
Solve();
Output();
}