# -*- coding: utf-8 -*-
import os
from io import BytesIO
from warnings import warn

import numpy as np

from lutio.core import MetaRegistry

__all__ = ['LUT3DFormat']


class LUT3DFormat(MetaRegistry):
    class Meta:
        name = None
        extensions = None

    supported_sizes = None
    supported_bitdepths = None

    def __init__(self, size, bitdepth=None):
        if size not in self.supported_sizes:
            raise ValueError("`size` must be in {}."
                             .format(self.supported_sizes))

        self.size = size

        if bitdepth and bitdepth not in self.supported_bitdepths:
            raise ValueError("`bitdepth` must be in {}."
                             .format(self.supported_bitdepths))

        self.bitdepth = bitdepth

    def to_file(self, lut, filename):
        raise NotImplementedError

    def _prepare_to_file_(self, lut, filename):
        s = lut.shape

        if lut.ndim != 4 or s[3] != 3 or s.count(s[0]) != 3:
            raise ValueError("`lut` is not a proper LUT3D")

        # TODO
        if s[0] != self.size:
            raise NotImplementedError("Cannot resize LUT3D")

        lut = lut.clip(0, 1)

        if self.bitdepth:
            max_ = 2 ** self.bitdepth - 1
            lut = np.round(lut * max_).astype(np.int).clip(0, max_)

        if os.path.splitext(filename)[1] not in self.Meta.extensions:
            warn("Using uncommon file extension for {} format."
                 .format(self.Meta.name), UserWarning)

        return lut, filename

    @staticmethod
    def _lut_to_str_(lut, fmt="%08.6f", **kwargs):
        if lut.ndim > 1:
            lut = lut.reshape(-1, lut.shape[-1])

        with BytesIO() as buffer:
            np.savetxt(buffer, lut, fmt=fmt, **kwargs)
            out = buffer.getvalue().decode().rstrip()

        return out
