문제 링크
https://www.acmicpc.net/problem/20303
💡 풀이
중간에 시행착오가 좀 있었지만, Union-Find와 배낭 문제를 이용하면 풀 수 있는 문제입니다.
- 서로 친구 관계가 있는 아이들끼리 Union-Find를 통해 그룹을 만들어 줍니다.
- 각 그룹의 인원 수, 그룹의 총 사탕 개수로 이루어진 노드를 만들고 DP를 이용해 얻을 수 있는 최대 사탕 개수를 구합니다.
이렇게 크게 두 가지 과정을 통해 문제를 해결할 수 있습니다.
Union-Find
우선, Union-Find를 사용해 서로 관계있는 아이들끼리 그룹을 형성합니다.
parents = new int[N + 1];
...
for (int i = 1; i <= N; i++) {
candies[i] = Integer.parseInt(st.nextToken());
parents[i] = i;
}
// Union-Find 진행
for (int i = 0; i < M; i++) {
st = new StringTokenizer(br.readLine());
int node1 = Integer.parseInt(st.nextToken());
int node2 = Integer.parseInt(st.nextToken());
union(node1, node2);
}
...
static boolean union(int node1, int node2) {
node1 = find(node1);
node2 = find(node2);
if (node1 == node2)
return false;
if (node1 <= node2)
parents[node2] = node1;
else
parents[node1] = node2;
return true;
}
static int find(int node) {
if(parents[node] == node)
return node;
return parents[node] = find(parents[node]);
}
그리고, 이를 통해 만들어진 `parents` 배열을 통해 그룹을 만들어 List에 저장합니다.
// group[i][0] := i 그룹의 인원 수, group[i][1] := i 그룹의 모든 사탕 수
group = new long[N + 1][2];
for (int index = 1; index <= N; index++) {
int gNum = find(parents[index]);
group[gNum][0]++; // 해당 그룹 인원 증가
group[gNum][1] += candies[index];
}
// List 에 그룹을 담고 인원 수를 기준으로 정렬
ArrayList<Node> list = new ArrayList<>();
for (int index = 1; index <= N; index++) {
if (parents[index] == index) {
list.add(new Node((int)group[index][0], group[index][1]));
}
}
Collections.sort(list, ((o1, o2) -> o1.cost - o2.cost));
...
static class Node {
int cost;
long value;
Node(int cost, long value) {
this.cost = cost;
this.value = value;
}
}
배낭 문제(DP)
만들어진 cost와 value로 구성된 Node의 List를 DP를 이용해 해결합니다. 기존의 배낭 문제와 같은 2차원 다이나믹 프로그래밍 방식으로 해결 할 수 있습니다.
long[][] dp = new long[list.size() + 1][K];
for (int i = 1; i <= list.size(); i++) {
for (int j = 0; j < K; j++) {
if (list.get(i - 1).cost > j) {
dp[i][j] = dp[i - 1][j];
} else {
dp[i][j] = Math.max(dp[i - 1][j - list.get(i - 1).cost] + list.get(i - 1).value, dp[i - 1][j]);
}
}
}
System.out.println(dp[list.size()][K - 1]);
전체 코드
전체적인 코드는 다음과 같습니다.
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.StringTokenizer;
class Q20303 {
static int N, M, K;
static int[] candies;
static int[] parents;
static long[][] group;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
N = Integer.parseInt(st.nextToken());
M = Integer.parseInt(st.nextToken());
K = Integer.parseInt(st.nextToken());
parents = new int[N + 1];
candies = new int[N + 1];
st = new StringTokenizer(br.readLine());
for (int i = 1; i <= N; i++) {
candies[i] = Integer.parseInt(st.nextToken());
parents[i] = i;
}
// Union-Find 진행
for (int i = 0; i < M; i++) {
st = new StringTokenizer(br.readLine());
int node1 = Integer.parseInt(st.nextToken());
int node2 = Integer.parseInt(st.nextToken());
union(node1, node2);
}
// group[i][0] := i 그룹의 인원 수, group[i][1] := i 그룹의 모든 사탕 수
group = new long[N + 1][2];
for (int index = 1; index <= N; index++) {
int gNum = find(parents[index]);
group[gNum][0]++; // 해당 그룹 인원 증가
group[gNum][1] += candies[index];
}
// List 에 그룹을 담고 인원 수를 기준으로 정렬
ArrayList<Node> list = new ArrayList<>();
for (int index = 1; index <= N; index++) {
if (parents[index] == index) {
list.add(new Node((int)group[index][0], group[index][1]));
}
}
Collections.sort(list, ((o1, o2) -> o1.cost - o2.cost));
long[][] dp = new long[list.size() + 1][K];
for (int i = 1; i <= list.size(); i++) {
for (int j = 0; j < K; j++) {
if (list.get(i - 1).cost > j) {
dp[i][j] = dp[i - 1][j];
} else {
dp[i][j] = Math.max(dp[i - 1][j - list.get(i - 1).cost] + list.get(i - 1).value, dp[i - 1][j]);
}
}
}
System.out.println(dp[list.size()][K - 1]);
}
static boolean union(int node1, int node2) {
node1 = find(node1);
node2 = find(node2);
if (node1 == node2)
return false;
if (node1 <= node2)
parents[node2] = node1;
else
parents[node1] = node2;
return true;
}
static int find(int node) {
if(parents[node] == node)
return node;
return parents[node] = find(parents[node]);
}
static class Node {
int cost;
long value;
Node(int cost, long value) {
this.cost = cost;
this.value = value;
}
}
}
'알고리즘 > BOJ' 카테고리의 다른 글
[BOJ] 1509 : 팰린드롬 분할(Java) (1) | 2024.02.06 |
---|---|
[BOJ] 4386 : 별자리 만들기(Java) (0) | 2024.02.06 |
[BOJ] 6603 : 로또(Java) (0) | 2024.01.31 |
[BOJ] 1202 : 보석 도둑(Java) (0) | 2024.01.30 |
[BOJ] 17136 : 색종이 붙이기(Java) (0) | 2024.01.23 |