import { ISO8601_REGEX, Freeze } from '../utils';
import _ from 'lodash';
import { z } from 'zod';
import { AST } from '../ast';
import { Case, Constant, MakeArray, Not } from '../builder/utilities';
import { Relation } from '../builder/relation';
import { Ty } from '../ty';
import { Interpreter } from '../interpreter';
import { SqlDialect } from './dialects';
import { FUNCTION_SQL_RULES } from './func-sql-rules';
import {
  intersperse,
  ParameterizedSql,
  pSqlMacro,
  SqlExprAst,
  sqlExprMacro,
  SqlParam,
  SqlRelAst,
  sqlRelMacro,
} from './sql-ast';
export type { SqlDialect } from './dialects/dialect';
export { originForIR } from './origin';
export { hashIR } from './hash-ir';
export { maxPossibleRows } from './max-possible-rows';
export * as Dialects from './dialects/dialects';
import { TC, TyStackTrace } from '../type-checker';
import { expandRel } from '../macros/expand-rel';
import { Expression } from '../builder/expression';
import { Assert } from '@cotera/utilities';

export type CompiledQuery = {
  readonly sql: string;
  readonly params: readonly Ty.Scalar[];
};

export type SqlGenOpts = {
  header?: string;
};

export type CompilerOpts = {
  dialect: SqlDialect;
  sqlGen: SqlGenOpts;
};

type CteRef = { t: 'cte'; name: number };

const makeUniqueAliasGenerator = (): (() => number) => {
  let n = 0;
  return () => {
    n += 1;
    return n;
  };
};

export const toSql = (
  rel: AST.RelFR | AST.RelIR,
  opts: CompilerOpts
): { sql: string; params: Ty.Scalar[] } => {
  const expanded = expandRel(rel, {});
  const base = toSqlRelAst(expanded, opts);

  const ctes = toCtes(base.ast, base.ctx, null, makeUniqueAliasGenerator());
  const params: string[] = [];

  const toParam = ({ val }: SqlParam): string => {
    const existing = params.findIndex((x) => x === val);
    if (existing === -1) {
      const placeholder = opts.dialect.placeholder({
        oneIndexedArgNum: params.length + 1,
      });
      params.push(val);
      return placeholder;
    }
    return opts.dialect.placeholder({ oneIndexedArgNum: existing + 1 });
  };

  const sqlToCteMap = new Map<string, number>();
  const cteRemapping = new Map<number, number>();

  const getRemappedNode = (n: number): number | undefined => {
    const exisiting = cteRemapping.get(n);
    return exisiting !== undefined
      ? getRemappedNode(exisiting) ?? exisiting
      : undefined;
  };

  const query = _.chain(ctes)
    .reverse()
    .map(
      ({ alias, sql }): { alias: number | null; sqlString: string } | null => {
        const [first, ...rest] = sql;

        if (
          rest.length === 0 &&
          // eslint-disable-next-line @typescript-eslint/strict-boolean-expressions
          first &&
          typeof first !== 'string' &&
          first.t === 'cte'
        ) {
          if (alias === null) {
            return {
              alias: null,
              sqlString: `select * from ${opts.dialect.relation(
                `cte${getRemappedNode(first.name) ?? first.name}`
              )}`,
            };
          } else {
            cteRemapping.set(alias, first.name);
            return null;
          }
        }

        const sqlString = sql
          .map((node) => {
            if (typeof node === 'string') {
              return node;
            }

            const { t } = node;
            switch (t) {
              case 'cte': {
                // Use the remapping to figure out the actual ref
                const ref = getRemappedNode(node.name) ?? node.name;
                return opts.dialect.relation(`cte${ref}`);
              }
              case 'param':
                return toParam(node);
              default:
                return Assert.unreachable(t);
            }
          })
          .join('');

        if (alias === null) {
          return { alias, sqlString };
        }

        const exisiting = sqlToCteMap.get(sqlString);

        // This CTE already exists, so dont create a new one
        if (exisiting !== undefined) {
          Assert.assert(alias !== null);
          cteRemapping.set(alias, exisiting);
          return null;
        }

        sqlToCteMap.set(sqlString, alias);
        return { alias, sqlString };
      }
    )
    .compact()
    .map(({ alias, sqlString }, i, allCtes) => {
      if (alias === null) {
        return `\n${sqlString}`;
      }

      return `${i === 0 ? 'with ' : ''}${opts.dialect.attr(
        `cte${alias}`
      )} as (\n  ${sqlString}\n)${i + 2 < allCtes.length ? ', ' : ''}`;
    })
    .join('')
    .join('')
    .value();

  return {
    sql: `${opts.sqlGen?.header ? opts.sqlGen.header : ''}${query.trim()}`,
    params,
  };
};

