import cx from 'classnames'
import { ScaleOrdinal, scaleLinear, scaleOrdinal } from 'd3-scale'
import { area, curveMonotoneY, line } from 'd3-shape'
/*
  Based on https://observablehq.com/@d3/line-chart

  TODO(mark): Add graph legend.
*/
import fp, { range, map, min, max, flatten, toPairs, flow, get, merge } from 'lodash/fp'

import D3Chart, { D3ChartDomain } from '~/components/d3/D3Chart'
import { D3_SCALE_ORDINAL } from '~/components/d3/vars'

import { AxisScale } from 'd3-axis'
import cs from './d3_line_graph.scss'

// @ts-expect-error: lodash/fp convert isn't properly typed
const forEach = fp.forEach.convert({ cap: false })

interface Point {
  x: number
  y: number
}

export interface D3LineGraphOptions {
  xKey: string
  yKey: string
  dotSize: number
  xDomain: [number, number]
  yDomain: [number, number]
  xDomainMin: number
  xDomainMax: number
  yDomainMin: number
  yDomainMax: number
  lineColor: string
  isProjected: (key: string) => boolean
  intercept: {
    x: number
    y: number
  }[]
  interceptColor: string
  projectedErrorKey: string
}

export type D3LineGraphData = {
  [datasetName: string]: { [key: string]: number }[]
}

// Any additional options for line graph.
const DEFAULT_OPTIONS = {
  xKey: '',
  yKey: '',
  dotSize: 3,
  // The x-axis lower bound will be at least this.
  // Set to null have no lower bound. The bounds will be determined by the data.
  xDomainMin: 0,
  // The x-axis lower bound will be at most this.
  // Set to null have no lower bound. The bounds will be determined by the data.
  xDomainMax: null,
  // The y-axis lower bound will be at least this.
  // Set to null have no lower bound. The bounds will be determined by the data.
  yDomainMin: 0,
  // The y-axis lower bound will be at least this.
  // Set to null have no lower bound. The bounds will be determined by the data.
  yDomainMax: null,
  lineColor: null,
}

export default class D3LineGraph extends D3Chart<D3LineGraphData, number, number> {
  chartOptions: D3LineGraphOptions
  scaleOrdinal: ScaleOrdinal<string, string>

  constructor(
    container,
    layoutOptions = {},
    axisOptions = {},
    chartOptions: Partial<D3LineGraphOptions> = {},
  ) {
    super(container, cs.d3LineGraph, layoutOptions, axisOptions)
    this.chartOptions = merge(DEFAULT_OPTIONS, chartOptions)
    this.scaleOrdinal = scaleOrdinal(D3_SCALE_ORDINAL)
  }

  // Nothing to clean-up for this chart.
  teardown = () => {}

  getDomains = data => {
    let allXValues: number[] = []
    let allYValues: number[] = []
    const options = this.chartOptions

    if (options.xDomain && options.yDomain) {
      return {
        xDomain: this.chartOptions.xDomain,
        yDomain: this.chartOptions.yDomain,
      }
    }

    forEach(value => {
      allXValues = allXValues.concat(map(options.xKey, value))
      allYValues = allYValues.concat(map(options.yKey, value))
    }, data)

    let xDomainMin = min(allXValues)
    if (options.xDomainMin != null) {
      xDomainMin = min([xDomainMin, options.xDomainMin])
    }

    let xDomainMax = max(allXValues)
    if (options.xDomainMax != null) {
      xDomainMax = max([xDomainMax, options.xDomainMax])
    }

    let yDomainMin = min(allYValues)
    if (options.yDomainMin != null) {
      yDomainMin = min([yDomainMin, options.yDomainMin])
    }

    let yDomainMax = max(allYValues)
    if (options.yDomainMin != null) {
      yDomainMax = max([yDomainMax, options.yDomainMax])
    }

    return {
      xDomain: [xDomainMin || 0, xDomainMax || 0] as D3ChartDomain<number>,
      yDomain: [yDomainMin || 0, yDomainMax || 0] as D3ChartDomain<number>,
    }
  }

