# -*- coding: utf-8 -*-
import itertools

import numpy as np
from lutio.ops import is3dlut
from scipy.interpolate import RegularGridInterpolator

__all__ = ['resize', 'TetrahedralInterpolator']


class TetrahedralInterpolator:
    tetrahedral_faces = {
        5: np.array([[0, 1, 3, 7]]),
        7: np.array([[0, 2, 3, 7]]),
        11: np.array([[0, 1, 5, 7]]),
        15: np.array([[0, 2, 6, 7]]),
        19: np.array([[0, 4, 5, 7]]),
        21: np.array([[0, 4, 6, 7]]),
    }

    tetrahedral_vertices = np.array(list(itertools.product([0, 1], repeat=3)))

    def __init__(self, values):
        if values.ndim != 4:
            raise TypeError("`values` must be 4-dim")

        self.values = values
        self.shape = np.asarray(values.shape[:3])

    def __call__(self, xi):
        return self._compute_tetrahedral(xi)

    def _compute_tetrahedral(self, xi):
        xi = np.atleast_2d(xi)

        frac, ix = np.modf(xi * (self.shape - 1))

        v_ixs = frac.argsort(axis=-1)
        f_ixs = v_ixs.argsort(axis=-1)

        cv_ixs = np.ravel_multi_index(v_ixs.T, (3, 3, 3))

        v_ixs = np.empty((len(cv_ixs), 4), dtype=np.int)
        for k, v in self.tetrahedral_faces.items():
            v_ixs[cv_ixs == k] = v

        ixs = ix.reshape(-1, 1, 3).astype(np.int) + \
              self.tetrahedral_vertices.take(v_ixs, axis=0)
        ixs = ixs.clip(0, self.shape - 1)

        a, b, c = frac.take(f_ixs).T

        yi = (self.values[ixs[..., 0], ixs[..., 1], ixs[..., 2]].T *
              [1 - c, c - b, b - a, a]).T.sum(axis=1)

        return yi


def resize(lut, shape, method='tetrahedral'):
    if is3dlut(lut):
        if method == 'tetrahedral':
            interp = TetrahedralInterpolator(lut)
        elif method == 'trilinear':
            x = np.linspace(0, 1, lut.shape[0])
            y = np.linspace(0, 1, lut.shape[1])
            z = np.linspace(0, 1, lut.shape[2])

            interp = RegularGridInterpolator((x, y, z), lut)
        else:
            raise NotImplementedError

        xn = np.linspace(0, 1, shape[0])
        yn = np.linspace(0, 1, shape[1])
        zn = np.linspace(0, 1, shape[2])

        return interp(np.transpose(np.array(np.meshgrid(xn, yn, zn)),
                                   (0, 2, 1, 3)).reshape(-1, 3))

    else:  # 1D LUT
        ixs = np.linspace(0, 1, lut.shape[0])
        ixns = np.linspace(0, 1, shape[0])

        return np.column_stack(np.interp(ixns, ixs, c) for c in lut.T)
