문제 링크
https://www.acmicpc.net/problem/2143
💡 풀이
배열 A와 배열 B의 누적합을 저장하는 두 배열 계산해 하나씩 비교하면 시간 초과가 발생합니다. 따라서, 배열 A의 누적합을 저장한 배열을 순회하며, T - a[i]의 값을 B의 누적합 배열에서 찾을 때, 이분 탐색을 이용해 찾습니다.
누적합 배열
각 A, B 배열에 대해 누적합을 저장하는 배열을 선언해줍니다.
int a_sum_length = a_length * (a_length + 1) / 2;
int[] a_sum = new int[a_sum_length];
int index = 0;
for (int i = 0; i < a_length; i++) {
int sum = 0;
for (int j = i; j < a_length; j++) {
sum += a[j];
a_sum[index++] = sum;
}
}
int b_sum_length = b_length * (b_length + 1) / 2;
int[] b_sum = new int[b_sum_length];
index = 0;
for (int i = 0; i < b_length; i++) {
int sum = 0;
for (int j = i; j < b_length; j++) {
sum += b[j];
b_sum[index++] = sum;
}
}
이분 탐색
B 누적합 배열에서 해당하는 값을 빠르게 찾기 위해 정렬하고 이분 탐색을 이용합니다. 이때, 중복된 숫자가 존재할 수 있기 때문에 이를 찾기 위해 왼쪽과 오른쪽을 확인하는 코드를 추가해줍니다.
static int binarySearch(int[] arr, int target) {
int left = 0;
int right = arr.length - 1;
while (left <= right) {
int mid = left + (right - left) / 2;
if (arr[mid] == target) {
return mid; // 찾는 값이 존재하는 경우 해당 인덱스 반환
} else if (arr[mid] < target) {
left = mid + 1; // 중간값이 찾는 값보다 작으면 오른쪽 영역으로 이동
} else {
right = mid - 1; // 중간값이 찾는 값보다 크면 왼쪽 영역으로 이동
}
}
return -1; // 찾는 값이 존재하지 않는 경우 -1 반환
}
static int leftCount(int[] arr, int index, int target) {
int result = 0;
while (true) {
index--;
if (index < 0) {
break;
}
if (arr[index] == target) {
result++;
}
}
return result;
}
static int rightCount(int[] arr, int index, int target) {
int result = 0;
while (true) {
index++;
if (index >= arr.length) {
break;
}
if (arr[index] == target) {
result++;
}
}
return result;
}
해당 값 찾기
이제, A 누적합 배열을 순회하며, T - a_sum[i] 의 값을 B 누적합 배열에서 찾아줍니다. 이때, 시간을 줄이기 위해 기존에 찾은 값을 저장하는 메모이제이션을 사용했습니다.
Arrays.sort(b_sum);
long answer = 0;
for (int i = 0; i < a_sum_length; i++) {
int number = T - a_sum[i];
if (b_map.containsKey(number)) {
answer += b_map.get(number);
}
else {
int findIndex = binarySearch(b_sum, number);
if (findIndex != -1) {
int count = 1 + leftCount(b_sum, findIndex, number) + rightCount(b_sum, findIndex, number);
b_map.put(number, count);
answer += count;
}
}
}
전체 코드
전체적인 코드는 다음과 같습니다.
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.StringTokenizer;
public class Q2143 {
static int T;
static int[] a, b;
static Map<Integer, Integer> b_map = new HashMap<>();
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st;
T = Integer.parseInt(br.readLine());
int a_length = Integer.parseInt(br.readLine());
a = new int[a_length];
st = new StringTokenizer(br.readLine());
for (int i = 0; i < a_length; i++) {
a[i] = Integer.parseInt(st.nextToken());
}
int b_length = Integer.parseInt(br.readLine());
b = new int[b_length];
st = new StringTokenizer(br.readLine());
for (int i = 0; i < b_length; i++) {
b[i] = Integer.parseInt(st.nextToken());
}
int a_sum_length = a_length * (a_length + 1) / 2;
int[] a_sum = new int[a_sum_length];
int index = 0;
for (int i = 0; i < a_length; i++) {
int sum = 0;
for (int j = i; j < a_length; j++) {
sum += a[j];
a_sum[index++] = sum;
}
}
int b_sum_length = b_length * (b_length + 1) / 2;
int[] b_sum = new int[b_sum_length];
index = 0;
for (int i = 0; i < b_length; i++) {
int sum = 0;
for (int j = i; j < b_length; j++) {
sum += b[j];
b_sum[index++] = sum;
}
}
Arrays.sort(b_sum);
long answer = 0;
for (int i = 0; i < a_sum_length; i++) {
int number = T - a_sum[i];
if (b_map.containsKey(number)) {
answer += b_map.get(number);
}
else {
int findIndex = binarySearch(b_sum, number);
if (findIndex != -1) {
int count = 1 + leftCount(b_sum, findIndex, number) + rightCount(b_sum, findIndex, number);
b_map.put(number, count);
answer += count;
}
}
}
System.out.println(answer);
}
static int binarySearch(int[] arr, int target) {
int left = 0;
int right = arr.length - 1;
while (left <= right) {
int mid = left + (right - left) / 2;
if (arr[mid] == target) {
return mid; // 찾는 값이 존재하는 경우 해당 인덱스 반환
} else if (arr[mid] < target) {
left = mid + 1; // 중간값이 찾는 값보다 작으면 오른쪽 영역으로 이동
} else {
right = mid - 1; // 중간값이 찾는 값보다 크면 왼쪽 영역으로 이동
}
}
return -1; // 찾는 값이 존재하지 않는 경우 -1 반환
}
static int leftCount(int[] arr, int index, int target) {
int result = 0;
while (true) {
index--;
if (index < 0) {
break;
}
if (arr[index] == target) {
result++;
}
}
return result;
}
static int rightCount(int[] arr, int index, int target) {
int result = 0;
while (true) {
index++;
if (index >= arr.length) {
break;
}
if (arr[index] == target) {
result++;
}
}
return result;
}
}
'알고리즘 > BOJ' 카테고리의 다른 글
[BOJ] 17281 : ⚾(Java) (2) | 2024.01.22 |
---|---|
[BOJ] 17070 : 파이프 옮기기 1(Java) (1) | 2024.01.21 |
[BOJ] 1234 : 크리스마스 트리(Java) (1) | 2023.11.27 |
[BOJ] 2473: 세 용액(Java) (1) | 2023.11.15 |
[BOJ] 16398 : 행성 연결(Java) (1) | 2023.10.17 |