알고리즘

백준 11049 행렬 곱셈 순서 [JAVA]

경딩 2024. 10. 17. 13:10

 

해당 문제는 dp 문제로 규칙만 찾으면 쉽게 풀 수 있는 문제다.

생각보다 어려웠다.

풀이를 참고하여 풀었지만 너무 어려웠던 만큼 복습을 해보도록 하자

문제

크기가 N×M인 행렬 A와 M×K인 B를 곱할 때 필요한 곱셈 연산의 수는 총 N×M×K번이다. 행렬 N개를 곱하는데 필요한 곱셈 연산의 수는 행렬을 곱하는 순서에 따라 달라지게 된다.

예를 들어, A의 크기가 5×3이고, B의 크기가 3 ×2, C의 크기가 2 ×6인 경우에 행렬의 곱 ABC를 구하는 경우를 생각해 보자.

  • AB를 먼저 곱하고 C를 곱하는 경우 (AB)C에 필요한 곱셈 연산의 수는 5×3×2 + 5×2×6 = 30 + 60 = 90번이다.
  • BC를 먼저 곱하고 A를 곱하는 경우 A(BC)에 필요한 곱셈 연산의 수는 3×2×6 + 5×3×6 = 36 + 90 = 126번이다.

같은 곱셈이지만, 곱셈을 하는 순서에 따라서 곱셈 연산의 수가 달라진다.

행렬 N개의 크기가 주어졌을 때, 모든 행렬을 곱하는데 필요한 곱셈 연산 횟수의 최솟값을 구하는 프로그램을 작성하시오. 입력으로 주어진 행렬의 순서를 바꾸면 안 된다.

 

입력

첫째 줄에 행렬의 개수 N(1 ≤ N ≤ 500)이 주어진다.

둘째 줄부터 N개 줄에는 행렬의 크기 r과 c가 주어진다. (1 ≤ r, c ≤ 500)

항상 순서대로 곱셈을 할 수 있는 크기만 입력으로 주어진다.

 

 

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;

class Main{
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(br.readLine());

        int [] p =  new int[n+2];


        for(int i=1; i<=n; i++){
            String[] s = br.readLine().split(" ");
            int a = Integer.parseInt(s[0]);
            int b = Integer.parseInt(s[1]);
            if(i == 1){
                p[i] = a;
            }
            p[i+1] = b;
        }

        int[][] dp = new int[n+1][n+1];

        for(int k=1; k<=n-1; k++){
            for(int i=1; i+k<=n; i++){
                int last = i+k;
                dp[i][last] = Integer.MAX_VALUE;
                for(int j=i; j< last; j++){
                    dp[i][last] = Math.min(dp[i][j] + dp[j+1][last] + p[i]*p[j+1]*p[last+1] , dp[i][last]);
                }
            }
        }
        System.out.println(dp[1][n]);

    }
}

 

 

연속 행렬 곱셉

행렬의 곱셈은 곱하는 순서에 따라서 전체 연산 횟수에 차이가 발생한다. 여기서 순서란 AB <-> BA처럼 교환법칙을 의미하는 것이 아니다. ABC에서 AB를 먼저 할지, BC를 먼저 할지의 순서이다.

 

여기서 중요한 것은 '두 행렬을 곱하는 것의 반복'이라는 것이다. AxBxC에서 에서 3개의 행렬을 한 번에 곱하는 것이 아니라  (AxB) xC, Ax(BxC)처럼 순서가 다를 뿐 두 행렬의 곱이다,

 

 

(A(B(CD)))
(A((BC)D))
((AB)(CD))
((A(BC))D)
(((AB)C)D)

 

DynamicProgramming 은 동일한 하위 문제에 대한 해답이 계속해서 필요할 때 사용됩니다.

dp에서는 하위 문제애 대해 계산된 해답이 테이블에 저장되므로 다시 계산을 할 필요가 없습니다.

따라고 공통(겹치는) 하위 문제가 없을 때는 DP을 사용하지 않습니다.

다시 필요하지 않은 해답을 저장할  필요가 없기 때문입니다.

 

dp 알고리즘 4단계

1 단계 : 최적 솔루션의  구조를 특성화하기

2 단계 : 최적 솔루션의 값을 반복적으로 정의하기

3 단계: 일반적으로 상향식 방식으로 최종 솔루션 값을 계산함

4 단계: 계산된 정보로부터 최적의 해답을 도출함.

 

<A1, A2 , A3> 문제를 생각해 보자!

