import _ from 'lodash';
import { GroupedRelationRef, RelationRef } from '../relation-ref';
import { Relation } from '../relation';
import { Expression } from '../expression';
import { PercentileCont } from '../aggregate';
import { And, Asc, Case, Constant, GenerateSeries } from '../utilities';
import { MaxOver, MinOver, RowNumberOver } from '../window';

export const Histogram = (
  rel: Relation,
  attr: (r: RelationRef | GroupedRelationRef) => {
    target: Expression;
    group: Expression;
  },
  opts: { numBuckets: number }
): Relation => {
  const relNoOutliers = rel
    .leftJoin(
      rel.summary((t) => ({
        PCTILE_75: PercentileCont(attr(t).target, 0.75),
        PCTILE_25: PercentileCont(attr(t).target, 0.25),
        IQR: PercentileCont(attr(t).target, 0.75).sub(
          PercentileCont(attr(t).target, 0.25)
        ),
      })),
      (rel, summary) => ({
        on: true,
        select: {
          ...rel.star(),
          ...summary.star(),
        },
      })
    )
    .where((t) =>
      And(
        attr(t).target.gte(t.attr('PCTILE_25').sub(t.attr('IQR').mul(3))),
        attr(t).target.lte(t.attr('PCTILE_75').add(t.attr('IQR').mul(3)))
      )
    );

  const groupBucketsBase = GenerateSeries(1, opts.numBuckets)
    .leftJoin(
      relNoOutliers.select(
        (t) => ({
          group: attr(t).group,
          MIN_OVERALL: MinOver(attr(t).target),
          MAX_OVERALL: MaxOver(attr(t).target),
        }),
        { distinct: true }
      ),
      (ser, summary) => ({
        on: true,
        select: {
          BUCKET_ASSIGNMENT: RowNumberOver({
            partitionBy: summary.attr('group'),
            orderBy: Asc(ser.attr('n')),
          }),
          ...summary.star(),
        },
      })
    )
    .select((t) => ({
      ...t.pick('BUCKET_ASSIGNMENT', 'group'),
      start: t
        .attr('MIN_OVERALL')
        .add(
          t
            .attr('MAX_OVERALL')
            .sub(t.attr('MIN_OVERALL'))
            .div(opts.numBuckets)
            .mul(t.attr('BUCKET_ASSIGNMENT').sub(1))
        )
        .coalesce(t.attr('MIN_OVERALL'))
        .round(1),
      end: t
        .attr('MIN_OVERALL')
        .add(
          t
            .attr('MAX_OVERALL')
            .sub(t.attr('MIN_OVERALL'))
            .div(opts.numBuckets)
            .mul(t.attr('BUCKET_ASSIGNMENT'))
        )
        .coalesce(t.attr('MAX_OVERALL'))
        .round(1),
    }));

  const base = relNoOutliers
    .select((t) => ({
      ...t.star(),
      BUCKET_ASSIGNMENT: Case(
        _.times(opts.numBuckets, (i) => ({
          when: attr(t).target.lte(
            MinOver(attr(t).target).add(
              Constant(i + 1).mul(
                MaxOver(attr(t).target)
                  .sub(MinOver(attr(t).target))
                  .div(opts.numBuckets)
              )
            )
          ),
          then: i + 1,
        }))
      ),
    }))
    .countBy((t) => ({
      ...t.pick('BUCKET_ASSIGNMENT'),
      group: attr(t).group,
    }))
    .where((t) => t.attr('BUCKET_ASSIGNMENT').isNotNull());

  const histogramOutput = groupBucketsBase
    .leftJoin(base, (buckets, base) => ({
      on: And(
        buckets.attr('BUCKET_ASSIGNMENT').eq(base.attr('BUCKET_ASSIGNMENT')),
        buckets.attr('group').eq(base.attr('group'))
      ),
      select: {
        ...buckets.star(),
        count: base.attr('COUNT').coalesce(0),
      },
    }))
    .orderBy((t) => Asc(t.attr('start')));

  return histogramOutput;
};
