import _ from 'lodash';
import { AST } from '../ast';
import { Vars } from './base';
import { Assert } from '@cotera/utilities';
import { z } from 'zod';
import { Interpreter } from '../interpreter';
import deepEquals from 'fast-deep-equal';
import { buildCacheEntryFn } from './expansion-cache';
import { TC, TyStackTrace } from '../type-checker';

const cacheFor = buildCacheEntryFn<AST.ExprFR, AST.ExprIR>({
  openScopes: (expr) => {
    const tc = TC.checkExpr(expr);
    Assert.assert(!(tc instanceof TyStackTrace));
    return new Set(Object.keys(tc.vars));
  },
});

export const expandExpr = (expr: AST.ExprFR, vars: Vars): AST.ExprIR => {
  const { t } = expr;

  const cacheEntry = cacheFor(expr, vars);

  if (cacheEntry.t === 'existing') {
    return cacheEntry.val;
  }

  let ir: AST.ExprIR;

  switch (t) {
    case 'scalar':
    case 'attr':
      ir = expr;
      break;
    case 'expr-var': {
      const variable = vars[expr.scope]?.exprs?.[expr.name];
      if (variable === undefined) {
        if (expr.default === null) {
          throw new Error(
            `Missing non defaulted expression variable "${expr.scope}"."${expr.name}"`
          );
        } else {
          ir = expandExpr(expr.default, vars);
        }
      } else {
        ir = expandExpr(variable, vars);
      }
      break;
    }
    case 'cast':
      ir = { ...expr, expr: expandExpr(expr.expr, vars) };
      break;
    case 'invariants':
      ir = {
        ...expr,
        expr: expandExpr(expr.expr, vars),
        invariants: _.mapValues(expr.invariants, (expr) =>
          expandExpr(expr, vars)
        ),
      };
      break;
    case 'function-call':
      ir = {
        ...expr,
        args: expr.args.map((expr) => expandExpr(expr, vars)),
      };
      break;
    case 'window':
      ir = {
        ...expr,
        args: expr.args.map((expr) => expandExpr(expr, vars)),
        over: {
          partitionBy: expr.over.partitionBy.map((expr) =>
            expandExpr(expr, vars)
          ),
          orderBy: expr.over.orderBy.map(({ direction, expr }) => ({
            direction,
            expr: expandExpr(expr, vars),
          })),
        },
      };
      break;
    case 'case':
      ir = {
        ...expr,
        cases: expr.cases.map(({ when, then }) => ({
          when: expandExpr(when, vars),
          then: expandExpr(then, vars),
        })),
        else: expr.else ? expandExpr(expr.else, vars) : undefined,
      };
      break;
    case 'get-field':
      ir = {
        ...expr,
        expr: expandExpr(expr.expr, vars),
      };
      break;
    case 'make-struct':
      ir = {
        ...expr,
        fields: _.mapValues(expr.fields, (field) => expandExpr(field, vars)),
      };
      break;
    case 'make-array':
      ir = {
        ...expr,
        elements: expr.elements.map((elem) => expandExpr(elem, vars)),
      };
      break;
    case 'macro-expr-case': {
      ir = expandExpr(expr.else, vars);

      loop: for (const { when, then } of expr.cases) {
        const cond = z
          .boolean()
          .nullable()
          .parse(Interpreter.evalExprIR(expandExpr(when, vars)));

        if (cond) {
          ir = expandExpr(then, vars);
          break loop;
        }
      }

      break;
    }
    case 'macro-apply-vars-to-expr': {
      const combinedVars: Vars = {
        ...vars,
        [expr.scope]: {
          exprs: _.mapValues(expr.vars.exprs, (expr) => expandExpr(expr, vars)),
        },
      };

      ir = expandExpr(expr.sources.from, combinedVars);
      break;
    }
    default:
      return Assert.unreachable(t);
  }

  const res = deepEquals(ir, expr) ? (expr as AST.ExprIR) : ir;
  cacheEntry.set(res);
  return res;
};
