문제 링크
https://www.acmicpc.net/problem/4386
💡 풀이
주어진 2차원 점들을 최소 비용으로 이었을 때의 가중치 값을 구하는 문제입니다. 그래프의 정점들을 모두 연결했을 때의 가중치의 최솟값을 구하는 문제이므로, 최소 스패닝 트리(MST)를 사용하면 쉽게 풀이할 수 있는 문제입니다.
최소 스패닝 트리(MST) - Prim 알고리즘
저는 최소 스패닝 트리를 해결하기 위해, Prim 알고리즘을 사용했습니다.
프림 알고리즘은
우선 순위 큐(최소힙)를 이용하며 정점과 연결된 간선을 하나씩 추가하며 MST를 만드는 알고리즘
입니다.
static float prim(int start) {
boolean[] visited = new boolean[N];
// 최소힙 := 가중치를 기준으로 오름차순 정렬
PriorityQueue<Edge> pq = new PriorityQueue<>((o1, o2) -> (int)(o1.cost - o2.cost));
pq.offer(new Edge(start, 0));
float result = 0.0f;
while (!pq.isEmpty()) {
Edge edge = pq.poll();
int v = edge.w;
float cost = edge.cost;
// 이미 방문한 점이라면 Pass
if(visited[v]) continue;
visited[v] = true;
result += cost;
// 현재 정점에 연결된 모든 정점을 최소힙에 추가
for (Edge e : graph[v]) {
if (!visited[e.w]) {
pq.add(e);
}
}
}
return result;
}
위와 같이 프림 알고리즘을 사용하기 위해, 먼저 주어진 2차원 점들을 이용해 다른 모든 정점들과의 관계를 정의해줍니다.
그래프 만들기
points = new float[N][2];
for (int i = 0; i < N; i++) {
st = new StringTokenizer(br.readLine());
points[i][0] = Float.parseFloat(st.nextToken());
points[i][1] = Float.parseFloat(st.nextToken());
}
// Tree(Graph) 만들기
graph = new ArrayList[N];
for (int i = 0; i < N; i++) {
graph[i] = new ArrayList<>();
for (int j = 0; j < N; j++) {
if (i != j) {
graph[i].add(new Edge(j, getDistance(points[i], points[j])));
}
}
}
...
static class Edge {
int w;
float cost;
Edge(int w, float cost) {
this.w = w;
this.cost = cost;
}
}
`Edge` 클래스는 현재 연결된 정점과 가중치 값을 필드로 가지는 클래스입니다. 따라서, `graph[i]`의 의미는 `i` 정점에 연결된 Edge들을 의미합니다. 예시로 `graph[i].get(0)`를 살펴보면 하나의 Edge가 될 것이고 `Edge.w = 1`, `Edge.cost = 1.2`라면 i 와 1이 연결되어 있고 이 간선의 가중치의 값이 1.2라는 의미입니다.
전체 코드
전체적인 코드는 다음과 같습니다.
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.PriorityQueue;
import java.util.StringTokenizer;
import static java.lang.Math.*;
class Q4386 {
static int N;
static float answer;
static float[][] points;
static ArrayList<Edge>[] graph;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st;
N = Integer.parseInt(br.readLine());
points = new float[N][2];
for (int i = 0; i < N; i++) {
st = new StringTokenizer(br.readLine());
points[i][0] = Float.parseFloat(st.nextToken());
points[i][1] = Float.parseFloat(st.nextToken());
}
// Tree(Graph) 만들기
graph = new ArrayList[N];
for (int i = 0; i < N; i++) {
graph[i] = new ArrayList<>();
for (int j = 0; j < N; j++) {
if (i != j) {
graph[i].add(new Edge(j, getDistance(points[i], points[j])));
}
}
}
System.out.printf("%.2f\n", prim(0));
}
static float prim(int start) {
boolean[] visited = new boolean[N];
PriorityQueue<Edge> pq = new PriorityQueue<>((o1, o2) -> (int)(o1.cost - o2.cost));
pq.offer(new Edge(start, 0));
float result = 0.0f;
while (!pq.isEmpty()) {
Edge edge = pq.poll();
int v = edge.w;
float cost = edge.cost;
// 이미 방문한 점이라면 Pass
if(visited[v]) continue;
visited[v] = true;
result += cost;
for (Edge e : graph[v]) {
if (!visited[e.w]) {
pq.add(e);
}
}
}
return result;
}
static float getDistance(float[] point1, float[] point2) {
return (float)(sqrt(pow(point1[0] - point2[0], 2) + pow(point1[1] - point2[1], 2)));
}
static class Edge {
int w;
float cost;
Edge(int w, float cost) {
this.w = w;
this.cost = cost;
}
}
}
'알고리즘 > BOJ' 카테고리의 다른 글
[BOJ] 14003 : 가장 긴 증가하는 부분 수열 5(Java) (1) | 2024.02.08 |
---|---|
[BOJ] 1509 : 팰린드롬 분할(Java) (1) | 2024.02.06 |
[BOJ] 20303 : 할로윈의 양아치(Java) (1) | 2024.02.05 |
[BOJ] 6603 : 로또(Java) (0) | 2024.01.31 |
[BOJ] 1202 : 보석 도둑(Java) (0) | 2024.01.30 |