import _ from 'lodash';
import { err, ok, Result } from 'neverthrow';
import * as Errs from '../type-check-error';
import { AST } from '../../ast';
import { Ty } from '../../ty';
import { implementsTy, narrowestSuperTypeOf } from './../implements';
import { getFromRecordTypingRule } from './record-funcs';
import { dateTruncTypingRule } from './time-funcs';
import {
  funcCallShorthandToFuncCall,
  satisfiesOneOf,
  satisfiesSignature,
} from './satisfies-signature';
import {
  implementsTypingRule,
  nullOfTypingRule,
  tagTypingRule,
  typeOfTypingRule,
} from './type-level-func-rules';
import { Assert } from '../../utils';

export const ALLOWED_PRIMITIVE_CASTS: Record<
  Ty.PrimitiveAttributeType,
  readonly Ty.PrimitiveAttributeType[]
> = {
  boolean: ['boolean', 'string', 'int', 'super'],
  timestamp: ['timestamp', 'string', 'super'],
  day: ['timestamp', 'string', 'super'],
  month: ['timestamp', 'day', 'string', 'super'],
  year: ['timestamp', 'month', 'day', 'string', 'super'],
  float: ['float', 'int', 'string', 'super'],
  int: ['int', 'float', 'string', 'boolean', 'super'],
  string: [
    'string',
    'timestamp',
    'int',
    'float',
    'boolean',
    'super',
    'day',
    'month',
    'year',
  ],
  super: [
    'boolean',
    'timestamp',
    'float',
    'int',
    'string',
    'super',
    'day',
    'month',
    'year',
  ],
};

export type FuncCallSignature = readonly [
  readonly Ty.ExtendedAttributeType[],
  Ty.ExtendedAttributeType
];

export type FuncCallSignatureShorthand = readonly [
  readonly Ty.Shorthand[],
  Ty.Shorthand
];

const MATH_FUNC_SIGNATURES: FuncCallSignatureShorthand[] = [
  [['int', 'int'], 'int'],
  [['float', 'int'], 'float'],
  [['int', 'float'], 'float'],
  [['float', 'float'], 'float'],
];

export type FunctionTypingRule = (
  name: AST.FunctionIdentifier,
  args: readonly Ty.ExtendedAttributeType[]
) => Result<Ty.ExtendedAttributeType, Errs.TypeCheckError>;

type FunctionTypingModifier = (
  args: readonly Ty.ExtendedAttributeType[],
  res: Ty.ExtendedAttributeType
) => Ty.ExtendedAttributeType;

const alwaysNullable: FunctionTypingModifier = (_args, res) =>
  Ty.makeNullable(res);

const withModifiers =
  (
    r: FunctionTypingRule,
    modifiers: FunctionTypingModifier[]
  ): FunctionTypingRule =>
  (name, args) => {
    const res = r(name, args);
    if (res.isErr()) {
      return res;
    }
    return ok(
      modifiers.reduce<Ty.ExtendedAttributeType>(
        (curr, modifier) => modifier(args, curr),
        res.value
      )
    );
  };

const binaryMathRules: FunctionTypingRule =
  satisfiesOneOf(MATH_FUNC_SIGNATURES);

const equalityComparisonRule: FunctionTypingRule = (name, args) => {
  const [lhs, rhs] = args;
  Assert.assert(lhs !== undefined && rhs !== undefined);
  // We want to allow struct equality, so this ones a bit funny. Basically,
  // we're going to check if we've got two struct arguments and handle
  // those, and fall back to the standard binaryComparisonRules otherwise.

  const notComparable = new Errs.TypesNotComparable({
    op: name,
    left: lhs,
    right: rhs,
  });

  if (lhs.ty.k === 'record' || rhs.ty.k === 'record') {
    return err(notComparable);
  }

  if (lhs.ty.k === 'struct' && rhs.ty.k === 'struct') {
    if (implementsTy({ subject: lhs, req: rhs })) {
      return ok({
        ...Ty.shorthandToTy('boolean'),
        nullable: lhs.nullable || rhs.nullable,
        tags: _.intersection(lhs.tags, rhs.tags),
      });
    }
    return err(notComparable);
  }

  const st = narrowestSuperTypeOf([lhs, rhs]);

  return st.isOk()
    ? ok({
        ty: { k: 'primitive', t: 'boolean' },
        nullable: lhs.nullable || rhs.nullable,
        tags: st.value.tags,
      })
    : err(notComparable);
};