type Cte = {
  alias: null | number;
  sql: (string | SqlParam | CteRef)[];
};

const toCtes = (
  relAst: SqlRelAst,
  ctx: CompilerOpts,
  myAlias: null | number,
  genUniqueAlias: () => number
): Cte[] => {
  // Typescript gets confused when flattening a recursive type;
  const flattened = (relAst as any).flat(Infinity) as (
    | string
    | SqlParam
    | AST.RelIR
  )[];

  const sql: (string | SqlParam | CteRef)[] = [];
  const children: Cte[] = [];

  for (const node of flattened) {
    if (
      typeof node === 'string' ||
      !('t' in node) ||
      ('t' in node && node.t === 'param')
    ) {
      sql.push(node);
    } else {
      const child = toSqlRelAst(node, ctx);
      const name = genUniqueAlias();
      sql.push({ t: 'cte', name });
      children.push(...toCtes(child.ast, child.ctx, name, genUniqueAlias));
    }
  }

  return [{ alias: myAlias, sql }, ...children];
};

type RelationCompiler<T extends AST.RelIR> = (
  rel: T,
  ctx: CompilerOpts
) => { ast: SqlRelAst; ctx: CompilerOpts };

const toSqlRelAst: RelationCompiler<AST.RelIR> = (rel, ctx) => {
  const { t } = rel;
  switch (t) {
    case 'table':
      return compileTable(rel, ctx);
    case 'select':
      return compileSelect(rel, ctx);
    case 'aggregate':
      return compileAggregate(rel, ctx);
    case 'values':
      return compileValues(rel, ctx);
    case 'join':
      return compileJoin(rel, ctx);
    case 'union':
      return compileUnion(rel, ctx);
    case 'generate-series':
      return compileGenerateSeries(rel, ctx);
    case 'information-schema':
      return compileInformationSchema(rel, ctx);
    case 'file':
      return compileFileSchema(rel, ctx);
    default:
      return Assert.unreachable(t);
  }
};

const compileFileSchema: RelationCompiler<AST._FileContract> = (file, ctx) => {
  if (!ctx.dialect.supportsFileSources) {
    throw new Error('This dialect does not support file sources');
  }
  const ast = sqlRelMacro`select * from "${file.uri}"`;
  return { ast, ctx };
};

const compileInformationSchema: RelationCompiler<AST._InformationSchema> = (
  rel,
  ctx
) => {
  const ast = Relation.wrap(
    ctx.dialect.informationSchema[rel.type](rel.schemas)
  ).ir();
  return { ast: [ast], ctx };
};

const compileGenerateSeries: RelationCompiler<
  AST._GenerateSeries<AST.IRChildren>
> = (rel, ctx) => {
  const sqlAst = ctx.dialect.generateSeries({
    start: z.number().parse(Interpreter.evalExprIR(rel.start)),
    stop: z.number().parse(Interpreter.evalExprIR(rel.stop)),
  });
  return { ast: compileSqlExprAst(sqlAst, ctx), ctx };
};

const compileUnion: RelationCompiler<AST._Union<AST.IRChildren>> = (
  union,
  ctx
) => {
  const ast: SqlRelAst = sqlRelMacro`select * from ${
    union.sources.left
  } union ${union.all ? 'all' : 'distinct'} select * from ${
    union.sources.right
  }`;

  return { ast, ctx };
};

const compileJoin: RelationCompiler<AST._Join<AST.IRChildren>> = (
  join,
  ctx
) => {
  const exprsSql: ParameterizedSql[] = _.sortBy(
    Object.entries(join.selection),
    ([name]) => name
  ).map(
    ([name, expr]) =>
      pSqlMacro`${exprToPSql(expr, ctx)} as ${ctx.dialect.attr(name)}`
  );

  const left =
    join.sources.left.t === 'table'
      ? renderTableIdentifier(join.sources.left, ctx)
      : join.sources.left;

  const right =
    join.sources.right.t === 'table'
      ? renderTableIdentifier(join.sources.right, ctx)
      : join.sources.right;

  const ast: SqlRelAst = sqlRelMacro`select ${intersperse(exprsSql, [
    ', ',
  ])} from ${left} as ${ctx.dialect.attr('left')} ${
    join.how
  } join ${right} as ${ctx.dialect.attr('right')} on ${exprToPSql(
    join.condition,
    ctx
  )}`;

  return { ast, ctx };
};

