문제
모든 도시 쌍의 최단 거리가 주어질 때, 이를 만족시키는 도로의 최소 개수를 찾고 그 도로들의 총 거리 합을 구하라. 불가능하면 -1을 출력한다.
입력
첫째 줄에 도시 수 N, 이후 N×N 최단 거리 행렬이 주어진다.
출력
필요한 도로의 총 거리 합을 출력한다. 불가능하면 -1을 출력한다.
예제
| 입력 | 출력 |
|---|---|
3 0 2 3 2 0 1 3 1 0 | 3 |
풀이
플로이드-워셜의 역과정으로, 경유점을 통해 대체 가능한 간선을 제거하여 최소 간선 집합을 구한다.
- 모든 중간 노드 k에 대해
maps[i][j]와maps[i][k] + maps[k][j]를 비교한다 maps[i][j] > maps[i][k] + maps[k][j]이면 삼각 부등식 위반이므로 -1을 출력한다maps[i][j] == maps[i][k] + maps[k][j]이면 직접 간선 i-j는 불필요하므로 제거 표시한다- 남은 직접 간선들의 가중치 합을 2로 나누어 출력한다 (양방향이므로)
핵심 아이디어: 경유점을 통해 동일 비용으로 도달 가능한 간선은 불필요하므로, 삼각 부등식 등호가 성립하는 간선을 모두 제거하면 최소 간선 집합이 남는다.
코드
import sys
input = sys.stdin.readline
N = int(input())
maps = [list(map(int, input().split())) for _ in range(N)]
direct = [[True] * N for _ in range(N)]
for k in range(N):
for i in range(N):
for j in range(N):
if i == j or i == k or k == j:
continue
if maps[i][j] > maps[i][k] + maps[k][j]:
print(-1)
exit()
elif maps[i][j] == maps[i][k] + maps[k][j]:
direct[i][j] = direct[j][i] = False
maps = [[maps[i][j] if direct[i][j] else 0 for j in range(N)] for i in range(N)]
print(sum(map(sum, maps)) // 2)복잡도
- 시간: O(N^3)
- 공간: O(N^2)