const connective: FunctionTypingRule = (op, args) => {
  if (args.every((arg) => implementsTy({ subject: arg, req: 'boolean' }))) {
    return ok({
      ty: { k: 'primitive', t: 'boolean' },
      nullable: args.some((arg) => arg.nullable),
      tags: [],
    });
  } else {
    return err(
      new Errs.InvalidFunctionCall({
        op,
        recieved: args,
        allowed: [[args.map((_) => Ty.ty('boolean')), Ty.ty('boolean')]],
      })
    );
  }
};

const concatRule: FunctionTypingRule = (op, args) => {
  if (args.every((arg) => implementsTy({ subject: arg, req: 'string' }))) {
    return ok({
      ty: { k: 'primitive', t: 'string' },
      nullable: args.some((arg) => arg.nullable),
      tags: [],
    });
  } else {
    return err(
      new Errs.InvalidFunctionCall({
        op,
        recieved: args,
        allowed: [[args.map((_) => Ty.ty('string')), Ty.ty('string')]],
      })
    );
  }
};

export const FUNCTION_TYPING_RULES: Record<
  AST.FunctionIdentifier,
  {
    aggregate?: boolean;
    rule: FunctionTypingRule;
  }
> = {
  round: {
    rule: satisfiesOneOf([
      [['int', 'int'], 'float'],
      [['float', 'int'], 'float'],
    ]),
  },
  split_part: {
    // This has weird behaviour across the dialects when used with nullable inputs.
    // I'm just punting on that and requiring people to prove the inputs arent null
    rule: satisfiesSignature(
      [Ty.nn('string'), Ty.nn('string'), Ty.nn('int')],
      'string'
    ),
  },
  lower: { rule: satisfiesSignature(['string'], 'string') },
  upper: { rule: satisfiesSignature(['string'], 'string') },
  replace: {
    rule: satisfiesSignature(['string', 'string', 'string'], 'string'),
  },
  like: { rule: satisfiesSignature(['string', 'string'], 'boolean') },
  format: {
    rule: () =>
      ok({
        ty: { k: 'primitive', t: 'string' },
        nullable: false,
        tags: [],
      }),
  },
  concat: { rule: concatRule },
  length: { rule: satisfiesSignature(['string'], 'int') },

  // MATH_FUNCTIONS
  add: { rule: binaryMathRules },
  sub: { rule: binaryMathRules },
  mul: { rule: binaryMathRules },
  div: { rule: binaryMathRules },
  to_the_power_of: { rule: binaryMathRules },

  abs: {
    rule: satisfiesOneOf((['int', 'float'] as const).map((ty) => [[ty], ty])),
  },
  ln: {
    rule: satisfiesOneOf(
      (['int', 'float'] as const).map((ty) => [[ty], 'float'])
    ),
  },
  log_2: {
    rule: satisfiesOneOf([[['float'], 'float']]),
  },
  log_10: {
    rule: satisfiesOneOf([[['float'], 'float']]),
  },
  floor: {
    rule: satisfiesOneOf((['int', 'float'] as const).map((ty) => [[ty], ty])),
  },
  ceil: {
    rule: satisfiesOneOf((['int', 'float'] as const).map((ty) => [[ty], ty])),
  },
  cosine_distance: {
    rule: satisfiesOneOf([
      [
        [
          { k: 'array', t: 'float' },
          { k: 'array', t: 'float' },
        ],
        'float',
      ],
    ]),
  },
  // CONNECTIVES
  and: { rule: connective },
  or: { rule: connective },
  // COMPARISIONS
  eq: { rule: equalityComparisonRule },
  neq: { rule: equalityComparisonRule },
  gt: { rule: equalityComparisonRule },
  gte: { rule: equalityComparisonRule },
  lt: { rule: equalityComparisonRule },
  lte: { rule: equalityComparisonRule },
  // RECORD
  get_from_record: {
    rule: getFromRecordTypingRule,
  },
  // Utility
  tag: {
    rule: tagTypingRule,
  },
  type_of: {
    rule: typeOfTypingRule,
  },
  null_of: {
    rule: nullOfTypingRule,
  },
  implements: {
    rule: implementsTypingRule,
  },
  now: {
    rule: satisfiesSignature([], 'timestamp'),
  },
  not: {
    rule: satisfiesSignature(['boolean'], 'boolean'),
  },
  is_null: {
    rule: (_name, _args) =>
      ok(
        Ty.shorthandToTy({
          ty: 'boolean',
          nullable: false,
        })
      ),
  },
  impure: {
    rule: (_name, args) => {
      const arg = args[0];
      Assert.assert(arg !== undefined);
      return ok(arg);
    },
  },
  gen_random_uuid: {
    rule: (name, args) =>
      satisfiesSignature([], 'string')(name, args).map((ty) => ({
        ...ty,
        nullable: false,
      })),
  },
  random: {
    rule: (name, args) =>
      satisfiesSignature([], 'float')(name, args).map((ty) => ({
        ...ty,
        nullable: false,
      })),
  },
  // Aggregate functions
  array_agg: {
    aggregate: true,
    rule: (_name, args) => {
      const arg = args[0];
      Assert.assert(arg !== undefined);
      return ok(Ty.nn(Ty.a(Ty.nn(arg))));
    },
  },
  string_agg: {
    aggregate: true,
    rule: withModifiers(satisfiesSignature(['string', 'string'], 'string'), [
      alwaysNullable,
    ]),
  },
  sum: {
    aggregate: true,
    rule: withModifiers(
      satisfiesOneOf([
        [['int'], 'float'],
        [['float'], 'float'],
      ]),
      [alwaysNullable]
    ),
  },
  count: {
    aggregate: true,
    rule: withModifiers(
      satisfiesOneOf([
        ...Ty.PRIMITIVE_ATTRIBUTE_TYPES.map((ty) => [[ty], 'int'] as const),
        [[], 'int'],
      ]),
      [alwaysNullable]
    ),
  },
  avg: {
    aggregate: true,
    rule: withModifiers(
      satisfiesOneOf(
        (['int', 'float'] as const).map((ty) => [[ty], 'float' as const])
      ),
      [alwaysNullable]
    ),
  },
  min: {
    aggregate: true,
    rule: withModifiers(
      satisfiesOneOf(Ty.PRIMITIVE_ATTRIBUTE_TYPES.map((ty) => [[ty], ty])),
      [alwaysNullable]
    ),
  },
  max: {
    aggregate: true,
    rule: withModifiers(
      satisfiesOneOf(Ty.PRIMITIVE_ATTRIBUTE_TYPES.map((ty) => [[ty], ty])),
      [alwaysNullable]
    ),
  },

  date_diff: {
    rule: satisfiesSignature(['timestamp', 'timestamp', 'string'], 'int'),
  },
  date_add: {
    rule: satisfiesSignature(['timestamp', 'int', 'string'], 'timestamp'),
  },
  date_trunc: {
    rule: dateTruncTypingRule,
  },
  date_part: {
    rule: satisfiesSignature(['timestamp', 'string'], 'int'),
  },
  nan: {
    rule: satisfiesSignature([], 'float'),
  },
  is_numeric_string: { rule: satisfiesSignature(['string'], 'boolean') },
  is_nan: {
    rule: satisfiesOneOf([
      [['float'], 'boolean'],
      [['int'], 'boolean'],
    ]),
  },
  null_if: {
    rule: withModifiers(
      satisfiesOneOf([
        ...Ty.PRIMITIVE_ATTRIBUTE_TYPES.map(
          (ty) => [[ty, ty] as const, ty] as const
        ),
        [['float', 'int'], 'float'] as const,
        [['int', 'float'], 'float'] as const,
      ]),
      [alwaysNullable]
    ),
  },
  coalesce: {
    rule: withModifiers(
      satisfiesOneOf([
        ...Ty.PRIMITIVE_ATTRIBUTE_TYPES.map((ty) => [[ty, ty], ty] as const),
        [['float', 'int'], 'float'],
        [['int', 'float'], 'float'],
      ]),
      [
        (args, res) => {
          const [subject, fallback] = args;
          return { ...res, nullable: subject!.nullable && fallback!.nullable };
        },
      ]
    ),
  },
  corr: {
    aggregate: true,
    rule: withModifiers(
      satisfiesOneOf([
        [['int', 'float'], 'float'],
        [['float', 'int'], 'float'],
        [['int', 'int'], 'float'],
        [['float', 'float'], 'float'],
      ]),
      [alwaysNullable]
    ),
  },
  percentile_cont: {
    aggregate: true,
    rule: withModifiers(
      satisfiesOneOf(
        Ty.PRIMITIVE_ATTRIBUTE_TYPES.flatMap((ty) => [
          [[ty, 'int'], ty],
          [[ty, 'float'], ty],
        ])
      ),
      [alwaysNullable]
    ),
  },
  percentile_disc: {
    aggregate: true,
    rule: withModifiers(
      satisfiesOneOf(
        (['int', 'float'] as const).flatMap((ty) => [
          [[ty, 'int'], ty],
          [[ty, 'float'], ty],
        ])
      ),
      [alwaysNullable]
    ),
  },
  stddev_samp: {
    aggregate: true,
    rule: withModifiers(
      satisfiesOneOf([
        [['int'], 'float'],
        [['float'], 'float'],
      ]),
      [alwaysNullable]
    ),
  },
  stddev_pop: {
    aggregate: true,
    rule: withModifiers(
      satisfiesOneOf([
        [['int'], 'float'],
        [['float'], 'float'],
      ]),
      [alwaysNullable]
    ),
  },
  count_distinct: {
    aggregate: true,
    rule: withModifiers(
      satisfiesOneOf([
        ...Ty.PRIMITIVE_ATTRIBUTE_TYPES.map((ty) => [[ty], 'int'] as const),
        [[], 'int'],
      ]),
      [alwaysNullable]
    ),
  },
};