const compileValues: RelationCompiler<AST._ValuesContract> = (values, ctx) => {
  const ast: SqlExprAst = ctx.dialect.values(values);
  return { ast: compileSqlExprAst(ast, ctx), ctx };
};

const renderTableIdentifier = (rel: AST._TableContract, ctx: CompilerOpts) =>
  ctx.dialect.relation(rel);

const compileTable: RelationCompiler<AST._TableContract> = (rel, ctx) => {
  const exprsSql =
    Object.keys(rel.attributes).length > 0
      ? _.chain(rel.attributes)
          .keys()
          .sort()
          .map((name) => ctx.dialect.attr(name))
          .join(', ')
          .value()
      : `null`;

  const ast = `select ${exprsSql} from ${renderTableIdentifier(rel, ctx)}`;

  return { ast: [ast], ctx };
};

const compileSelect: RelationCompiler<AST._Select<AST.IRChildren>> = (
  select,
  ctx
) => {
  const exprsSql: ParameterizedSql =
    Object.keys(select.selection).length > 0
      ? intersperse(
          _.chain(select.selection)
            .entries()
            .sortBy(([name]) => name)
            .map(
              ([name, expr]) =>
                pSqlMacro`${exprToPSql(expr, ctx)} as ${ctx.dialect.attr(
                  name
                )}` as ParameterizedSql[]
            )
            .value(),
          pSqlMacro`, `
        )
      : pSqlMacro`null`;

  const from =
    select.sources.from.t === 'table'
      ? renderTableIdentifier(select.sources.from, ctx)
      : select.sources.from;

  let whereClause: ParameterizedSql = pSqlMacro``;

  if (select.condition !== null) {
    if (TC.isExprInterpretable(select.condition)) {
      const val = z.boolean().parse(Interpreter.evalExprIR(select.condition));

      if (val) {
        whereClause = pSqlMacro``;
      } else {
        whereClause = pSqlMacro` where ${val.toString()}`;
      }
    } else {
      whereClause = pSqlMacro` where ${exprToPSql(select.condition, ctx)}`;
    }
  }

  const sorts: ParameterizedSql = select.orderBys.map(
    ({ expr, direction }) => pSqlMacro`${exprToPSql(expr, ctx)} ${direction}`
  );

  const orderBy =
    sorts.length > 0
      ? pSqlMacro` order by ${intersperse(sorts, [', '])}`
      : pSqlMacro``;

  const ast: SqlRelAst = sqlRelMacro`select ${
    select.distinct ? 'distinct ' : ''
  }${exprsSql} from ${from}${whereClause}${orderBy}${
    typeof select.limit === 'number' ? ` limit ${select.limit}` : ''
  }${typeof select.offset === 'number' ? ` offset ${select.offset}` : ''}`;

  return { ast, ctx };
};

const compileAggregate: RelationCompiler<AST._Aggregate<AST.IRChildren>> = (
  agg,
  ctx
) => {
  const exprsSql: ParameterizedSql =
    Object.keys(agg.selection).length > 0
      ? intersperse(
          _.sortBy(Object.entries(agg.selection), ([name]) => name).map(
            ([name, expr]) =>
              pSqlMacro`${exprToPSql(expr, ctx)} as ${ctx.dialect.attr(name)}`
          ),
          [', ']
        )
      : pSqlMacro`null`;

  const from =
    agg.sources.from.t === 'table'
      ? renderTableIdentifier(agg.sources.from, ctx)
      : agg.sources.from;

  const groupBySql: string =
    agg.groupedAttributes.length === 0
      ? ''
      : ` group by ${agg.groupedAttributes
          .map((name) => ctx.dialect.attr(name))
          .join(', ')}`;

  const ast: SqlRelAst = sqlRelMacro`select ${exprsSql} from ${from}${groupBySql}`;

  return { ast, ctx };
};

type ExpressionCompiler<T extends AST.ExprIR> = (
  expr: T,
  opts: CompilerOpts
) => ParameterizedSql;

const compileSqlExprAst = (
  ast: SqlExprAst,
  ctx: CompilerOpts
): ParameterizedSql => {
  return ast.map((section) => {
    if (typeof section === 'string') {
      return section;
    }
    if (Freeze.isReadonlyArray(section)) {
      return compileSqlExprAst(section, ctx);
    }

    return exprToPSql(section, ctx);
  });
};

