알고리즘

[백준 10830번] 행렬 제곱

RuBPCase 2022. 2. 6. 09:11
728x90

원본 링크

https://www.acmicpc.net/problem/10830

Intro

오랜만에 보는 분할 정복 기법문제이다.
Divide & Conquer라고도 불리는 분할 정복 기법은 말 그대로 작게 문제를 분할해서 풀어나가고 합치는 방식이다.
여기서도 이 방법을 쓰는데, 나중에 과정을 보면 알겠지만 거듭 제곱을 활용해서 푼다.
이와 비슷한 문제도 찾아보면 나오는데, 심심하면 풀어보기 바란다.
유사 문제 - 1629 곱셈: https://www.acmicpc.net/problem/1629

문제

크기가 N*N인 행렬 A가 주어진다. 이때, A의 B제곱을 구하는 프로그램을 작성하시오. 수가 매우 커질 수 있으니, A^B의 각 원소를 1,000으로 나눈 나머지를 출력한다.

입력

첫째 줄에 행렬의 크기 N과 B가 주어진다. (2 ≤ N ≤ 5, 1 ≤ B ≤ 100,000,000,000)

둘째 줄부터 N개의 줄에 행렬의 각 원소가 주어진다. 행렬의 각 원소는 1,000보다 작거나 같은 자연수 또는 0이다.

출력

첫째 줄부터 N개의 줄에 걸쳐 행렬 A를 B제곱한 결과를 출력한다.

예제

예제 입력 1

2 5
1 2
3 4

예제 출력 1

69 558
337 406

예제 입력 2

3 3
1 2 3
4 5 6
7 8 9

예제 출력 2

468 576 684
62 305 548
656 34 412

예제 입력 3

5 10
1 0 0 0 1
1 0 0 0 1
1 0 0 0 1
1 0 0 0 1
1 0 0 0 1

예제 출력 3

512 0 0 0 512
512 0 0 0 512
512 0 0 0 512
512 0 0 0 512
512 0 0 0 512

주 알고리즘 분류

  • 수학
  • 분할 정복
  • 분할 정복을 이용한 거듭제곱
  • 선형대수학

나만의 풀이

우리의 목표는 주어진 행렬 A를 B제곱한 결과를 출력하는 것이다.
이때, 1000으로 나눈 나머지를 출력해야 함을 잊지 말자.
따라서 행렬 곱을 계산할 때, 1000에 관한 나머지 연산을 수행해야 할 것이다.

단순히 for문을 B-1번(or B번) 돌려서 A를 곱하면 너무 시간이 오래 걸린다.
따라서 행렬 곱과 관련된 약간의 성질을 이용하면 효율적으로 풀 수 있다.
알다시피, 수의 제곱처럼 행렬의 제곱도 A^m * A^n = A^(m+n)형태로 계산 가능하다.
이 점을 이용하면, A^B 또한 A^(B-x) * A^x와 같은 형태로 계산 가능할 것이다.
즉, 분할이 가능함을 알 수 있다.

여기에다가 추가로 B라는 값을 이진수로 변환해보자.
여기서 1이 나오는 자릿수에 해당하는 거듭제곱값을 이용하면 빠르게 계산할 수 있다.
조금 이해가 안 될 수도 있으니 아래의 간단한 예시를 생각해보자.
가령 10010이라는 값이 나왔다고 가정하자.
그러면 A^18A^16 * A^2의 형태로 계산할 수 있게 된다.
거듭제곱의 형태로 계산하면 큰 수라도 금방 도달할 수 있으니 단순 곱에 비해 효율적이다.

이제, 코드를 짜면서 추가적인 설명을 하겠다.
그 전에, 지금 행렬이라는 친구가 나왔으므로, numpy라는 모듈을 불러와서 행렬 연산을 수행할 생각이다.
numpy는 행렬이나 벡터 연산을 효율적으로 수행하기 위해 만든 라이브러리이다.
요즘은 ML이나 DL, 데이터 분석 시 pandas라는 친구랑 합쳐져서 많이 쓴다.
아무튼, 지금은 알고리즘 관련 문제를 푸는 거지만 한번 써보도록 하겠다.

라이브러리를 불러오고 입력을 받아들이는 과정을 먼저 수행한다.
나는 그냥 단순하게 n하고 a, b라 변수 명을 붙였다.
그리고 앞서 말한 개념적인 내용을 토대로 코드를 짜려 한다.

binb = bin(b)[2:]
arr = np.array(a)
tmp = np.identity(n=n)  # 단위 행렬.

본격적인 행렬곱 연산을 수행하기 전, 미리 b를 이진값으로 바꿔주자.
이때, string의 형태로 값이 바뀌며, 0b????....꼴로 나오므로, 앞의 2개 문자를 제거해주자.
이후 arr라는 배열을 생성하고, tmp에는 np.identity()를 써서 단위행렬을 넣어 주자.

for bb in binb[::-1]:  # 이진 거꾸로 보면서
    if bb == '1':
        tmp = tmp.dot(arr)%1000
    arr = arr.dot(arr)%1000  # 내적

