문제 난이도
GOLD III
1774번 풀이
간선의 가중치가 주어지지 않아 간선의 가중치를 구하는 로직만 추가해준다면 쉽게 문제를 해결할 수 있다.
이 문제에서 간선의 가중치는 2차원에 있는 신들의 위치 사이의 거리를 구하는 것으로, 아래 공식처럼 유클리드 거리 계산 방법을 이용하면 쉽게 가중치를 구할 수 있다.
위 공식을 코드로 나타낸다면 아래와 같다.
private static double calCost(int x1, int x2, int y1, int y2) {
return Math.sqrt(Math.pow(x1 - y1, 2) + Math.pow(x2 - y2, 2));
}
직전에 크루스칼 알고리즘에 대해 알아보았기 때문에 크루스칼 알고리즘으로 문제를 해결했다.
간선의 가중치를 위 공식으로 알아냈기 때문에, 간선을 기준으로 오름차순 정렬을 진행하여 Union-Find로 문제를 해결할 수 있다.
내가 문제를 풀이한 순서는 다음과 같다.
1. 입력 받기
입력을 Map형태로 받았다.
Key는 노드, Value는 해당 노드의 좌표를 받았다.
HashMap<Integer, int[]> map = new HashMap<>();
for (int i = 1; i <= N; i ++) {
st = new StringTokenizer(br.readLine());
int a = Integer.parseInt(st.nextToken());
int b = Integer.parseInt(st.nextToken());
map.put(i, new int[] {a, b});
}
2. Union-Find를 위한 각 노드의 부모 배열 생성 및 초기화
문제에서, 이미 연결된 각 노드들의 정보를 제공해 주었기에, union을 통해 부모 정보를 담고있는 배열을 업데이트 했다.
int[] parent = new int[N + 1];
//초기화
for (int i = 1; i <= N; i ++) {
parent[i] = i;
}
//문제 조건 업데이트
for (int i = 0; i < M; i ++) {
st = new StringTokenizer(br.readLine());
int a = Integer.parseInt(st.nextToken());
int b = Integer.parseInt(st.nextToken());
union(parent, a, b);
}
3. 간선 가중치 계산
List<Node> lis = makeGraph(parent, N, map);
//신들의 그래프를 만든다.
// 이 때 find를 이용하여 연결되지 않은 신들의 가중치만을 계산한다.
// 이미 연결 되어있는 신들 간의 가중치는 계산할 필요가 없다.
private static List<Node> makeGraph(int[] parent, int N, HashMap<Integer, int[]> map) {
List<Node> lis = new ArrayList<>();
for (int i = 1; i <= N; i ++) {
for (int j = 1; j <= N; j ++) {
if (find(parent, i) != find(parent, j)) {
double cost = calCost(map, i, j);
lis.add(new Node(i, j, cost));
}
}
}
Collections.sort(lis, (o1, o2) -> Double.compare(o1.cost, o2.cost));
return lis;
}
//가중치 계산
private static double calCost(HashMap<Integer, int[]> map, int x, int y) {
int[] xLocation = map.get(x);
int[] yLocation = map.get(y);
return Math.sqrt(Math.pow(xLocation[0] - yLocation[0], 2) + Math.pow(xLocation[1] - yLocation[1], 2));
}
4. Union-Find를 이용해 문제 해결
마지막으로, 연결되지 않은 노드들의 가중치를 바탕으로 최소 신장 트리를 만들면서 문제에서 요구하는 가중치를 구한다.
요구조건에 맞게, 소숫점 두 번째 자리까지 출력한다.
List<Node> lis = makeGraph(parent, N, map);
double mst = 0;
for (Node n: lis) {
int n1 = n.node1;
int n2 = n.node2;
if (find(parent, n1) != find(parent, n2)) {
union(parent, n1, n2);
mst += n.cost;
}
}
System.out.printf("%.2f\n", mst);
전체 코드
public class Main {
public static class Node {
int node1;
int node2;
double cost;
public Node(int n1, int n2, double c) {
this.node1 = n1;
this.node2 = n2;
this.cost = c;
}
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
int N = Integer.parseInt(st.nextToken());
int M = Integer.parseInt(st.nextToken());
HashMap<Integer, int[]> map = new HashMap<>();
for (int i = 1; i <= N; i ++) {
st = new StringTokenizer(br.readLine());
int a = Integer.parseInt(st.nextToken());
int b = Integer.parseInt(st.nextToken());
map.put(i, new int[] {a, b});
}
int[] parent = new int[N + 1];
for (int i = 1; i <= N; i ++) {
parent[i] = i;
}
for (int i = 0; i < M; i ++) {
st = new StringTokenizer(br.readLine());
int a = Integer.parseInt(st.nextToken());
int b = Integer.parseInt(st.nextToken());
union(parent, a, b);
}
List<Node> lis = makeGraph(parent, N, map);
double mst = 0;
for (Node n: lis) {
int n1 = n.node1;
int n2 = n.node2;
if (find(parent, n1) != find(parent, n2)) {
union(parent, n1, n2);
mst += n.cost;
}
}
System.out.printf("%.2f\n", mst);
}
private static List<Node> makeGraph(int[] parent, int N, HashMap<Integer, int[]> map) {
List<Node> lis = new ArrayList<>();
for (int i = 1; i <= N; i ++) {
for (int j = 1; j <= N; j ++) {
if (find(parent, i) != find(parent, j)) {
double cost = calCost(map, i, j);
lis.add(new Node(i, j, cost));
}
}
}
Collections.sort(lis, (o1, o2) -> Double.compare(o1.cost, o2.cost));
return lis;
}
private static double calCost(HashMap<Integer, int[]> map, int x, int y) {
int[] xLocation = map.get(x);
int[] yLocation = map.get(y);
return Math.sqrt(Math.pow(xLocation[0] - yLocation[0], 2) + Math.pow(xLocation[1] - yLocation[1], 2));
}
private static int find(int[] parent, int x) {
if (parent[x] == x) {
return x;
}
return parent[x] = find(parent, parent[x]);
}
private static void union(int[] parent, int a, int b) {
a = find(parent, a);
b = find(parent, b);
if (a == b) return;
if (a < b) {
parent[b] = a;
}
else {
parent[a] = b;
}
}
}
4386번 풀이
위 풀이에서, 미리 Union을 수행하는 부분만 제외하면 똑같은 문제이다.
대신에 이미 이차원 평면에서 각 노드의 좌표를 double로 받아서 계산만 해주면 된다.
자세한 풀이는 1774번과 거의 흡사하니 위 풀이를 참고한다면 금방 풀이가 가능하다.
public class Main {
public static class Node {
int node1;
int node2;
double cost;
public Node(int n1, int n2, double c) {
this.node1 = n1;
this.node2 = n2;
this.cost = c;
}
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int N = Integer.parseInt(br.readLine());
HashMap<Integer, double[]> map = new HashMap<>();
StringTokenizer st;
for (int i = 1; i <= N; i ++) {
st = new StringTokenizer(br.readLine());
double x = Double.parseDouble(st.nextToken());
double y = Double.parseDouble(st.nextToken());
map.put(i, new double[] {x, y});
}
int[] parent = new int[N + 1];
for (int i = 1; i <= N; i ++) {
parent[i] = i;
}
List<Node> lis = makeGraph(parent, N, map);
double mst = 0;
for (Node n: lis) {
int n1 = n.node1;
int n2 = n.node2;
if (find(parent, n1) != find(parent, n2)) {
union(parent, n1, n2);
mst += n.cost;
}
}
System.out.printf("%.2f\n", mst);
}
private static List<Node> makeGraph(int[] parent, int N, HashMap<Integer, double[]> map) {
List<Node> lis = new ArrayList<>();
for (int i = 1; i <= N; i ++) {
for (int j = 1; j <= N; j ++) {
if (find(parent, i) != find(parent, j)) {
double cost = calCost(map, i, j);
lis.add(new Node(i, j, cost));
}
}
}
Collections.sort(lis, (o1, o2) -> Double.compare(o1.cost, o2.cost));
return lis;
}
private static double calCost(HashMap<Integer, double[]> map, int x, int y) {
double[] xLocation = map.get(x);
double[] yLocation = map.get(y);
return Math.sqrt(Math.pow(xLocation[0] - yLocation[0], 2) + Math.pow(xLocation[1] - yLocation[1], 2));
}
private static int find(int[] parent, int x) {
if (parent[x] == x) {
return x;
}
return parent[x] = find(parent, parent[x]);
}
private static void union(int[] parent, int a, int b) {
a = find(parent, a);
b = find(parent, b);
if (a == b) return;
if (a < b) {
parent[b] = a;
}
else {
parent[a] = b;
}
}
}
2023.04 ~ 백엔드 개발자의 기록
포스팅이 좋았다면 "좋아요❤️" 또는 "구독👍🏻" 해주세요!