[BOJ] 4386 : 별자리 만들기(Java)

문제 링크

https://www.acmicpc.net/problem/4386

 

4386번: 별자리 만들기

도현이는 우주의 신이다. 이제 도현이는 아무렇게나 널브러져 있는 n개의 별들을 이어서 별자리를 하나 만들 것이다. 별자리의 조건은 다음과 같다. 별자리를 이루는 선은 서로 다른 두 별을 일

www.acmicpc.net


💡 풀이

주어진 2차원 점들을 최소 비용으로 이었을 때의 가중치 값을 구하는 문제입니다. 그래프의 정점들을 모두 연결했을 때의 가중치의 최솟값을 구하는 문제이므로, 최소 스패닝 트리(MST)를 사용하면 쉽게 풀이할 수 있는 문제입니다.

최소 스패닝 트리(MST) - Prim 알고리즘

저는 최소 스패닝 트리를 해결하기 위해, Prim 알고리즘을 사용했습니다.

프림 알고리즘은

우선 순위 큐(최소힙)를 이용하며 정점과 연결된 간선을 하나씩 추가하며 MST를 만드는 알고리즘

입니다.

static float prim(int start) {
    boolean[] visited = new boolean[N];
	
    // 최소힙 := 가중치를 기준으로 오름차순 정렬
    PriorityQueue<Edge> pq = new PriorityQueue<>((o1, o2) -> (int)(o1.cost - o2.cost));
    pq.offer(new Edge(start, 0));

    float result = 0.0f;
    while (!pq.isEmpty()) {
        Edge edge = pq.poll();
        int v = edge.w;
        float cost = edge.cost;

        // 이미 방문한 점이라면 Pass
        if(visited[v]) continue;

        visited[v] = true;
        result += cost;
		
        // 현재 정점에 연결된 모든 정점을 최소힙에 추가
        for (Edge e : graph[v]) {
            if (!visited[e.w]) {
                pq.add(e);
            }
        }
    }

    return result;
}

 

위와 같이 프림 알고리즘을 사용하기 위해, 먼저 주어진 2차원 점들을 이용해 다른 모든 정점들과의 관계를 정의해줍니다.

그래프 만들기

points = new float[N][2];
for (int i = 0; i < N; i++) {
    st = new StringTokenizer(br.readLine());
    points[i][0] = Float.parseFloat(st.nextToken());
    points[i][1] = Float.parseFloat(st.nextToken());
}

// Tree(Graph) 만들기
graph = new ArrayList[N];
for (int i = 0; i < N; i++) {
    graph[i] = new ArrayList<>();
    for (int j = 0; j < N; j++) {
        if (i != j) {
            graph[i].add(new Edge(j, getDistance(points[i], points[j])));
        }
    }
}

...
static class Edge {
    int w;
    float cost;

    Edge(int w, float cost) {
        this.w = w;
        this.cost = cost;
    }
}

`Edge` 클래스는 현재 연결된 정점가중치 값을 필드로 가지는 클래스입니다. 따라서, `graph[i]`의 의미는 `i` 정점에 연결된 Edge들을 의미합니다. 예시로 `graph[i].get(0)`를 살펴보면 하나의 Edge가 될 것이고 `Edge.w = 1`, `Edge.cost = 1.2`라면 i 와 1이 연결되어 있고 이 간선의 가중치의 값이 1.2라는 의미입니다.

 

전체 코드

전체적인 코드는 다음과 같습니다.

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.PriorityQueue;
import java.util.StringTokenizer;

import static java.lang.Math.*;

class Q4386 {
    static int N;
    static float answer;
    static float[][] points;
    static ArrayList<Edge>[] graph;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st;

        N = Integer.parseInt(br.readLine());
        points = new float[N][2];
        for (int i = 0; i < N; i++) {
            st = new StringTokenizer(br.readLine());
            points[i][0] = Float.parseFloat(st.nextToken());
            points[i][1] = Float.parseFloat(st.nextToken());
        }

        // Tree(Graph) 만들기
        graph = new ArrayList[N];
        for (int i = 0; i < N; i++) {
            graph[i] = new ArrayList<>();
            for (int j = 0; j < N; j++) {
                if (i != j) {
                    graph[i].add(new Edge(j, getDistance(points[i], points[j])));
                }
            }
        }

        System.out.printf("%.2f\n", prim(0));
    }

    static float prim(int start) {
        boolean[] visited = new boolean[N];

        PriorityQueue<Edge> pq = new PriorityQueue<>((o1, o2) -> (int)(o1.cost - o2.cost));
        pq.offer(new Edge(start, 0));

        float result = 0.0f;
        while (!pq.isEmpty()) {
            Edge edge = pq.poll();
            int v = edge.w;
            float cost = edge.cost;

            // 이미 방문한 점이라면 Pass
            if(visited[v]) continue;

            visited[v] = true;
            result += cost;

            for (Edge e : graph[v]) {
                if (!visited[e.w]) {
                    pq.add(e);
                }
            }
        }

        return result;
    }

    static float getDistance(float[] point1, float[] point2) {
        return (float)(sqrt(pow(point1[0] - point2[0], 2) + pow(point1[1] - point2[1], 2)));
    }

    static class Edge {
        int w;
        float cost;

        Edge(int w, float cost) {
            this.w = w;
            this.cost = cost;
        }
    }
}