import { blue, grass, orange, purple, ruby, yellow, tomato, teal, lime, pink } from '@radix-ui/colors'
import { AreaChart, BarChart, DonutChart, Legend } from '@tremor/react'
import { DateTime } from 'luxon'
import { match } from 'ts-pattern'
import { cn } from '../../../utils'
import { Card } from '@bpinternal/ui-kit'
import { Flex, Grid, Text } from '@radix-ui/themes'
import { ReactNode } from 'react'
import { chartTypes } from '../types'
import { type Resolution } from './ChartDataProvider'

type Props<T extends Record<string, number | Date>> = {
  categories: (keyof T)[]
  data: ({ timestamp: Date } & T)[]
  type: (typeof chartTypes)[number]
  className?: string
  showLegend?: boolean
  children?: ReactNode
  resolution?: Resolution
}

export const chartColorNames = [
  'blue',
  'yellow',
  'ruby',
  'grass',
  'orange',
  'purple',
  'tomato',
  'teal',
  'lime',
  'pink',
] as const

const colors = [
  blue.blue8,
  yellow.yellow8,
  ruby.ruby8,
  grass.grass8,
  orange.orange8,
  purple.purple8,
  tomato.tomato8,
  teal.teal8,
  lime.lime8,
  pink.pink8,
]