앞서 구한 이진값에서 거꾸로 접근을 수행할 생각이다.
A, A^2, A^4, A^8, ... 꼴로 구해나갈 생각이라 그렇다.
일단 매 반복 끝마다 arr의 값을 제곱해준다.
이때, 계산의 편의를 위해서 나머지 연산도 수행해줬다.
그 앞에서 만일 현재 자릿수가 1인 경우엔 tmp라는 변수에 현재의 arr값을 곱하도록 코드를 짰다.
아까 말한 A^18 = A^16 * A^2의 과정을 위와 같이 구현했다.
그리고 그 이후엔 문제 조건에 맞게 출력만 하면 된다.

전체 코드는 아래와 같다.

import sys
input = sys.stdin.readline
import numpy as np

n, b = map(int, input().rstrip().split())

a = []
for _ in range(n):
    a.append(list(map(int, input().rstrip().split())))

binb = bin(b)[2:]
arr = np.array(a)
tmp = np.identity(n=n)

for bb in binb[::-1]:
    if bb == '1':
        tmp = tmp.dot(arr)%1000
    arr = arr.dot(arr)%1000

for i in range(n):
    for j in range(n):
        print(int(tmp[i][j]), end='')
        if j+1 != n: print(' ', end='')
    print()

엫 이거 왜 이래

테스트 케이스에서도 문제가 없어서 그냥 제출했는데 문제가 생겼다.
백준에서 numpy모듈을 지원하지 않나 보다.
할 수 없이 별도의 메소드를 다시 작성해야겠다. (ㅠㅠ)

앞선 코드를 수정해서 이번엔 numpy 없이코드를 작성하겠다.
따라서 np.dot()연산 대신 mydot()이라는 함수를 새로 만들었다.
mydot()은 입력으로 두 개의 2차원 리스트(행렬)를 받아 행렬곱을 수행한 행렬을 전달한다.

def mydot(a, b):  # np가 안 돼서 할 수 없이 mydot을 생성.
    n = len(a)
    tmp = [[0 for _ in range(n)] for _ in range(n)]
    for i in range(n):
        for j in range(n):
            for k in range(n):
                tmp[i][j] += (a[i][k] * b[k][j])
            tmp[i][j] %= 1000

    return tmp

그냥 단순하게 일반적으로 수행하는 행렬곱 과정을 구현했다.
행렬곱을 하는 원리만 알면 어떤 의미인지 해석할 수 있을 것이다.
이때, i-j loop의 끝부분에서 1000과의 나머지 연산을 별도로 수행했다.

조금 advance한 내용으로, 행렬 곱을 조금 더 빨리 구하기 위한 알고리즘도 있다.
Strassen's trick이라 불리는 내용이 있다.
주어진 n by n 행렬을 n/2 by n/2로 쪼개서 11번의 덧셈 & 7번의 뺄셈 & 7번의 행렬곱을 써서 구해낼 수 있다.
수학적으로 증명했을 때, 기존의 행렬곱 방식 O(n^3) 대비, O(n^2.81)정도의 효율을 지닌다고 한다.
수학적 증명은 Master Theorem을 활용해서 구할 수 있는데, 여기선 생략한다.

위 방식으로 행렬곱을 계산하면 조금 더 빨리 계산할 수 있다.
그러나 보면 알겠지만 실제 n의 범위는 5 이하다.
따라서 그냥 O(n^3)형태를 써도 상관없겠다는 생각이 든다.
앞선 Strassen's trick을 쓰려면 재귀적으로 처리하는 과정도 필요하니 구현 과정은 유제로 남겨놓으려 한다.

그 외의 코드는 앞서 짠 코드와 거의 비슷하다.
다만, 별도로 단위 행렬을 만드는 코드를 넣고 함수 구조를 살짝 바꾼 정도이다.
아래가 진짜 코드 전문이 되겠다.

import sys
input = sys.stdin.readline

def mydot(a, b):  # np가 안 돼서 할 수 없이 mydot을 생성.
    n = len(a)
    tmp = [[0 for _ in range(n)] for _ in range(n)]
    for i in range(n):
        for j in range(n):
            for k in range(n):
                tmp[i][j] += (a[i][k] * b[k][j])
            tmp[i][j] %= 1000

    return tmp


n, b = map(int, input().rstrip().split())

a = []
for _ in range(n):
    a.append(list(map(int, input().rstrip().split())))

binb = bin(b)[2:]
tmp = [[1 if i==j else 0 for i in range(n)] for j in range(n)]  # 단위 행렬.

for bb in binb[::-1]:  # 이진 거꾸로 보면서
    if bb == '1':
        tmp = mydot(tmp, a)
    a = mydot(a, a)  # 내적

for i in range(n):
    for j in range(n):
        print(int(tmp[i][j]), end='')
        if j+1 != n: print(' ', end='')
    print()

마치면서

이미 존재하는 행렬 연산 전용 라이브러리를 이용해서 풀려고 했으나 'ModuleNotFoundError'가 뜬 것이 아쉽다.
(Colab이나 Jupyter에서 데이터 끌올하면 numpy 잘 쓰는데...)
따라서 다시 함수를 만드느라 조금 번거로웠던 점을 빼면 괜찮았다.

앞서 말했듯 행렬 곱 계산 시 조금 더 효율적으로 곱을 구하는 방식이 있다.
다만 이를 구현하는 게 귀찮은 것도 있고, 주어지는 input의 행렬 크기가 5 이하인 점도 있어서 단순 곱으로만 구현했다.
혹시 필요한 개발자들은 해당 내용도 참고해서 구현해보면 좋을 것 같다.

728x90