# 필요한 모듈과 클래스 임포트
from airflow import DAG
from airflow.utils.decorators import apply_defaults
from airflow.exceptions import AirflowException
from airflow.hooks.postgres_hook import PostgresHook
from airflow.operators.python_operator import PythonOperator
import logging
from glob import glob

# DAG에서 사용할 함수: 주어진 디렉토리의 JSON 파일을 읽어들이는 함수
def load_all_jsons_into_list(path_to_json):
    # configs 리스트 초기화
    configs = []
    
    # 주어진 디렉토리에 있는 모든 .py 파일에 대해 반복
    for f_name in glob(path_to_json + '/*.py'):
        # 파일 열기
        with open(f_name) as f:
            # 파일 내용을 텍스트로 읽어들임
            dict_text = f.read()
            
            try:
                # 읽어들인 텍스트를 eval 함수를 사용하여 파이썬 딕셔너리로 변환
                dict = eval(dict_text)
            except Exception as e:
                # 변환 중 에러가 발생하면 로그에 에러 메시지 기록하고 예외 전파
                logging.info(str(e))
                raise
            else:
                # 성공적으로 변환된 딕셔너리를 configs 리스트에 추가
                configs.append(dict)
    
    # 최종적으로 변환된 딕셔너리들을 담고 있는 configs 리스트 반환
    return configs


# 주어진 테이블 이름과 테이블 설정에서 테이블을 찾는 함수
def find(table_name, table_confs):
    # table_confs 리스트를 순회하며 각 테이블 설정 딕셔너리를 확인
    for table in table_confs:
        # 현재 테이블 설정 딕셔너리의 "table" 키 값이 주어진 table_name과 일치하는지 확인
        if table.get("table") == table_name:
            # 일치하는 경우 해당 테이블 설정 딕셔너리를 반환
            return table
    
    # 모든 테이블 설정을 확인한 후에도 일치하는 것이 없으면 None 반환
    return None

# Redshift에서 요약 테이블을 생성하는 함수
def build_summary_table(dag_root_path, dag, tables_load, redshift_conn_id, start_task=None):
    # 로그에 DAG 루트 경로 출력
    logging.info(dag_root_path)
    
    # JSON 형식의 테이블 설정 파일들을 읽어들임
    table_confs = load_all_jsons_into_list(dag_root_path + "/config/")

    # DAG의 시작 작업이 지정되었으면 해당 작업을 기준으로 설정
    if start_task is not None:
        prev_task = start_task
    else:
        prev_task = None

    # tables_load에 지정된 테이블들에 대해 작업 생성 및 DAG에 추가
    for table_name in tables_load:
        # 지정된 테이블명에 해당하는 테이블 설정 찾기
        table = find(table_name, table_confs)
        
        # RedshiftSummaryOperator를 생성하고 설정값 전달
        summarizer = RedshiftSummaryOperator(
            table=table["table"],
            schema=table["schema"],
            redshift_conn_id=redshift_conn_id,
            input_check=table["input_check"],
            main_sql=table["main_sql"],
            output_check=table["output_check"],
            overwrite=table.get("overwrite", True),
            after_sql=table.get("after_sql"),
            pre_sql=table.get("pre_sql"),
            attributes=table.get("attributes", ""),
            dag=dag,
            task_id="analytics"+"__"+table["table"]
        )
        
        # 이전 작업과 현재 작업 간의 의존성 설정
        if prev_task is not None:
            prev_task >> summarizer
        prev_task = summarizer
    
    # 마지막으로 추가된 작업을 반환 (나중에 DAG 작성 시 필요)
    return prev_task

# Redshift에서 SQL을 실행하는 함수
def redshift_sql_function(**context):
    sql=context["params"]["sql"]
    print(sql)
    hook = PostgresHook(postgres_conn_id=context["params"]["redshift_conn_id"])
    hook.run(sql, True)