export const Chart = <T extends Record<string, number | Date>>({
  children,
  className,
  showLegend = true,
  resolution,
  ...props
}: Props<T>) => {
  return (
    <Card className={cn('h-full w-full bg-white py-4 dark:bg-gray-1', className)}>
      <Flex direction={'column'} className="h-full w-full" gap={'4'}>
        {children}
        {match(props)
          .when(
            ({ data }) => data.length === 0,
            () => <EmptyChartState>No data available</EmptyChartState>
          )
          .when(
            ({ categories }) => categories.length === 0,
            () => <EmptyChartState>No data source selected</EmptyChartState>
          )
          .with({ type: 'area' }, ({ categories, data }) => {
            const parsedData = data.map((d) => ({
              ...d,
              timestamp: formatDate(d.timestamp, resolution ?? 'month'),
            }))
            return (
              <AreaChart
                showLegend={showLegend}
                className={cn('-ml-3 h-full w-full')}
                colors={colors}
                index={'timestamp'}
                categories={categories as string[]}
                data={parsedData}
                curveType="monotone"
                connectNulls
              />
            )
          })
          .with({ type: 'spark_area' }, ({ categories, data }) => {
            const parsedData = data.map((d) => ({
              ...d,
              timestamp: formatDate(d.timestamp, resolution ?? 'month'),
            }))
            return (
              <AreaChart
                showLegend={showLegend}
                className={cn('h-full w-full')}
                colors={colors}
                index={'timestamp'}
                categories={categories as string[]}
                data={parsedData}
                showGridLines={false}
                showXAxis={false}
                showYAxis={false}
                curveType="monotone"
                connectNulls
              />
            )
          })
          .with({ type: 'bar' }, ({ categories, data }) => {
            const parsedData = data.map((d) => ({
              ...d,
              timestamp: formatDate(d.timestamp, resolution ?? 'month'),
            }))
            return (
              <BarChart
                showLegend={showLegend}
                className={cn('-ml-3 h-full w-full')}
                colors={colors}
                index={'timestamp'}
                categories={categories as string[]}
                data={parsedData}
              />
            )
          })
          .with({ type: 'spark_bar' }, ({ categories, data }) => {
            const parsedData = data.map((d) => ({
              ...d,
              timestamp: formatDate(d.timestamp, resolution ?? 'month'),
            }))
            return (
              <BarChart
                showLegend={showLegend}
                className={cn('h-full w-full')}
                colors={colors}
                index={'timestamp'}
                categories={categories as string[]}
                showGridLines={false}
                showXAxis={false}
                showYAxis={false}
                data={parsedData}
              />
            )
          })
          .with({ type: 'metric' }, ({ data, categories }) => {
            const reducedData = data.reduce(
              (acc, d) => {
                for (const key of categories) {
                  acc[key] = Number(acc[key] ?? 0) + Number(d[key] ?? 0)
                }
                return acc
              },
              {} as Record<keyof T, number>
            )

            const filteredData = Object.entries(reducedData)
              .map(([key, value]) => ({
                name: key,
                value,
              }))
              .filter((d) => categories.includes(d.name))
            return (
              <Grid
                gap={'4'}
                mx={'auto'}
                className={cn('h-full items-center', {
                  '@[18rem]:grid-cols-2': categories.length >= 2,
                  '@xl:grid-cols-3': categories.length >= 3,
                })}
              >
                {filteredData.map(({ name, value }, index) => (
                  <Flex key={name} direction={'column'} className="text-2xl" px={'4'}>
                    <Flex align={'center'} gap={'1'}>
                      <div className={`bg-[${colors[index]}] size-3 flex-none rounded-sm`} />
                      <Text color="gray" truncate className="text-[0.7em]">
                        {name}
                      </Text>
                    </Flex>
                    <Text weight={'bold'}>{value}</Text>
                  </Flex>
                ))}
              </Grid>
            )
          })
          .with({ type: 'donut' }, ({ data, categories }) => {
            const reducedData = data.reduce(
              (acc, d) => {
                // for each key of categories sum the value with the acc
                for (const key of categories) {
                  acc[key] = Number(acc[key] ?? 0) + Number(d[key] ?? 0)
                }
                return acc
              },
              {} as Record<keyof T, number>
            )

            const parsedData = Object.entries(reducedData).map(([key, value]) => ({
              name: key,
              value,
            }))

            return (
              <Flex className="h-full w-full" direction={'column'} gap={'4'}>
                {showLegend && (
                  <Legend className="flex justify-end" colors={colors} categories={categories as string[]} />
                )}
                <DonutChart
                  className="h-12 grow"
                  data={parsedData}
                  colors={colors}
                  index="name"
                  category="value"
                  showLabel={false}
                />
              </Flex>
            )
          })
          .with({ type: 'pie' }, ({ data, categories }) => {
            const reducedData = data.reduce(
              (acc, d) => {
                // for each key of categories sum the value with the acc
                for (const key of categories) {
                  acc[key] = Number(acc[key] ?? 0) + Number(d[key] ?? 0)
                }
                return acc
              },
              {} as Record<keyof T, number>
            )

            const parsedData = Object.entries(reducedData).map(([key, value]) => ({
              name: key,
              value,
            }))

            return (
              <Flex className="h-full w-full" direction={'column'} gap={'4'}>
                {showLegend && (
                  <Legend className="flex justify-end" colors={colors} categories={categories as string[]} />
                )}
                <DonutChart
                  className="h-12 grow"
                  data={parsedData}
                  colors={colors}
                  index="name"
                  category="value"
                  showLabel={false}
                  variant="pie"
                />
              </Flex>
            )
          })
          .otherwise(() => (
            <EmptyChartState>Chart type not supported</EmptyChartState>
          ))}
      </Flex>
    </Card>
  )
}

const EmptyChartState = ({ children }: { children: ReactNode }) => (
  <Flex justify={'center'} align={'center'} className="size-full rounded-sm border border-dashed border-gray-4">
    <Text color="gray" size={'2'} className="text-center">
      {children}
    </Text>
  </Flex>
)

function formatDate(timestamp: Date, resolution: Resolution) {
  const dt = DateTime.fromJSDate(timestamp)
  switch (resolution) {
    case 'hour':
      return dt.toFormat('LLL d, HH:mm')
    case 'day':
      return dt.toFormat('LLL d, yyyy')
    case 'week':
      return dt.toFormat('LLL d, yyyy')
    case 'month':
      return dt.toFormat('LLL yyyy')
    default:
      return dt.toFormat('LLL d, yyyy')
  }
}
