백준

[백준] 수열과 퀴리 13 Python 3 (Feat 비효율)

안녕 나의 20대 2024. 6. 15.
반응형

 

세그먼트 트리와 지연 전파를 활용한 효율적인 구간 연산 처리

세그먼트 트리(Segment Tree)와 지연 전파(Lazy Propagation)를 활용하여 구간 연산을 효율적으로 처리하는 파이썬 코드를 작성해 보겠습니다. 세그먼트 트리는 구간 합, 구간 최솟값, 구간 최댓값 등 다양한 구간 연산을 빠르게 처리할 수 있는 자료구조입니다. 여기에 지연 전파 기법을 추가하여 업데이트 연산을 최적화할 수 있습니다.

문제 설명

주어진 배열에 대해 다음과 같은 연산을 효율적으로 수행하는 프로그램을 작성합니다:

  1. 특정 구간의 모든 원소에 값을 더하는 연산
  2. 특정 구간의 모든 원소에 값을 곱하는 연산
  3. 특정 구간의 모든 원소를 특정 값으로 바꾸는 연산
  4. 특정 구간의 합을 구하는 연산

이 문제는 세그먼트 트리와 지연 전파를 사용하여 해결할 수 있습니다.

접근 방식

이 문제를 해결하기 위해 다음과 같은 접근 방식을 사용합니다:

  1. 세그먼트 트리 초기화: 주어진 배열을 기반으로 세그먼트 트리를 초기화합니다.
  2. 지연 전파: 구간 업데이트 연산을 지연 전파 기법을 사용하여 최적화합니다.
  3. 쿼리 처리: 구간 합을 빠르게 계산할 수 있도록 합니다.

코드 구현

다음은 위의 접근 방식을 코드로 구현한 것입니다:

from sys import stdin
from math import ceil, log2

# 표준 입력을 사용하여 데이터를 읽어옵니다.
input = stdin.read
MOD = 10**9 + 7

class SegmentTree:
    def __init__(self, n):
        # 세그먼트 트리 초기화
        self.n = n
        self.size = 1 << (ceil(log2(n)) + 1)  # 트리의 크기 계산
        self.tree = [0] * self.size  # 세그먼트 트리 배열
        self.lazy_mult = [1] * self.size  # 곱셈 지연 전파 배열
        self.lazy_add = [0] * self.size  # 덧셈 지연 전파 배열
    
    def build(self, array):
        # 주어진 배열을 이용해 세그먼트 트리를 구축합니다.
        self._build(array, 1, 0, self.n - 1)

    def _build(self, array, node, start, end):
        # 세그먼트 트리를 재귀적으로 구축하는 함수
        if start == end:
            self.tree[node] = array[start]
        else:
            mid = (start + end) // 2
            self._build(array, 2 * node, start, mid)
            self._build(array, 2 * node + 1, mid + 1, end)
            self.tree[node] = (self.tree[2 * node] + self.tree[2 * node + 1]) % MOD

    def _propagate(self, node, start, end):
        # 지연 전파를 처리하는 함수
        if self.lazy_mult[node] != 1 or self.lazy_add[node] != 0:
            self.tree[node] = (self.tree[node] * self.lazy_mult[node] + self.lazy_add[node] * (end - start + 1)) % MOD
            if start != end:
                self.lazy_mult[2 * node] = (self.lazy_mult[2 * node] * self.lazy_mult[node]) % MOD
                self.lazy_mult[2 * node + 1] = (self.lazy_mult[2 * node + 1] * self.lazy_mult[node]) % MOD
                self.lazy_add[2 * node] = (self.lazy_add[2 * node] * self.lazy_mult[node] + self.lazy_add[node]) % MOD
                self.lazy_add[2 * node + 1] = (self.lazy_add[2 * node + 1] * self.lazy_mult[node] + self.lazy_add[node]) % MOD
            self.lazy_mult[node] = 1
            self.lazy_add[node] = 0
    
    def update_plus(self, node, start, end, left, right, diff):
        # 구간 덧셈 업데이트를 수행하는 함수
        self._propagate(node, start, end)
        
        if end < left or start > right:
            return
        
        if left <= start and end <= right:
            self.lazy_add[node] = (self.lazy_add[node] + diff) % MOD
            self._propagate(node, start, end)
            return
        
        mid = (start + end) // 2
        self.update_plus(2 * node, start, mid, left, right, diff)
        self.update_plus(2 * node + 1, mid + 1, end, left, right, diff)
        self.tree[node] = (self.tree[2 * node] + self.tree[2 * node + 1]) % MOD
    
    def update_multi(self, node, start, end, left, right, diff):
        # 구간 곱셈 업데이트를 수행하는 함수
        self._propagate(node, start, end)
        
        if end < left or start > right:
            return
        
        if left <= start and end <= right:
            self.lazy_mult[node] = (self.lazy_mult[node] * diff) % MOD
            self._propagate(node, start, end)
            return
        
        mid = (start + end) // 2
        self.update_multi(2 * node, start, mid, left, right, diff)
        self.update_multi(2 * node + 1, mid + 1, end, left, right, diff)
        self.tree[node] = (self.tree[2 * node] + self.tree[2 * node + 1]) % MOD
    
    def update(self, node, start, end, left, right, val):
        # 구간 값을 특정 값으로 설정하는 업데이트 함수
        self._propagate(node, start, end)
        
        if end < left or start > right:
            return
        
        if left <= start and end <= right:
            self.lazy_mult[node] = 0
            self.lazy_add[node] = val
            self._propagate(node, start, end)
            return
        
        mid = (start + end) // 2
        self.update(2 * node, start, mid, left, right, val)
        self.update(2 * node + 1, mid + 1, end, left, right, val)
        self.tree[node] = (self.tree[2 * node] + self.tree[2 * node + 1]) % MOD
    
    def query(self, node, start, end, left, right):
        # 구간 합을 구하는 쿼리 함수
        self._propagate(node, start, end)
        
        if end < left or start > right:
            return 0
        
        if left <= start and end <= right:
            return self.tree[node] % MOD
        
        mid = (start + end) // 2
        left_query = self.query(2 * node, start, mid, left, right)
        right_query = self.query(2 * node + 1, mid + 1, end, left, right)
        return (left_query + right_query) % MOD

