import { Assert } from '@cotera/utilities';
import { TyStackTrace } from './ty-stack-trace';
import _ from 'lodash';
import { mergeMacroVars } from './merge-marco-vars';
import { AST } from '../ast';
import { Ty } from '../ty';
import { Parser } from '../parser';
import * as Errs from './type-check-error';
import { checkExpr } from './expr/check-expr';
import type { ExprTypeCheck } from './expr/check-expr';
import {
  implementsRel,
  implementsTy,
  narrowestSuperTypeOf,
} from './implements';
import { checkVars } from './check-vars';
import { attrsThatMustBeGroupedOn } from './expr/must-be-grouped-on';
import { isExprInterpretable } from './type-checker';
import { MacroArgsType, RelMacroChildren } from '../ast/base';
import { RelInterfaceShorthand } from './rel-interface';

export type RelTypeCheck = {
  readonly attributes: {
    readonly [attr: string]: Ty.ExtendedAttributeType;
  };
  readonly vars: {
    readonly [scope: string]: AST.MacroArgsType;
  };
};

type RelCheck<T extends AST.RelFR> = (rel: T) => RelTypeCheck | TyStackTrace;

const REL_TYPE_CHECK_CACHE: WeakMap<AST.RelFR, RelTypeCheck | TyStackTrace> =
  new WeakMap();

export const checkRel = (
  rel: AST.RelFR,
  opts: {
    implements?: { attributes: RelInterfaceShorthand };
  } = {}
): RelTypeCheck | TyStackTrace => {
  const existing = REL_TYPE_CHECK_CACHE.get(rel);
  let check: RelTypeCheck | TyStackTrace | undefined = existing;

  if (check === undefined) {
    const { t } = rel;
    switch (t) {
      case 'select':
        check = checkSelect(rel);
        break;
      case 'join':
        check = checkJoin(rel);
        break;
      case 'table':
        check = checkTable(rel);
        break;
      case 'file':
        check = checkFileContract(rel);
        break;
      case 'union':
        check = checkUnion(rel);
        break;
      case 'aggregate':
        check = checkAggregate(rel);
        break;
      case 'values':
        check = checkValues(rel);
        break;
      case 'information-schema':
        check = checkInformationSchema(rel);
        break;
      case 'generate-series':
        check = checkGenerateSeries(rel);
        break;
      case 'macro-apply-vars-to-rel':
        check = checkMacroApplyVarsToRel(rel);
        break;
      case 'macro-rel-case':
        check = checkMacroRelCase(rel);
        break;
      case 'rel-var':
        check = checkRelVar(rel);
        break;
      default:
        return Assert.unreachable(t);
    }
  }

  if (!existing) {
    REL_TYPE_CHECK_CACHE.set(rel, check);
  }

  if (check instanceof TyStackTrace) {
    return check;
  }

  if (opts.implements) {
    const isImplementedCorrectly = implementsRel({
      subject: check.attributes,
      reqs: opts.implements.attributes,
    });
    if (isImplementedCorrectly.isErr()) {
      return TyStackTrace.fromErr({}, isImplementedCorrectly.error);
    }
  }

  return check;
};

const checkMacroApplyVarsToRel: RelCheck<
  AST._MacroApplyVarsToRel<RelMacroChildren>
> = (rel) => {
  const fromD = checkRel(rel.sources.from);

  if (fromD instanceof TyStackTrace) {
    return fromD.withFrame({ frame: rel, location: 'from' });
  }

  const scope = fromD.vars[rel.scope] ?? { exprs: {}, rels: {}, sections: {} };

  const checkedVars = checkVars(scope, rel.vars);

  if (typeof checkedVars === 'function') {
    return checkedVars(rel);
  }

  const varRes = mergeMacroVars(
    fromD.vars,
    ...Object.values(checkedVars.rels).map((rel) => rel.vars),
    ...Object.values(checkedVars.sections).map((section) => section.vars),
    ...Object.values(checkedVars.exprs).map((expr) => expr.vars)
  );

  if (varRes instanceof Errs.TypeCheckError) {
    return TyStackTrace.fromErr({ frame: rel }, varRes);
  }

  const { [rel.scope]: _, ...newVars } = varRes;

  return { ...fromD, vars: newVars };
};