  getColor = key => {
    const options = this.chartOptions
    return options.lineColor || this.scaleOrdinal(key)
  }

  updateData = (data: D3LineGraphData) => {
    const options = this.chartOptions
    this.dataLayer.selectAll('*').remove()
    this.rerenderAxis(data)

    if (!this.xScale || !this.yScale) return
    const xScale = this.xScale
    const yScale = this.yScale

    forEach((value, key) => {
      const color = this.getColor(key)

      const xValues = map(options.xKey, value)
      const yValues = map(options.yKey, value)

      if (options.isProjected && options.isProjected(key)) {
        const errorValues = map(options.projectedErrorKey, value)
        const _lineMin = line<number>()
          .curve(curveMonotoneY)
          .x(i => xScale(xValues[i]) || 0)
          .y(i => yScale(yValues[i]) || 0)

        const _area = area<number>()
          .curve(curveMonotoneY)
          .x0(i => xScale(xValues[i]) || 0)
          .x1(i => xScale(xValues[i]) || 0)
          .y0(i => yScale(yValues[i] - errorValues[i]) || 0)
          .y1(i => yScale(yValues[i] + errorValues[i]) || 0)

        this.dataLayer
          .append('path')
          .datum(range(0, value.length))
          .attr('class', cs.area)
          .attr('d', _area)

        this.dataLayer
          .append('path')
          .attr('class', cs.line)
          .attr('stroke', color)
          .attr('stroke-dasharray', 6)
          .attr('d', _lineMin(range(0, value.length)))
      } else {
        // curveMonotoneY preserves "monotonocity",
        // meaning the curve won't look like it's going up
        // when in reality the data is staying level.
        const _line = line<number>()
          .curve(curveMonotoneY)
          .x(i => xScale(xValues[i]) || 0)
          .y(i => yScale(yValues[i]) || 0)

        this.dataLayer
          .append('path')
          .attr('class', cs.line)
          .attr('stroke', color)
          .attr('d', _line(range(0, value.length)))
      }
    }, data)

    if (options.intercept) {
      forEach(value => {
        const color = options.interceptColor || options.lineColor
        const yIntercept = value.y
        const xIntercept = value.x

        const values = [
          { x: 0, y: yIntercept },
          { x: xIntercept, y: yIntercept },
          // { x: xIntercept, y: 0 },
        ]
        const _line = line<Point>()
          .x(i => xScale(i.x) || 0)
          .y(i => yScale(i.y) || 0)
        this.dataLayer
          .append('path')
          .attr('class', cs.line)
          .attr('stroke', color)
          .attr('d', _line(values))
          .style('stroke-opacity', 0.5)
          .attr('stroke-dasharray', 6)
      }, options.intercept)
    }

    const points = flow(
      toPairs, // [key, [values]]
      map(pairs =>
        map(
          value => ({
            key: pairs[0],
            x: get(options.xKey, value),
            y: get(options.yKey, value),
          }),
          pairs[1],
        ),
      ),
      flatten,
    )(data)

    this.dataLayer
      .selectAll('circle')
      .data(points)
      .enter()
      .append('circle')
      .attr('stroke', d => this.getColor(d.key))
      .attr('r', options.dotSize)
      .attr('cx', d => xScale(d.x) || null)
      .attr('cy', d => yScale(d.y) || null)
      .attr('class', d =>
        cx(cs.point, options.isProjected && options.isProjected(d.key) && cs.hidden),
      )
  }

  getXScale = (xDomain: [number, number]): AxisScale<number> => {
    const dim = this.getChartDimensions()

    return scaleLinear().domain(xDomain).range([0, dim.width])
  }

  getYScale = (yDomain: [number, number]): AxisScale<number> => {
    const dim = this.getChartDimensions()

    return scaleLinear().domain(yDomain).range([dim.height, 0])
  }

  updateChartOptions = options => {
    this.chartOptions = merge(DEFAULT_OPTIONS, options)
    if (this._data) {
      this.updateData(this._data)
    }
  }
}
