import { intersperse, SqlExprAst, sqlExprMacro } from './sql-ast';
import _ from 'lodash';
import type { AST } from '../ast';
import { Assert } from '@cotera/utilities';
import { Expression } from '../builder/expression';
import { And, Constant, Not } from '../builder/utilities';
import { isFunctionIdentifier } from '../ast/func-identifier';

export type FnCompilationRule = (args: readonly AST.ExprIR[]) => SqlExprAst;

type FnCompilationRuleBuilder = (op: string) => FnCompilationRule;

const infix: FnCompilationRuleBuilder = (op) => (args) =>
  sqlExprMacro`(${intersperse(args as SqlExprAst, [` ${op} `])})`;

const infixAlias =
  (alias: string): FnCompilationRuleBuilder =>
  (_op) =>
    infix(alias);
const functionCallAlias =
  (alias: string): FnCompilationRuleBuilder =>
  (_op) =>
    functionCall(alias);

const prefix: FnCompilationRuleBuilder = (op) => (args) =>
  sqlExprMacro`(${op} ${args})`;

const functionCall: FnCompilationRuleBuilder = (op) => (args) =>
  sqlExprMacro`${op}(${intersperse(args as SqlExprAst, [', '])})`;

const interpreter: FnCompilationRuleBuilder = (op) => (args) => {
  Assert.assert(isFunctionIdentifier(op));

  const result = Expression.fromAst({
    t: 'function-call',
    op,
    args,
  }).evaluate();

  return sqlExprMacro`${Constant(result).ir()}`;
};

const infixEquality: FnCompilationRuleBuilder = (op) => {
  return (args) => {
    Assert.assert(args.length === 2, 'Equality is binary');
    const arg0 = Expression.fromAst(args[0]!);
    const arg1 = Expression.fromAst(args[1]!);

    // If we're comparing two structs, then we need to rewrite the
    // comparison into a pariwise copmarison of all of the struct fields.
    if (arg0.ty.ty.k === 'struct' && arg1.ty.ty.k === 'struct') {
      const arg0Fields = arg0.ty.ty.fields;
      const arg1Fields = arg1.ty.ty.fields;

      // This is covered by the type checker, but it's worth asserting here.
      const fields = Object.keys(arg0Fields);
      Assert.assert(
        _.isEqual(new Set(fields), new Set(Object.keys(arg1Fields))),
        'Structs must have identical fields to be compared. This is an error in the type checker.'
      );

      // We're ready to compare - we're just going to rewrite the expression
      // into the conjunction of the individual field comparisons, and pass
      // the result back into the whole compilation process again.
      const expr = And(
        ...fields.map((field) => arg0.getField(field).eq(arg1.getField(field)))
      );

      return sqlExprMacro`${op === 'eq' ? expr.ir() : Not(expr).ir()}`;
    }

    const sqlOp = { eq: '=', neq: '!=' }[op]!;
    return sqlExprMacro`(${intersperse(args as SqlExprAst, [` ${sqlOp} `])})`;
  };
};

const format: FnCompilationRuleBuilder = (_op) => (args) => {
  if (args.length === 0) {
    return sqlExprMacro`''`;
  }

  return [
    Expression.fromAst({
      t: 'function-call',
      op: 'concat',
      args: args.map((arg) => {
        const casted = Expression.fromAst(arg).cast('string');
        const { ast } = casted.ty.nullable ? casted.coalesce('') : casted;
        return ast;
      }),
    }).ir(),
  ];
}

export const FUNCTION_SQL_RULES: Record<
  AST.FunctionIdentifier,
  FnCompilationRuleBuilder | 'override'
> = {
  lower: functionCall,
  upper: functionCall,
  like: infix,
  format,
  concat: (_op) => (args) =>
    sqlExprMacro`(${intersperse(
      args.map((arg) => sqlExprMacro`(${arg})`),
      [' || ']
    )})`,
  replace: functionCall,
  length: functionCall,
  add: infixAlias('+'),
  sub: infixAlias('-'),
  mul: infixAlias('*'),
  div: infixAlias('/'),
  to_the_power_of: infixAlias('^'),
  split_part: functionCall,
  round: functionCall,
  abs: functionCall,
  ln: functionCall,
  log_2: 'override',
  log_10: 'override',
  floor: functionCall,
  ceil: functionCall,
  cosine_distance: 'override',
  and: infix,
  or: infix,
  eq: infixEquality,
  neq: infixEquality,
  gt: infixAlias('>'),
  gte: infixAlias('>='),
  lt: infixAlias('<'),
  lte: infixAlias('<='),
  now: functionCall,
  not: prefix,
  is_null:
    (_op) =>
    ([arg]) =>
      sqlExprMacro`(${arg!} is null)`,
  is_numeric_string: 'override',
  gen_random_uuid: functionCall,
  random: functionCall,
  sum: functionCall,
  count: functionCall,
  count_distinct:
    (_op) =>
    ([target]) =>
      sqlExprMacro`count(distinct ${target!})`,
  tag:
    (_op) =>
    ([target, _tag]) =>
      sqlExprMacro`${target!}`,
  impure:
    (_op) =>
    ([target]) =>
      sqlExprMacro`${target!}`,
  avg: functionCall,
  min: functionCall,
  max: functionCall,
  string_agg: functionCall,
  date_diff: 'override',
  date_add: 'override',
  date_trunc: 'override',
  date_part: 'override',
  nan: 'override',
  is_nan: 'override',
  null_if: functionCallAlias('nullif'),
  coalesce: functionCall,
  corr: functionCall,
  array_agg: 'override',
  percentile_cont: 'override',
  percentile_disc: 'override',
  stddev_samp: functionCall,
  stddev_pop: functionCall,
  get_from_record: 'override',
  null_of: interpreter,
  type_of: interpreter,
  implements: interpreter,
};