export const exprToPSql: ExpressionCompiler<AST.ExprIR> = (expr, ctx) => {
  let foldedExpr: AST.ExprIR;

  if (TC.isExprInterpretable(expr)) {
    const val = Interpreter.evalExprIR(expr);
    foldedExpr = Constant(val, { ty: Expression.fromAst(expr).ty }).ir();
  } else {
    foldedExpr = expr;
  }

  switch (foldedExpr.t) {
    case 'scalar': {
      const { val, ty } = foldedExpr;
      if (val === null) {
        const ast = ctx.dialect.nullLiteral(ty.ty);
        return compileSqlExprAst(ast, ctx);
      }

      if (ty.ty.k === 'primitive') {
        const scalarTy = ty.ty.t;

        switch (scalarTy) {
          case 'super': {
            return exprToPSql(Constant(val).cast('super').ir(), ctx);
          }
          case 'int':
          case 'float':
          case 'boolean': {
            const t = typeof val;
            Assert.assert(t === 'number' || t === 'boolean');
            return pSqlMacro`${val.toString()}`;
          }
          case 'string': {
            Assert.assert(typeof val === 'string');
            if (ctx.dialect.scalarLiteralOverrides?.string) {
              return ctx.dialect.scalarLiteralOverrides.string(val);
            }
            if (val === '') {
              return pSqlMacro`''`;
            }

            return /^[!*%><&/a-zA-Z0-9-_ .:,{}@"’'`()?=]+$/.test(val)
              ? pSqlMacro`'${val.replaceAll("'", "\\'")}'`
              : pSqlMacro`${{ val, t: 'param' }}`;
          }
          case 'day':
          case 'month':
          case 'year':
          case 'timestamp': {
            if (val instanceof Date) {
              const iso8601 = val.toISOString();
              return exprToPSql(Constant(iso8601).cast('timestamp').ir(), ctx);
            }

            Assert.assert(
              typeof val === 'string',
              'Type checker requires this'
            );

            if (ISO8601_REGEX.test(val)) {
              return exprToPSql(Constant(val).cast('timestamp').ir(), ctx);
            }

            throw new Error(
              `Invalid Timestamp in SQL gen got: ${JSON.stringify(val)}`
            );
          }
          default:
            return Assert.unreachable(scalarTy);
        }
      } else if (ty.ty.k === 'enum') {
        return exprToPSql(Constant(val, { ty: 'string' }).ir(), ctx);
      } else if (ty.ty.k === 'struct') {
        Assert.assert(
          val !== null && typeof val === 'object' && !(val instanceof Date),
          'TypeChecker requires this'
        );
        Assert.assert(!Freeze.isReadonlyArray(val));

        const struct: AST._MakeStruct<AST.IRChildren> = {
          t: 'make-struct',
          fields: _.mapValues(ty.ty.fields, (field, name) =>
            Constant(val[name] ?? null, { ty: field }).ir()
          ),
        };

        const typedIR = Expression.fromAst(struct, {}).ir();
        Assert.assert(typedIR.t === 'make-struct');
        const ast = ctx.dialect.makeStruct(typedIR.fields);
        return compileSqlExprAst(ast, ctx);
      } else if (ty.ty.k === 'array') {
        const t = ty.ty.t;
        Assert.assert(Freeze.isReadonlyArray(val));
        const arr = MakeArray(
          val.map((x) => Constant(x, { ty: Ty.nn(t) }))
        ).ir();
        Assert.assert(arr.t === 'make-array');
        const ast = ctx.dialect.makeArray(arr);
        return compileSqlExprAst(ast, ctx);
      } else if (ty.ty.k === 'record') {
        Assert.assert(
          val !== null && typeof val === 'object' && !(val instanceof Date),
          'TypeChecker requires this'
        );
        Assert.assert(!Freeze.isReadonlyArray(val));

        const inner = ty.ty.t;
        const typedIR = _.mapValues(val, (item) =>
          Constant(item, { ty: Ty.ty(inner) }).ir()
        );
        const ast = ctx.dialect.makeRecord(typedIR);
        return compileSqlExprAst(ast, ctx);
      } else if (ty.ty.k === 'id') {
        return exprToPSql(Constant(val, { ty: ty.ty.t }).ir(), ctx);
      } else if (ty.ty.k === 'range') {
        return exprToPSql(Constant(val, { ty: 'int' }).ir(), ctx);
      } else {
        return Assert.unreachable(ty.ty);
      }
    }
    case 'attr':
      return foldedExpr.source === 'from'
        ? pSqlMacro`${ctx.dialect.attr(foldedExpr.name)}`
        : pSqlMacro`${ctx.dialect.relation({
            schema: foldedExpr.source,
            name: foldedExpr.name,
          })}`;
    case 'cast': {
      const ast = ctx.dialect.cast(foldedExpr.expr, foldedExpr.targetTy);
      return compileSqlExprAst(ast, ctx);
    }
    case 'function-call': {
      const rule = FUNCTION_SQL_RULES[foldedExpr.op];
      const override = ctx.dialect.functionOverrides[foldedExpr.op];

      if (override instanceof Error) {
        throw override;
      }

      if (override) {
        const wantedTy = TC.checkExpr(foldedExpr);
        Assert.assert(!(wantedTy instanceof TyStackTrace));
        const ast = override(foldedExpr.args, wantedTy.ty.ty);
        return compileSqlExprAst(ast, ctx);
      }

      if (rule === 'override') {
        throw new Error(
          `Function ${foldedExpr.op} required an override but it was not provided by the dialect`
        );
      }

      const ast = rule(foldedExpr.op)(foldedExpr.args);
      return compileSqlExprAst(ast, ctx);
    }
    case 'window': {
      const {
        op,
        args,
        over: { orderBy, partitionBy },
        frame,
      } = foldedExpr;
      const argsExprs = args.map((arg) => exprToPSql(arg, ctx));
      const partitionExprs = partitionBy.map((expr) => exprToPSql(expr, ctx));
      const orderExprs = orderBy.map(
        ({ expr, direction }) =>
          pSqlMacro`${exprToPSql(expr, ctx)} ${direction}`
      );

      return pSqlMacro`${op}(${intersperse(argsExprs, [', '])}) over (${
        partitionBy.length > 0
          ? pSqlMacro`partition by ${intersperse(partitionExprs, [', '])} `
          : pSqlMacro``
      }${
        orderBy.length > 0
          ? pSqlMacro`order by ${intersperse(orderExprs, [', '])}${
              frame
                ? ` rows between ${frame.preceding} preceding and ${frame.following} following`
                : ''
            }`
          : pSqlMacro``
      })`;
    }
    case 'case': {
      const ast = sqlExprMacro`case ${intersperse(
        foldedExpr.cases.map(
          ({ when, then }) => sqlExprMacro`when ${when} then ${then}`
        ),
        [' ']
      )}${foldedExpr.else ? sqlExprMacro` else ${foldedExpr.else}` : ''} end`;
      return compileSqlExprAst(ast, ctx);
    }
    case 'get-field': {
      const { ty: wantedTy } = Expression.fromAst(foldedExpr);
      const ast = ctx.dialect.getPropertyFromStruct(
        foldedExpr.expr,
        foldedExpr.name,
        wantedTy.ty
      );

      return compileSqlExprAst(ast, ctx);
    }
    case 'make-struct': {
      const ast = ctx.dialect.makeStruct(foldedExpr.fields);
      return compileSqlExprAst(ast, ctx);
    }
    case 'make-array': {
      const ast = ctx.dialect.makeArray(foldedExpr);
      return compileSqlExprAst(ast, ctx);
    }
    case 'invariants': {
      const { invariants, expr } = foldedExpr;
      const elseClause = Expression.fromAst(expr);
      const invariantsExpr = Case(
        _.chain(invariants)
          .entries()
          .sortBy(([name, __invariant]) => name)
          .map(([name, invariantExpr]) => {
            const invariant = Expression.fromAst(invariantExpr);
            // We need to have the runtime component on concating before we
            // cast in order to defer execution and beat constant folding. In
            // postgres and redshift we _must_ use the invariant in the
            // resulting expression. Relying on execution order is finicky so
            // be very careful this works in all warehouses
            return {
              when: Not(invariant),
              then: Constant(`Invariant *${name}* failed! Got `, {
                impure: true,
              })
                .concat(invariant.cast('string'))
                .cast('int')
                .cast('super')
                .cast(elseClause.ty),
            };
          })
          .value(),
        { else: elseClause }
      );

      return exprToPSql(invariantsExpr.ir(), ctx);
    }
    default:
      return Assert.unreachable(foldedExpr);
  }
};