const LEAD_AND_LAG_FUNC_ARG_TYPES: FuncCallSignatureShorthand[] =
  Ty.PRIMITIVE_ATTRIBUTE_TYPES.flatMap((ty) => [
    // (expr) => expr
    [[ty], ty],
    // (expr[, offset]) => expr
    [[ty, 'int'], ty],
  ]);

const WINDOW_FUNC_SIGNATURES_: Record<
  AST.WindowOp,
  readonly FuncCallSignatureShorthand[]
> = {
  ntile: [[['int'], 'int']],
  rank: [[[], 'int']],
  row_number: [[[], 'int']],
  dense_rank: [[[], 'int']],
  percent_rank: [[[], 'float']],
  // TODO these functions affect the nullability of the underlying type (forcing it to be nullable)
  lead: LEAD_AND_LAG_FUNC_ARG_TYPES,
  lag: LEAD_AND_LAG_FUNC_ARG_TYPES,
  first_value: Ty.PRIMITIVE_ATTRIBUTE_TYPES.map((ty) => [[ty], ty]),
  last_value: Ty.PRIMITIVE_ATTRIBUTE_TYPES.map((ty) => [[ty], ty]),
  min: Ty.PRIMITIVE_ATTRIBUTE_TYPES.map((ty) => [[ty], ty]),
  max: Ty.PRIMITIVE_ATTRIBUTE_TYPES.map((ty) => [[ty], ty]),
  sum: [
    [['int'], 'float'],
    [['float'], 'float'],
  ],
  avg: [
    [['int'], 'float'],
    [['float'], 'float'],
  ],
  count: Ty.PRIMITIVE_ATTRIBUTE_TYPES.map((ty) => [[ty], 'int']),
  corr: [
    [['int', 'float'], 'float'],
    [['float', 'int'], 'float'],
    [['int', 'int'], 'float'],
    [['float', 'float'], 'float'],
  ],
  string_agg: [[['string'], 'string']],
};

export const WINDOW_FUNC_SIGNATURES: Record<
  AST.WindowOp,
  readonly FuncCallSignature[]
> = _.mapValues(WINDOW_FUNC_SIGNATURES_, (signatures) =>
  signatures.map((sig) => funcCallShorthandToFuncCall(sig))
);

export const WINDOW_FUNCTION_FRAME_RULES: Record<
  AST.WindowOp,
  'required' | 'invalid'
> = {
  // Aggregating
  sum: 'required',
  count: 'required',
  avg: 'required',
  min: 'required',
  max: 'required',
  first_value: 'required',
  last_value: 'required',
  string_agg: 'required',
  // Ranking  (Redshift does not support frames here...)
  corr: 'invalid',
  ntile: 'invalid',
  percent_rank: 'invalid',
  row_number: 'invalid',
  rank: 'invalid',
  dense_rank: 'invalid',
  // Frame Clause is Nonsensical
  lead: 'invalid',
  lag: 'invalid',
};
