import { AST } from '../../ast';
import { Ty } from '../../ty';
import { intersperse, sqlExprMacro } from '../sql-ast';
import { SqlDialect, assertConstantString } from './dialect';
import {
  PrimitiveAttributeTypeToPostgresType,
  PostgresDialect,
} from './postgres';
import { Expression } from '../../builder';

const ScalarAttributeTypeToDuckDbNativeType: Record<
  Ty.PrimitiveAttributeType,
  string
> = {
  ...PrimitiveAttributeTypeToPostgresType,
  int: 'int64',
  super: 'json',
};

const typeMapping = (ty: Ty.AttributeType): string => {
  if (ty.k === 'struct') {
    return `struct(${Object.entries(ty.fields)
      .map(([name, { ty }]) => `"${name}" ${typeMapping(ty)}`)
      .join(', ')})`;
  }

  if (ty.k === 'enum') {
    return ScalarAttributeTypeToDuckDbNativeType.string;
  }

  if (ty.k === 'record') {
    return ScalarAttributeTypeToDuckDbNativeType.super;
  }

  if (ty.k === 'array') {
    return `${typeMapping(ty.t.ty)}[]`;
  }

  return ScalarAttributeTypeToDuckDbNativeType[ty.t];
};

export const DuckDbNativeDialect: SqlDialect = {
  ...PostgresDialect,
  makeArray: ({ elements }) => {
    const ast = sqlExprMacro`ARRAY[${intersperse<AST.ExprIR | string>(
      elements,
      ', '
    )}]`;
    return ast;
  },
  functionOverrides: {
    ...PostgresDialect.functionOverrides,
    array_agg: ([arg0]) =>
      sqlExprMacro`array_agg(${arg0!}) filter (where (${arg0!}) is not null)`,
    log_2: ([arg0]) => sqlExprMacro`log2(${arg0!})`,
    log_10: ([arg0]) => sqlExprMacro`log10(${arg0!})`,
    date_part: ([arg0, arg1]) => {
      const unit = assertConstantString(arg1!, AST.DATE_PART_UNITS);
      return sqlExprMacro`cast(date_part('${unit}', (((${arg0!})::timestamptz) at time zone 'UTC')) as int)`;
    },
    date_diff: ([arg0, arg1, arg2]) => {
      return sqlExprMacro`date_diff('${assertConstantString(
        arg2!,
        AST.DATE_DIFF_UNITS
      )}', (((${arg0!})::timestamptz) at time zone 'UTC'), (((${arg1!})::timestamptz) at time zone 'UTC'))`;
    },
    cosine_distance: ([arg0, arg1]) => {
      return sqlExprMacro`(1 - ((${arg0!}) <=> (${arg1!})))`;
    },
  },
  generateSeries: ({ start, stop }) => {
    return sqlExprMacro`select generate_series as n from generate_series(${start.toString()}, ${stop.toString()})`;
  },
  typeMapping,
  makeStruct: (fields) => {
    return sqlExprMacro`{${Object.entries(fields).map(
      ([name, expr]) => sqlExprMacro`"${name}": ${expr},`
    )}}`;
  },
  getPropertyFromStruct(expr, name, _wantedTy) {
    return sqlExprMacro`(${expr}).${this.attr(name)}`;
  },
  supportsFileSources: true,
  cast: (expr, targetTy) => {
    const { ty } = Expression.fromAst(expr);
    if (
      ty.ty.k === 'struct' &&
      targetTy.k === 'primitive' &&
      targetTy.t === 'string'
    ) {
      return sqlExprMacro`to_json(${expr})`;
    }
    if (targetTy.k === 'primitive' && targetTy.t === 'super') {
      return sqlExprMacro`to_json(${expr})`;
    }

    if (ty.ty.k === 'primitive' && ty.ty.t === 'super') {
      return sqlExprMacro`cast((${expr})->>'$' as ${typeMapping(targetTy)})`;
    }

    return sqlExprMacro`cast(${expr} as ${typeMapping(targetTy)})`;
  },
};