const checkMacroRelCase: RelCheck<AST._MacroRelCase<RelMacroChildren>> = (
  rel
) => {
  const checkedWhens: ExprTypeCheck[] = [];
  const checkedThens: RelTypeCheck[] = [];

  for (const { when, then } of rel.cases) {
    const whenD = checkExpr(when);

    if (whenD instanceof TyStackTrace) {
      return whenD.withFrame({ frame: rel });
    }

    if (!implementsTy({ subject: whenD.ty, req: 'boolean' })) {
      return TyStackTrace.fromErr(
        { frame: rel, location: 'condition' },
        new Errs.TypeDoesNotMatchExpectation({
          expected: 'boolean',
          found: whenD.ty,
        })
      );
    }

    if (!isExprInterpretable(when, { allow: { undefaultedVars: true } })) {
      return TyStackTrace.fromErr(
        { frame: rel, location: 'condition' },
        new Errs.ConstExprRequired({ name: 'when' })
      );
    }
    checkedWhens.push(whenD);

    const thenD = checkRel(then);

    if (thenD instanceof TyStackTrace) {
      return thenD.withFrame({ frame: rel });
    }

    checkedThens.push(thenD);
  }

  const checkedElse = checkRel(rel.else);

  if (checkedElse instanceof TyStackTrace) {
    return checkedElse;
  }

  const uniqAttrs = new Set(
    ...[...checkedThens, checkedElse].map((x) => Object.keys(x.attributes))
  );

  const mergedAttributes: Record<string, Ty.ExtendedAttributeType> = {};

  for (const name of uniqAttrs) {
    const tys = [
      ...checkedThens.map((then) => then.attributes[name]),
      checkedElse.attributes[name],
    ];

    if (tys.some((ty) => ty === undefined)) {
      return TyStackTrace.fromErr(
        { frame: rel },
        new Errs.MacroIfBranchesMustHaveTheSameType({
          attr: name,
          lhs: _.compact(tys)[0] ?? '*Not Found*',
          rhs: '*Not Found*',
        })
      );
    }

    const superType = narrowestSuperTypeOf(_.compact(tys));

    if (superType.isErr()) {
      return TyStackTrace.fromErr(
        { frame: rel },
        new Errs.MacroIfBranchesMustHaveTheSameType({
          attr: name,
          lhs: superType.error.lhs,
          rhs: superType.error.rhs,
        })
      );
    }

    mergedAttributes[name] = superType.value;
  }

  const mergedVars = mergeMacroVars(
    checkedElse.vars,
    ...checkedThens.map((x) => x.vars),
    ...checkedWhens.map((x) => x.vars)
  );

  if (mergedVars instanceof Errs.TypeCheckError) {
    return TyStackTrace.fromErr({ frame: rel }, mergedVars);
  }

  const res: RelTypeCheck = {
    attributes: mergedAttributes,
    vars: mergedVars,
  };

  return res;
};

const checkRelVar: RelCheck<AST._RelVar<RelMacroChildren>> = (rel) => {
  let defaultD: RelTypeCheck | TyStackTrace | null = null;

  if (rel.default !== null) {
    defaultD = checkRel(rel.default);
    if (defaultD instanceof TyStackTrace) {
      return defaultD.withFrame({ frame: rel, location: 'default' });
    }

    const wanted: RelInterfaceShorthand = _.mapValues(rel.attributes, (ty) => [
      ty,
    ]);

    const implCheck = implementsRel({
      subject: defaultD.attributes,
      reqs: wanted,
    });

    if (implCheck.isErr()) {
      return TyStackTrace.fromErr({ frame: rel }, implCheck.error);
    }
  }

  const vars = mergeMacroVars(defaultD?.vars ?? {}, {
    [rel.scope]: {
      rels: {
        [rel.name]: { type: rel.attributes, defaulted: rel.default !== null },
      },
      exprs: {},
      sections: {},
    },
  });

  if (vars instanceof Errs.TypeCheckError) {
    return vars.toStackTrace({ frame: rel });
  }

  const res: RelTypeCheck = {
    attributes: rel.attributes,
    vars,
  };

  return res;
};

