import {
  Case,
  Corr,
  Count,
  Desc,
  Expression,
  Relation,
  RelationRef,
  RowNumberOver,
  Sum,
  UnionAll,
  GroupedRelationRef,
  Constant,
  Ty,
  Histogram,
  TC,
} from '@cotera/era';

const cross = <T>(xs: T[], ys: T[]): [T, T][] => {
  return xs.flatMap((x) => ys.map((y): [T, T] => [x, y]));
};

export const CorrelationMatrix = (rel: Relation) => {
  const numericAttrs = columnsOfTypes(rel, ['float', 'int']);

  return rel.summary((t) =>
    Object.fromEntries(
      cross(numericAttrs, numericAttrs).map(([l, r]) => [
        `${l} x ${r}`,
        Corr(t.attr(l), t.attr(r)),
      ])
    )
  );
};

export const columnsOfTypes = (
  rel: Relation,
  types: Ty.PrimitiveAttributeType[]
): string[] =>
  Object.entries(rel.attributes)
    .filter(([_name, { ty }]) =>
      types.some((req) => TC.implementsTy({ subject: ty, req }))
    )
    .map(([name, _]) => name);

export const DistinctValuesCountByColumn = (rel: Relation): Relation =>
  DistinctValuesCountForColumns(rel, Object.keys(rel.attributes));

export const DistinctValuesCountForColumns = (
  rel: Relation,
  cols: string[]
): Relation =>
  UnionAll(
    cols.map((name) =>
      rel
        .select((t) => t.pick(name), { distinct: true })
        .summary((_t) => ({ name, count: Count() }))
    )
  );

export const HistogramForColumns = (
  rel: Relation,
  cols: string[],
  opts: { numBuckets: number }
): Relation =>
  UnionAll(
    cols.map((name) =>
      Histogram(
        rel,
        (t) => ({
          target: t.attr(name),
          group: Constant(name),
        }),
        opts
      ).select((t) => ({
        start: t.attr('start').cast('float'),
        end: t.attr('end').cast('float'),
        count: t.attr('count'),
        name,
      }))
    )
  );

export const HistogramByColumn = (
  rel: Relation,
  opts: { numBuckets: number }
): Relation => {
  const cols = Object.entries(rel.attributes)
    .filter(([_, x]) => isNumberType(x.ty))
    .map(([attr]) => attr);

  if (cols.length === 0) {
    return rel.where(() => false).select((_) => ({}));
  }

  return HistogramForColumns(rel, cols, opts);
};

export const isNumberType = (attr: Ty.AttributeType): boolean =>
  attr.k === 'primitive' && ['float', 'int'].includes(attr.t);

export const CATEGORICAL_TAGS = {
  discrete: 'other',
  category: 'category',
};

export const CategoricalAttributes = (
  rel: Relation,
  attr: (r: RelationRef) => Expression,
  opts: {
    discreteThreshold: number;
  }
) => {
  const base = rel
    .countBy((t) => ({ value: attr(t) }))
    .where((t) => t.attr('COUNT').gt(1))
    .orderBy((t) => [Desc(t.attr('COUNT')), Desc(t.attr('value'))])
    .select((t) => {
      const rn = RowNumberOver({ orderBy: Desc(t.attr('COUNT')) });

      return {
        value: t.attr('value'),
        count: t.attr('COUNT'),
        tag: Case(
          [
            {
              when: rn.gt(opts.discreteThreshold),
              then: CATEGORICAL_TAGS.discrete,
            },
          ],
          {
            else: CATEGORICAL_TAGS.category,
          }
        ),
      };
    });

  return UnionAll([
    base
      .where((t) => t.attr('tag').eq(CATEGORICAL_TAGS.category))
      .select((t) => ({
        count: t.attr('count'),
        value: t.attr('value').cast('string'),
      }))
      .orderBy((t) => Desc(t.attr('count'))),
    base
      .where((t) => t.attr('tag').eq(CATEGORICAL_TAGS.discrete))
      .summary((t) => ({
        value: Constant('other', { ty: 'string' }),
        count: Sum(t.attr('count')).cast('int'),
      })),
  ]).where((t) => t.attr('count').gt(0));
};

export const CategoricalAttributesForColumns = (
  rel: Relation,
  cols: string[],
  opts: {
    discreteThreshold: number;
  }
) =>
  UnionAll(
    cols.map((name) =>
      CategoricalAttributes(rel, (t) => t.attr(name), opts).select((t) => ({
        ...t.star(),
        name,
      }))
    )
  );

export const CategoricalAttributesByColumn = (
  rel: Relation,
  opts: {
    discreteThreshold: number;
  }
) =>
  CategoricalAttributesForColumns(
    rel,
    Object.entries(rel.attributes)
      .filter(([_, x]) => !isNumberType(x.ty))
      .map(([attr]) => attr),
    opts
  );

export const UniqueValuesCount = (
  rel: Relation,
  attr: (t: RelationRef | GroupedRelationRef) => Expression
) =>
  rel
    .countBy((t) => ({ value: attr(t) }))
    .where((t) => t.attr('COUNT').eq(1))
    .summary((_t) => ({ count: Count() }));
