import { ColumnDef, Header } from '@tanstack/react-table'
import { RowData } from '@tanstack/table-core/src/types'

type HeaderWithTableAttributes<TData extends RowData> = Omit<ColumnDef<TData>, 'columns'> & {
    rowSpan?: number
    colSpan?: number
    columns: HeaderWithTableAttributes<TData>[]
}

export type HeaderRow<TData extends RowData> = {
    level: number
    headers: Header<TData, unknown>[]
}

const filterAllPlaceholdersFromHeaders = <TData extends RowData>(headers: Header<TData, unknown>[]) => {
    const newHeaders: Header<TData, unknown>[] = []

    for (const header of headers) {
        if (header.isPlaceholder) {
            continue
        }
        if (header.subHeaders.length) {
            header.subHeaders = filterAllPlaceholdersFromHeaders(header.subHeaders)
        }

        newHeaders.push(header)
    }

    return newHeaders
}

const calculateMaxDepth = <TData extends RowData>(headers: HeaderWithTableAttributes<TData>[]) => {
    let maxDepth: number = 0

    headers.forEach(header => {
        if ('columns' in header && header.columns.length) {
            // if the header is a group, calculate the depth of the group
            const depth = calculateMaxDepth(header.columns) + 1
            maxDepth = Math.max(maxDepth, depth)
        }
    })

    return maxDepth
}

const calculateSpanValues = <TData extends RowData>(
    headers: HeaderWithTableAttributes<TData>[],
    depth: number = 0,
    maxDepth: number = calculateMaxDepth(headers)
) => {
    return headers.map(header => {
        let rowSpan = 1
        let colSpan = 1

        if (header.columns && header.columns.length) {
            const columns = calculateSpanValues(header.columns, depth + 1, maxDepth)
            colSpan = columns.reduce((total, child) => {
                return total + child.colSpan
            }, 0)
            header.columns = columns
        } else {
            rowSpan = maxDepth - depth + 1
        }

        return {
            ...header,
            rowSpan,
            colSpan
        }
    })
}

const processHeaders = <TData extends RowData>(
    headers: Header<TData, unknown>[],
    columns: HeaderWithTableAttributes<TData>[],
    level: number = 0
) => {
    let rows: HeaderRow<TData>[] = []

    // Create a row for this level
    rows.push({
        level,
        headers: columns.map(column => {
            const header = headers.find(item => {
                return item.column.id === column.id
            })

            header.colSpan = column.colSpan
            header.rowSpan = column.rowSpan

            return header
        })
    })

    // If there are any children, create a row for them
    const childrenColumns: HeaderWithTableAttributes<TData>[] = columns.flatMap(header => {
        return header.columns || []
    })

    if (childrenColumns.length > 0) {
        const childRows = processHeaders(headers, childrenColumns, level + 1)
        rows = [...rows, ...childRows]
    }

    return rows
}

export const convertHeadersToColumnAndRowsGrouping = <TData extends RowData>(
    headers: Header<TData, unknown>[],
    columns: ColumnDef<TData>[]
) => {
    const filteredHeaders: Header<TData, unknown>[] = filterAllPlaceholdersFromHeaders(headers)
    const spanValues: HeaderWithTableAttributes<TData>[] = calculateSpanValues(
        columns as HeaderWithTableAttributes<TData>[]
    )
    return processHeaders(filteredHeaders, spanValues)
}