const checkGenerateSeries: RelCheck<AST._GenerateSeries<RelMacroChildren>> = (
  rel
) => {
  const { start, stop } = rel;

  const startD = checkExpr(start, {});

  if (startD instanceof TyStackTrace) {
    return startD.withFrame({ frame: rel, location: 'start' });
  }

  // TODO: Check constness
  if (!implementsTy({ subject: startD.ty, req: Ty.nn('int') })) {
    return TyStackTrace.fromErr(
      { frame: rel, location: 'start' },
      new Errs.TypeDoesNotMatchExpectation({
        found: startD.ty,
        expected: Ty.nn('int'),
      })
    );
  }

  const stopD = checkExpr(stop, {});
  if (stopD instanceof TyStackTrace) {
    return stopD.withFrame({ frame: rel, location: 'start' });
  }

  // TODO: Check constness
  if (!implementsTy({ subject: stopD.ty, req: Ty.nn('int') })) {
    return TyStackTrace.fromErr(
      { frame: rel, location: 'stop' },
      new Errs.TypeDoesNotMatchExpectation({
        found: stopD.ty,
        expected: Ty.nn('int'),
      })
    );
  }

  const vars = mergeMacroVars(startD.vars, stopD.vars);

  if (vars instanceof Errs.TypeCheckError) {
    return TyStackTrace.fromErr({ frame: rel }, vars);
  }

  return {
    attributes: {
      n: { nullable: false, ty: { k: 'primitive', t: 'int' }, tags: [] },
    },
    vars,
  };
};

const checkFileContract: RelCheck<AST._FileContract> = (file) => {
  return { attributes: file.attributes, vars: {} };
};

const checkInformationSchema: RelCheck<AST._InformationSchema> = (info) => {
  if (info.schemas.length === 0) {
    return new Errs.BadInformationSchemaSchemaCount().toStackTrace({
      frame: info,
    });
  }
  let attributes: { [name: string]: Ty.Shorthand };

  switch (info.type) {
    case 'columns':
      attributes = {
        column_name: 'string',
        table_name: 'string',
        table_schema: 'string',
        data_type: 'string',
        is_nullable: 'boolean',
      };
      break;
    case 'tables':
      attributes = {
        table_name: 'string',
        table_schema: 'string',
      };
      break;
    default:
      return Assert.unreachable(info.type);
  }

  return { attributes: _.mapValues(attributes, (ty) => Ty.ty(ty)), vars: {} };
};

const checkAggregate: RelCheck<AST._Aggregate<RelMacroChildren>> = (agg) => {
  const maybeErr = conflictingAttributesCheck(Object.keys(agg.selection));

  if (maybeErr instanceof Errs.TypeCheckError) {
    return TyStackTrace.fromErr({ frame: agg }, maybeErr);
  }

  const fromD = checkRel(agg.sources.from);

  if (fromD instanceof TyStackTrace) {
    return fromD.withFrame({ frame: agg, location: 'from' });
  }

  const selection = checkSelectionExprs(
    agg.selection,
    { from: { ...fromD, groupedAttributes: agg.groupedAttributes } },
    { aggregating: true }
  );

  if (Array.isArray(selection)) {
    const [location, err] = selection;
    const frame = { frame: agg, location };
    return err instanceof TyStackTrace
      ? err.withFrame(frame)
      : TyStackTrace.fromErr(frame, err);
  }

  const attributes = _.mapValues(selection, ({ ty }) => ty);

  const marcoVars = mergeMacroVars(
    ...Object.values(selection).map((expr) => expr.vars),
    fromD.vars
  );

  if (marcoVars instanceof Errs.TypeCheckError) {
    return TyStackTrace.fromErr({ frame: agg }, marcoVars);
  }

  return { attributes, vars: marcoVars };
};

const checkValues: RelCheck<AST._ValuesContract> = (valuesLiteral) => {
  const { attributes, values } = valuesLiteral;

  for (const [index, row] of values.map((v, i) => [i, v] as const)) {
    for (const [attrName, ty] of Object.entries(attributes)) {
      const val = row[attrName] ?? null;
      if (!Parser.validatorForType(ty).safeParse(val).success) {
        return TyStackTrace.fromErr(
          { frame: valuesLiteral, location: ['position', index] },
          new Errs.InvalidValuesLiteral({
            attrName,
            found: val,
            expected: ty,
          })
        );
      }
    }
  }

  return { attributes, vars: {} };
};

const checkTable: RelCheck<AST._TableContract> = (
  table: AST._TableContract
) => {
  const { attributes } = table;
  return { attributes, vars: {} };
};