# Redshift에서 요약 테이블을 생성하는 PythonOperator를 확장한 사용자 정의 연산자 클래스
class RedshiftSummaryOperator(PythonOperator):
    @apply_defaults
    def __init__(self, schema, table, redshift_conn_id, input_check, main_sql, output_check, overwrite, params={}, pre_sql="", after_sql="", attributes="", *args, **kwargs):
        # 사용자가 정의한 RedshiftSummaryOperator 클래스의 초기화 메소드입니다.
        
        # 생성할 요약 테이블의 Redshift 스키마 이름
        self.schema = schema
        
        # 생성할 요약 테이블의 이름
        self.table = table
        
        # Airflow에서 사용하는 Redshift 연결 ID
        self.redshift_conn_id = redshift_conn_id
        
        # 입력 유효성을 검사하기 위한 SQL 쿼리 및 최소 레코드 수로 구성된 목록
        self.input_check = input_check
        
        # 요약 테이블을 생성하기 위한 주요 SQL 쿼리
        self.main_sql = main_sql
        
        # 출력 유효성을 검사하기 위한 SQL 쿼리 및 최소 레코드 수로 구성된 목록
        self.output_check = output_check
        
        # True인 경우 기존 테이블을 덮어쓰고, False인 경우 덮어쓰지 않고 추가합니다.
        self.overwrite = overwrite
        
        # PythonOperator에 전달할 추가 매개변수
        self.params = params
        
        # main_sql 실행 전에 실행할 SQL 쿼리
        self.pre_sql = pre_sql if pre_sql else ""
        
        # main_sql 실행 후에 실행할 SQL 쿼리
        self.after_sql = after_sql.format(schema=self.schema, table=self.table) if after_sql else ""
        
        # 생성할 테이블의 추가 속성
        self.attributes = attributes

        # temp 테이블 생성 및 데이터 적재에 사용될 SQL 문 생성
        if pre_sql:
            self.main_sql = pre_sql
            if not self.main_sql.endswith(";"):
                self.main_sql += ";"
        else:
            self.main_sql = ""
        # 임시테이블이 있으면 삭제, 
        self.main_sql += "DROP TABLE IF EXISTS {schema}.temp_{table};".format(
            schema=self.schema,
            table=self.table
        )
        # CREATE TABLE 문을 만들어서 self.main_sql에 추가합니다.
        self.main_sql += "CREATE TABLE {schema}.temp_{table} {attributes} AS ".format(
            schema=self.schema,
            table=self.table,
            attributes=self.attributes
        ) + self.main_sql

        # 상위 클래스인 PythonOperator를 호출하여 초기화
        # RedshiftSummaryOperator 클래스의 초기화 메소드에서는 두 번의 상위 클래스 초기화(super)가 이루어집니다.

        # 첫 번째 super 호출:
        super(RedshiftSummaryOperator, self).__init__(
            python_callable=redshift_sql_function,  # Python callable로 사용될 함수
            params={
                "sql": self.main_sql,  # 생성한 SQL 문
                "overwrite": self.overwrite,  # 덮어쓰기 여부
                "redshift_conn_id": self.redshift_conn_id  # Redshift 연결 ID
            },
            provide_context=True,  # Airflow 컨텍스트 제공 여부
            *args,
            **kwargs
        )

        # 두 번째 super 호출:
        # after_sql이 정의되어 있다면 해당 값을 사용하고, 그렇지 않으면 빈 문자열("")을 사용합니다.
        if after_sql:
            self.after_sql = after_sql.format(
                schema=self.schema,
                table=self.table
            )
        else:
            self.after_sql = ""

        super(RedshiftSummaryOperator, self).__init__(
            python_callable=redshift_sql_function,  # Python callable로 사용될 함수
            params={
                "sql": main_sql,  # 생성한 SQL 문
                "overwrite": overwrite,  # 덮어쓰기 여부
                "redshift_conn_id": self.redshift_conn_id  # Redshift 연결 ID
            },
            provide_context=True,  # Airflow 컨텍스트 제공 여부
            *args,
            **kwargs
        )


    # temp 테이블과 본 테이블을 스왑하는 함수
    def swap(self):
        # 원본 테이블 삭제
        # 임시테이블 원본테이블 이름으로 바꿈
        # {schema}와 {table}을 사용자가 정의한 값으로 대체하여 
        # SELECT 권한을 부여하는 SQL 문
        sql = """BEGIN;
        DROP TABLE IF EXISTS {schema}.{table} CASCADE;
        ALTER TABLE {schema}.temp_{table} RENAME TO {table};   
        GRANT SELECT ON TABLE {schema}.{table} TO GROUP analytics_users;
        END
        """.format(schema=self.schema,table=self.table)
        self.hook.run(sql, True)

    def execute(self, context):
        """
        RedshiftSummaryOperator의 execute 메소드입니다.

        1. Input_check 먼저 수행
           - input_check는 "sql"과 "count"를 포함하는 딕셔너리의 목록이어야 함
        """
        self.hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
        for item in self.input_check:
            (cnt,) = self.hook.get_first(item["sql"])
            if cnt < item["count"]:
                raise AirflowException(
                    "Input Validation Failed for " + str(item["sql"]))

        """
        2. temp 테이블 생성 및 데이터 적재 수행
        """
        return_value = super(RedshiftSummaryOperator, self).execute(context)

        """
        3. Output_check은 self.output_check 사용
        """
        for item in self.output_check:
            (cnt,) = self.hook.get_first(item["sql"].format(schema=self.schema, table=self.table))
            if item.get("op") == 'eq':
                if int(cnt) != int(item["count"]):
                    raise AirflowException(
                        "Output Validation of 'eq' Failed for " + str(item["sql"]) + ": " + str(cnt) + " vs. " + str(item["count"])
                    )
            else:
                if cnt < item["count"]:
                    raise AirflowException(
                        "Output Validation Failed for " + str(item["sql"]) + ": " + str(cnt) + " vs. " + str(item["count"])
                    )
        
        """
        4. 이제 temp 테이블 이름을 스왑
        """
        self.swap()

        """
        5. after_sql이 정의되어 있다면 실행
        """
        if self.after_sql:
            self.hook.run(self.after_sql, True)

        return return_value

+ Recent posts