행렬의 크기가 각각 10 ×100, 100 ×5, 5 ×50이라 가정하자

 

  1. ((A1 A2) A3)에 따라 곱하면
    1. 10 × 5 행렬 곱 A1, A2를 계산하기 위해 10 · 100 · 5 = 5000 곱을 실행한다.
    2. 행렬 곱 (A1 A2)와 A3를 곱하기 위해 또 다른 10 · 5 · 50 = 2500 곱셈을 수행합니다.
    3. 따라서 곱 ((A1 A2) A3)을 계산하려면 7500번의 스칼라 곱셈을 해야 합니다.
  2. ((A1 (A2 A3))에 따라 곱하면
    1. 100 × 50 행렬 곱 A1, A2를 계산하기 위해 100 · 5 · 50 = 25000 곱을 실행한다.
    2. 행렬 곱 (A2 A3)와 A1를 곱하기 위해 또 다른 10· 100 · 50  = 5000 곱셈을 수행합니다.
    3. 따라서 곱 ((A1 A2) A3)을 계산하려면 75,000번의 스칼라 곱셈을 해야 합니다.

따라서 첫 번째 괄호에 따라 곱을 계산하는 것이 10배 더 빠릅니다.

스칼라 곱셈 횟수를 최소화하는 곱셈 순서를 결정하자!

 

1단계 : 최적의 괄호 구조

  • k의 일부 값에 대해 먼저 행렬 Ai.. k 및 Ak+1.. j 계산한 다음 이를 곱하여 최종 Ai.. j 값을 도출한다.

 

2단계: 최적의 해를 재귀적으로 정의

  • 최적의 비용이나 다음과 같이 재귀공식으로 설명할 수 있습니다.

 

3단계 :  최적의 비용 계산하기

  • 테이블화된 상향식 접근 방식을 사용하여 최적의 비용을 계산합니다,

 

4단계 : 최적의 해답 구축

  • 정답 출력하기

 

 

 

[4, 10 , 3, 12 ,20, 7]  주어졌다고 가정해 보자.

행렬은  4 x10 , 10 x3 , 3 x12  , 12 x20  , 20 x7이다.

우리는 M [i, j] , 0 <= i, j <= 5를 구해야 한다.

M [i, i]는 모두 0이다.

 

 

대각선을 계속해서 구해보자. 

M [1,3] =  min (M [1][1] + M [2][3] + p0 * p1 * p3 = 0 + 360 + 4*10* 12 = 840

                      (M [1][2] + M [3][3] + p0 * p2 * p3 = 120 + 0 + 4*3* 3 = 264 )  = 264

 

M [2,4] =  min (M [2][2] + M [3][4] + p1 * p2 * p4 = 0 + 720+ 10*3* 20 = 1320

                      (M [2][3] + M [4][4] + p1 * p3 * p4 = 360+ 0 + 10*12* 20  = 2760)  =  1320

 

M [3,5] =  min (M [3][3] + M [4][5] + p2 * p1 * p5 = 0 + 1680+ 3*12* 7= 1932

                      (M [3][4] + M [5][5] + p2 * p2 * p5 = 72-+ 0 + 3*20* 7= 1140)  = 1140

 

 

 

 

M [1,4] =  min (M [1][3] + M [4][4] + p1 * p3 * p4 = 264 + 0 + 4*12* 20 = 1224

                       M [1][2] + M [3][4] + p1 * p2 * p4 = 120 + 720 + 4*3* 20 = 1080

                        M [1][1] + M [2][4] + p1 * p1 * p4 = 0 +1320 + 4*10* 20 = 2120)  = 1080

 

 

M [2,5] =  min (M [2][4] + M [5][5] + p1 * p4 * p5 = 1080 + 0 + 10*20* 7 = 2720

                       M [2][3] + M [4][5] + p1 * p3 * p5 = 264 + 1680+ 10*12* 7  = 2880

                        M [2][2] + M [3][5] + p1 * p2 * p5 = 0+ 1140 + 10*3* 7  = 1350)  = 1350

 

 

p [4, 10 , 3, 12 ,20, 7] 

  [0,  1 ,  2,   3,   4, 5] 

 

M [1,5] =  min (M [1][4] + M [5][5] + p0 * p4 * p5 = 1080 + 0 + 4*20* 7 = 1544

                       M [1][3] + M [4][5] + p0 * p3 * p5 = 264 + 1680+ 4*12* 7  = 2016

                        M [1][2] + M [3][5] + p0 * p2 * p5 = 0+ 1140 + 4*3* 7  = 1344

                        M [1][1] + M [2][5] + p0 * p1 * p5 = 0+ 1140 + 4*10* 7  = 1350 )  = 1630

630

 

출처 : https://www.tutorialspoint.com/data_structures_algorithms/matrix_chain_multiplication.htm