const checkUnion: RelCheck<AST._Union<RelMacroChildren>> = (union) => {
  const leftD = checkRel(union.sources.left);
  const rightD = checkRel(union.sources.right);

  if (leftD instanceof TyStackTrace) {
    return leftD.withFrame({ frame: union, location: 'left' });
  }

  if (rightD instanceof TyStackTrace) {
    return rightD.withFrame({ frame: union, location: 'right' });
  }

  const uniqAttributeNames = new Set([
    ...Object.keys(leftD.attributes),
    ...Object.keys(rightD.attributes),
  ]);

  const mergedAttributes: Record<string, Ty.ExtendedAttributeType> = {};

  for (const name of uniqAttributeNames) {
    const left = leftD.attributes[name];
    const right = rightD.attributes[name];

    if (!left || !right) {
      return TyStackTrace.fromErr(
        { frame: union, location: 'right' },
        new Errs.UnionColumnsMustMatch({ left, right, name })
      );
    }

    const superType = narrowestSuperTypeOf([left, right]);

    if (superType.isOk()) {
      mergedAttributes[name] = superType.value;
    } else {
      return TyStackTrace.fromErr(
        { frame: union, location: 'right' },
        new Errs.UnionColumnsMustMatch({ left, right, name })
      );
    }
  }

  const macroVars = mergeMacroVars(leftD.vars, rightD.vars);

  if (macroVars instanceof Errs.TypeCheckError) {
    return TyStackTrace.fromErr({ frame: union }, macroVars);
  }

  return {
    attributes: mergedAttributes,
    vars: macroVars,
  };
};

const checkSelectionExprs = <Key extends string>(
  exprs: Record<Key, AST.ExprFR>,
  sources: Partial<Record<'left' | 'right' | 'from', AST.Source>>,
  opts: { aggregating: boolean }
):
  | Record<Key, ExprTypeCheck>
  | [Errs.TypeCheckErrorLocation, TyStackTrace | Errs.TypeCheckError] => {
  const checkedExprs: [Key, ExprTypeCheck][] = [];

  for (const [name, expr] of Object.entries<AST.ExprFR>(exprs)) {
    // Make sure the expression type checks
    const checkedExpr = checkExpr(expr, { sources });

    if (checkedExpr instanceof TyStackTrace) {
      return [['attribute', name], checkedExpr];
    }

    if (opts.aggregating) {
      const reqs = attrsThatMustBeGroupedOn(expr);
      for (const [source, attrNames] of Object.entries(reqs)) {
        const s = sources[source as 'left' | 'right' | 'from'];
        Assert.assert(s !== undefined, 'We already checked this exists');
        for (const attrName of attrNames) {
          if (!(s.groupedAttributes ?? []).includes(attrName)) {
            return [
              ['attribute', name],
              new Errs.TryingToSelectUnaggregatedAttributes({ attrName }),
            ];
          }
        }
      }
    } else {
      if (checkedExpr.aggregated) {
        return [
          ['attribute', name],
          new Errs.AggregateFunctionInNonAggregatedContext(),
        ];
      }
    }

    checkedExprs.push([name as Key, checkedExpr]);
  }

  return Object.fromEntries(checkedExprs) as Record<Key, ExprTypeCheck>;
};