if __name__ == '__main__':
    # 입력 데이터를 읽어와서 배열과 쿼리 정보를 파싱합니다.
    data = input().split()
    N = int(data[0])
    nums = list(map(int, data[1:N + 1]))

    seg_tree = SegmentTree(N)
    seg_tree.build(nums)
    
    M = int(data[N + 1])
    idx = N + 2
    results = []
    
    for _ in range(M):
        query_type = int(data[idx])
        x = int(data[idx + 1]) - 1
        y = int(data[idx + 2]) - 1
        if query_type == 1:
            v = int(data[idx + 3])
            seg_tree.update_plus(1, 0, N - 1, x, y, v)
            idx += 4
        elif query_type == 2:
            v = int(data[idx + 3])
            seg_tree.update_multi(1, 0, N - 1, x, y, v)
            idx += 4
        elif query_type == 3:
            v = int(data[idx + 3])
            seg_tree.update(1, 0, N - 1, x, y, v)
            idx += 4
        else:
            results.append(seg_tree.query(1, 0, N - 1, x, y))
            idx += 3
    
    for result in results:
        print(result)

코드 설명

  1. SegmentTree 클래스: 세그먼트 트리를 초기화하고, 구간 합, 구간 곱, 구간 치환 등의 연산을 수행합니다. 지연 전파를 사용하여 업데이트 연산을 최적화합니다.
    • build: 주어진 배열을 기반으로 세그먼트 트리를 초기화합니다.
    • _propagate: 지연 전파를 처리합니다.
    • update_plus: 특정 구간의 모든 원소에 값을 더합니다.
    • update_multi: 특정 구간의 모든 원소에 값을 곱합니다.
    • update: 특정 구간의 모든 원소를 특정 값으로 바꿉니다.
    • query: 특정 구간의 합을 구합니다.
  2. 메인 함수: 입력을 처리하고 쿼리를 수행합니다.
    • 입력 데이터를 읽어와서 배열과 쿼리 정보를 파싱합니다.
    • 세그먼트 트리를 초기화하고, 각 쿼리를 처리하여 결과를 출력합니다

번외
Cpp로 아주 빠른 속도로 해결 가능한 문제이지만 파이썬의 경우 메모리와 실행속도 최적화를 위해 상당히 어려운 문제이다.

파이썬으로만 푸는건 한계가 있을지도...

반응형

'백준' 카테고리의 다른 글

[백준] 1046 그림자 Python 3  (2) 2024.06.21
[백준] 2586 소방차 Python 3  (0) 2024.06.19
[백준] 9252 LCS 2 Python 3  (2) 2024.06.12
[백준] 1144 싼 비용 Python 3  (2) 2024.06.10
[백준] 5257 timeismoney Python 3  (0) 2024.06.05

댓글

💲 추천 글