문제
N x M 격자에 한 자리 수가 채워져 있을 때, 3개의 직사각형으로 분할하여 각 합의 곱이 최대가 되도록 하라.
입력
첫째 줄에 N, M, 이후 N줄에 M자리 숫자가 주어진다.
출력
세 직사각형 합의 곱의 최댓값을 출력한다.
예제
| 입력 | 출력 |
|---|---|
2 3 123 456 | 180 |
풀이
2D 누적합을 구한 뒤, 가로 2줄 분할, 세로 2줄 분할, 교차점 4가지 분할 총 6가지 패턴을 모두 시도한다.
- 2D 누적합 배열을 구축한다
- 가로줄 2개로 3등분하는 모든 조합을 시도한다
- 세로줄 2개로 3등분하는 모든 조합을 시도한다
- ㅗ/ㅜ/ㅏ/ㅓ 형태의 교차 분할을 모든 교차점에서 시도한다
- 모든 경우 중 최대 곱을 출력한다
핵심 아이디어: 직사각형 3개 분할은 가로선만, 세로선만, 교차선 4가지로 총 6패턴이며, 누적합으로 각 영역의 합을 O(1)에 구한다.
코드
package day749;
import java.io.*;
import java.util.*;
public class Day741BOJ1451직사각형나누기 {
static int N, M;
static int[] divide;
static int[][] num;
static long ans = 0;
static long[] tmp;
static long[][] sum;
public static long findSum(int x1, int y1, int x2, int y2) {
if (x1 == 0 && y1 == 0) {
return sum[y2][x2];
} else if (x1 == 0) {
return sum[y2][x2] - sum[y1 - 1][x2];
} else if (y1 == 0) {
return sum[y2][x2] - sum[y2][x1 - 1];
}
return sum[y2][x2] - sum[y2][x1 - 1] - sum[y1 - 1][x2] + sum[y1 - 1][x1 - 1];
}
public static void divideRow(int count, int index) { // 가로줄을 쭉 그었을때
if (count >= 2) {
tmp[0] = findSum(0, 0, M - 1, divide[0] - 1);
tmp[1] = findSum(0, divide[0], M - 1, divide[1] - 1);
tmp[2] = findSum(0, divide[1], M - 1, N - 1);
ans = Math.max(ans, tmp[0] * tmp[1] * tmp[2]);
return;
}
for (int i = index; i + 1 < N; i++) {
divide[count] = i + 1;
divideRow(count + 1, i + 1);
}
}
public static void divideCol(int count, int index) { // 세로줄을 쭉 그었을때
if (count >= 2) {
tmp[0] = findSum(0, 0, divide[0] - 1, N - 1);
tmp[1] = findSum(divide[0], 0, divide[1] - 1, N - 1);
tmp[2] = findSum(divide[1], 0, M - 1, N - 1);
ans = Math.max(ans, tmp[0] * tmp[1] * tmp[2]);
return;
}
for (int i = index; i + 1 < M; i++) {
divide[count] = i + 1;
divideCol(count + 1, i + 1);
}
}
public static void dividePoint() {
for (int i = 1; i <= N - 1; i++) {
for (int j = 1; j <= M - 1; j++) {
// ㅗ 모양
tmp[0] = findSum(0, 0, j - 1, i - 1);
tmp[1] = findSum(j, 0, M - 1, i - 1);
tmp[2] = findSum(0, i, M - 1, N - 1);
ans = Math.max(ans, tmp[0] * tmp[1] * tmp[2]);
// ㅜ 모양
tmp[0] = findSum(0, 0, M - 1, i - 1);
tmp[1] = findSum(0, i, j - 1, N - 1);
tmp[2] = findSum(j, i, M - 1, N - 1);
ans = Math.max(ans, tmp[0] * tmp[1] * tmp[2]);
// ㅏ 모양
tmp[0] = findSum(0, 0, j - 1, N - 1);
tmp[1] = findSum(j, 0, M - 1, i - 1);
tmp[2] = findSum(j, i, M - 1, N - 1);
ans = Math.max(ans, tmp[0] * tmp[1] * tmp[2]);
// ㅓ 모양
tmp[0] = findSum(0, 0, j - 1, i - 1);
tmp[1] = findSum(0, i, j - 1, N - 1);
tmp[2] = findSum(j, 0, M - 1, N - 1);
ans = Math.max(ans, tmp[0] * tmp[1] * tmp[2]);
}
}
}
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
N = Integer.parseInt(st.nextToken());
M = Integer.parseInt(st.nextToken());
num = new int[N][M];
sum = new long[N][M];
tmp = new long[3];
for (int i = 0; i < N; i++) {
String s = br.readLine();
for (int j = 0; j < M; j++) {
num[i][j] = s.charAt(j) - '0';
if (i == 0 && j == 0) {
sum[0][0] = num[0][0];
} else if (i == 0) {
sum[0][j] = sum[0][j - 1] + num[i][j];
} else if (j == 0) {
sum[i][0] = sum[i - 1][0] + num[i][j];
}
}
}
for (int i = 1; i < N; i++) {
for (int j = 1; j < M; j++) {
sum[i][j] = sum[i][j - 1] + sum[i - 1][j] - sum[i - 1][j - 1] + num[i][j];
}
}
divide = new int[2];
divideRow(0, 0);
divideCol(0, 0);
dividePoint();
System.out.println(ans);
}
}복잡도
- 시간: O(NM(N+M))
- 공간: O(N*M)