const checkSelect: RelCheck<AST._Select<RelMacroChildren>> = (select) => {
  const fromD = checkRel(select.sources.from);

  if (fromD instanceof TyStackTrace) {
    return fromD.withFrame({ frame: select, location: 'from' });
  }

  const maybeErr = conflictingAttributesCheck(Object.keys(select.selection));

  if (maybeErr instanceof Errs.TypeCheckError) {
    return TyStackTrace.fromErr({ frame: select }, maybeErr);
  }

  const selection = checkSelectionExprs(
    select.selection,
    { from: fromD },
    { aggregating: false }
  );

  if (Array.isArray(selection)) {
    const [location, err] = selection;
    const frame = { location, frame: select };
    return err instanceof TyStackTrace
      ? err.withFrame(frame)
      : TyStackTrace.fromErr(frame, err);
  }

  const attributes = Object.fromEntries(
    Object.entries(selection).map(([name, expr]) => [name, expr.ty])
  );

  let condDVars: Record<string, MacroArgsType>;
  if (select.condition !== null) {
    const condD = checkExpr(select.condition, {
      sources: { from: fromD },
    });

    if (condD instanceof TyStackTrace) {
      return condD.withFrame({ frame: select, location: 'condition' });
    }

    if (condD.windowed) {
      return TyStackTrace.fromErr(
        { frame: select, location: 'condition' },
        new Errs.CantUseWindowInAWhereClause()
      );
    }
    condDVars = condD.vars;
  } else {
    condDVars = {};
  }

  const enumeratedOrderBys = select.orderBys.map((x, i) => [x, i] as const);
  const checkedOrderBys: ExprTypeCheck[] = [];

  for (const [{ expr }, i] of enumeratedOrderBys) {
    const from = { attributes: _.mapValues(selection, ({ ty }) => ty) };
    const d = checkExpr(expr, { sources: { from } });

    if (d instanceof TyStackTrace) {
      return d.withFrame({ frame: select, location: ['position', i] });
    }

    if (
      [Ty.isArrayType, Ty.isSuperType, Ty.isStructType].some((f) => f(d.ty))
    ) {
      return TyStackTrace.fromErr(
        { frame: select },
        new Errs.CantOrderByStructArrayOrSuper({ attempted: d.ty })
      );
    }

    checkedOrderBys.push(d);
  }

  const macroVars = mergeMacroVars(
    ...Object.values(selection).map((expr) => expr.vars),
    ...checkedOrderBys.map(({ vars }) => vars),
    condDVars,
    fromD.vars
  );

  if (macroVars instanceof Errs.TypeCheckError) {
    return TyStackTrace.fromErr({ frame: select }, macroVars);
  }

  const res: RelTypeCheck = { attributes, vars: macroVars };

  return res;
};

const checkJoin: RelCheck<AST._Join<RelMacroChildren>> = (join) => {
  const maybeErr = conflictingAttributesCheck(Object.keys(join.selection));

  if (maybeErr instanceof Errs.TypeCheckError) {
    return TyStackTrace.fromErr({ frame: join }, maybeErr);
  }

  // Check the join sources
  const leftD = checkRel(join.sources.left);
  const rightD = checkRel(join.sources.right);

  if (leftD instanceof TyStackTrace) {
    return leftD.withFrame({ frame: join, location: 'left' });
  }

  if (rightD instanceof TyStackTrace) {
    return rightD.withFrame({ frame: join, location: 'right' });
  }

  // TODO we need to adjust the nullability of the types of the join to
  // propagate the fact that opposing side of an outer join is now nullable
  // I.e.
  //
  // inner join => no adjustment
  // left join => right side is now nullable
  // right join => left side is now nullable
  const sources = { left: leftD, right: rightD };

  // Check the join condition
  const condD = checkExpr(join.condition, { sources });

  if (condD instanceof TyStackTrace) {
    return condD.withFrame({ frame: join, location: 'on' });
  }

  if (condD.ty.ty.k !== 'primitive' || condD.ty.ty.t !== 'boolean') {
    return TyStackTrace.fromErr(
      { frame: join, location: 'on' },
      new Errs.TypeDoesNotMatchExpectation({
        found: condD.ty,
        expected: Ty.shorthandToTy('boolean'),
      })
    );
  }

  // Check the select
  const selection = checkSelectionExprs(join.selection, sources, {
    aggregating: false,
  });

  if (Array.isArray(selection)) {
    const [location, err] = selection;
    const frame = { frame: join, location };
    return err instanceof TyStackTrace
      ? err.withFrame(frame)
      : TyStackTrace.fromErr(frame, err);
  }

  const attributes = _.mapValues(selection, ({ ty }) => ty);

  const macroVars = mergeMacroVars(
    leftD.vars,
    rightD.vars,
    ...Object.values(selection).map((expr) => expr.vars)
  );

  if (macroVars instanceof Errs.TypeCheckError) {
    return TyStackTrace.fromErr({ frame: join }, macroVars);
  }

  return { attributes, vars: macroVars };
};

const conflictingAttributesCheck = (
  attrs: string[]
): Errs.TypeCheckError | null => {
  const seenAttrs = new Map<string, string>();

  for (const attr of attrs) {
    const lowered = attr.toLowerCase();
    const existing = seenAttrs.get(lowered);

    if (existing) {
      return new Errs.AttributesCantBeTheSameLettersInDifferentCase({
        lhs: existing,
        rhs: attr,
      });
    } else {
      seenAttrs.set(lowered, attr);
    }
  }

  return null;
};
