﻿using System;
using System.Collections.Generic;

namespace Oni.Akira
{
    internal class OctreeNode
    {
        public const int FaceCount = 6;
        public const int ChildCount = 8;

        private const float MinNodeSize = 16.0f;
        private const int MaxQuadsPerLeaf = 4096;
        private const int MaxRoomsPerLeaf = 255;

        #region Private data
        private int index;
        private BoundingBox bbox;
        private Polygon[] polygons;
        private OctreeNode[] children;
        private OctreeNode[] adjacency = new OctreeNode[FaceCount];
        private Room[] rooms;
        #endregion

        #region public enum Axis

        public enum Axis
        {
            Z,
            Y,
            X
        }

        #endregion
        #region public enum Direction

        public enum Direction
        {
            Negative,
            Positive
        }

        #endregion
        #region public struct Face

        public struct Face
        {
            private readonly int index;

            public Face(int index)
            {
                this.index = index;
            }

            public int Index => index;
            public Axis Axis => (Axis)(2 - ((index & 6) >> 1));
            public Direction Direction => (Direction)(index & 1);

            public static IEnumerable<Face> All
            {
                get
                {
                    for (int i = 0; i < FaceCount; i++)
                        yield return new Face(i);
                }
            }
        }

        #endregion
        #region public struct ChildPosition

        public struct ChildPosition
        {
            private int index;

            public ChildPosition(int index)
            {
                this.index = index;
            }

            public int Index => index;

            public int X => this[Axis.X];
            public int Y => this[Axis.Y];
            public int Z => this[Axis.Z];

            public int this[Axis axis]
            {
                get
                {
                    return ((index >> (int)axis) & 1);
                }
                set
                {
                    int mask = (1 << (int)axis);

                    if (value == 0)
                        index &= ~mask;
                    else
                        index |= mask;
                }
            }

            public static IEnumerable<ChildPosition> All
            {
                get
                {
                    for (int i = 0; i < 8; i++)
                        yield return new ChildPosition(i);
                }
            }
        }

        #endregion

        public OctreeNode(BoundingBox bbox, IEnumerable<Polygon> polygons, IEnumerable<Room> rooms)
        {
            this.bbox = bbox;
            this.polygons = polygons.ToArray();
            this.rooms = rooms.ToArray();
        }

        private OctreeNode(BoundingBox bbox, Polygon[] polygons, Room[] rooms)
        {
            this.bbox = bbox;
            this.polygons = polygons;
            this.rooms = rooms;
        }

        public int Index
        {
            get { return index; }
            set { index = value; }
        }

        public BoundingBox BoundingBox => bbox;
        public OctreeNode[] Children => children;
        public OctreeNode[] Adjacency => adjacency;
        public bool IsLeaf => polygons != null;
        public ICollection<Polygon> Polygons => polygons;
        public ICollection<Room> Rooms => rooms;
        private Vector3 Center => (bbox.Min + bbox.Max) * 0.5f;
        private float Size => bbox.Max.X - bbox.Min.X;

        public void Build()
        {
            BuildRecursive();

            //
            // Force a split of the root node because the root cannot be a leaf.
            //

            if (children == null)
                Split();
        }

        private void BuildRecursive()
        {
            if ((polygons == null || polygons.Length <= 19) && (rooms == null || rooms.Length < 16))
            {
                return;
            }

            if (Size <= MinNodeSize)
            {
                if (polygons.Length > MaxQuadsPerLeaf)
                    throw new NotSupportedException(string.Format("Octtree: Quad density too big: current {0} max 4096 bbox {1}", polygons.Length, BoundingBox));

                if (rooms.Length > MaxRoomsPerLeaf)
                    throw new NotSupportedException(string.Format("Octtree: Room density too big: current {0} max 255 bbox {1}", rooms.Length, BoundingBox));

                return;
            }

            Split();
        }

        private void Split()
        {
            children = SplitCore();
            polygons = null;
            rooms = null;

            BuildSimpleAdjaceny();

            foreach (var child in children)
                child.BuildRecursive();
        }

        private OctreeNode[] SplitCore()
        {
            var children = new OctreeNode[ChildCount];
            var center = Center;
            var childPolygons = new List<Polygon>(polygons.Length);
            var childRooms = new List<Room>(rooms.Length);

            foreach (var position in ChildPosition.All)
            {
                var childNodeBBox = new BoundingBox(center, center);

                if (position.X == 0)
                    childNodeBBox.Min.X = bbox.Min.X;
                else
                    childNodeBBox.Max.X = bbox.Max.X;

                if (position.Y == 0)
                    childNodeBBox.Min.Y = bbox.Min.Y;
                else
                    childNodeBBox.Max.Y = bbox.Max.Y;

                if (position.Z == 0)
                    childNodeBBox.Min.Z = bbox.Min.Z;
                else
                    childNodeBBox.Max.Z = bbox.Max.Z;

                childPolygons.Clear();
                childRooms.Clear();

                var boxIntersector = new PolygonBoxIntersector(ref childNodeBBox);

                foreach (var polygon in polygons)
                {
                    if (boxIntersector.Intersects(polygon))
                        childPolygons.Add(polygon);
                }

                foreach (var room in rooms)
                {
                    if (room.Intersect(childNodeBBox))
                        childRooms.Add(room);
                }

                children[position.Index] = new OctreeNode(childNodeBBox, childPolygons.ToArray(), childRooms.ToArray());
            }

            return children;
        }

        private void BuildSimpleAdjaceny()
        {
            foreach (ChildPosition position in ChildPosition.All)
            {
                var child = children[position.Index];

                foreach (var face in Face.All)
                    child.Adjacency[face.Index] = GetChildAdjacency(position, face);
            }
        }

        private OctreeNode GetChildAdjacency(ChildPosition position, Face face)
        {
            if (face.Direction == Direction.Positive)
            {
                if (position[face.Axis] == 0)
                {
                    position[face.Axis] = 1;
                    return children[position.Index];
                }
            }
            else
            {
                if (position[face.Axis] == 1)
                {
                    position[face.Axis] = 0;
                    return children[position.Index];
                }
            }

            return adjacency[face.Index];
        }

        public void RefineAdjacency()
        {
            Vector3 center = Center;
            float size = Size;

            foreach (var face in Face.All)
            {
                var node = adjacency[face.Index];

                if (node != null && !node.IsLeaf && node.Size > Size)
                {
                    Vector3 adjacentCenter = MovePoint(center, face, size);
                    adjacency[face.Index] = node.FindLargestOrEqual(adjacentCenter, size);
                }
            }
        }

        public QuadtreeNode BuildFaceQuadTree(Face face)
        {
            Vector3 faceCenter = MovePoint(Center, face, Size * 0.5f);
            var quadTreeNode = new QuadtreeNode(faceCenter, Size, face);
            quadTreeNode.Build(adjacency[face.Index]);
            return quadTreeNode;
        }

        public void DfsTraversal(Action<OctreeNode> action)
        {
            action(this);

            if (!IsLeaf)
            {
                foreach (var child in children)
                    child.DfsTraversal(action);
            }
        }

        public static Vector3 MovePoint(Vector3 point, Face face, float delta)
        {
            if (face.Direction == Direction.Negative)
                delta = -delta;

            if (face.Axis == Axis.X)
                point.X += delta;
            else if (face.Axis == Axis.Y)
                point.Y += delta;
            else
                point.Z += delta;

            return point;
        }

        private struct TriangleBoxIntersector
        {
            private Vector3 center;
            private Vector3 size;
            private Vector3[] triangle;
            private Vector3 edge;

            public TriangleBoxIntersector(ref BoundingBox box)
            {
                center = (box.Min + box.Max) * 0.5f;
                size = (box.Max - box.Min) * 0.5f;

                triangle = new Vector3[3];
                edge = Vector3.Zero;
            }

            public Vector3[] Triangle => triangle;

            public bool Intersect()
            {
                for (int i = 0; i < triangle.Length; i++)
                    triangle[i] -= center;

                edge = triangle[1] - triangle[0];

                if (AxisTest(Y, Z, 0, 2) || AxisTest(Z, X, 0, 2) || AxisTest(X, Y, 2, 1))
                    return false;

                edge = triangle[2] - triangle[1];

                if (AxisTest(Y, Z, 0, 2) || AxisTest(Z, X, 0, 2) || AxisTest(X, Y, 0, 1))
                    return false;

                edge = triangle[0] - triangle[2];

                if (AxisTest(Y, Z, 0, 1) || AxisTest(Z, X, 0, 1) || AxisTest(X, Y, 2, 1))
                    return false;

                return true;
            }

            private const int X = 0;
            private const int Y = 1;
            private const int Z = 2;

            private bool AxisTest(int a1, int a2, int p0, int p1)
            {
                Vector3 v0 = triangle[p0];
                Vector3 v1 = triangle[p1];
                float e1 = edge[a1];
                float e2 = edge[a2];

                float c0 = e2 * v0[a1] - e1 * v0[a2];
                float c1 = e2 * v1[a1] - e1 * v1[a2];
                float rad = Math.Abs(e2) * size[a1] + Math.Abs(e1) * size[a2];

                return (c0 < c1) ? (c0 > rad || c1 < -rad) : (c1 > rad || c0 < -rad);
            }
        }

        private struct PolygonBoxIntersector
        {
            private BoundingBox bbox;
            private TriangleBoxIntersector triangleBoxIntersector;

            public PolygonBoxIntersector(ref BoundingBox bbox)
            {
                this.bbox = bbox;
                this.triangleBoxIntersector = new TriangleBoxIntersector(ref bbox);
            }

            public bool Intersects(Polygon polygon)
            {
                if (!bbox.Intersects(polygon.BoundingBox))
                    return false;

                if (!bbox.Intersects(polygon.Plane))
                    return false;

                var intersector = new TriangleBoxIntersector(ref bbox);
                var points = polygon.Mesh.Points;
                var indices = polygon.PointIndices;

                intersector.Triangle[0] = points[indices[0]];
                intersector.Triangle[1] = points[indices[1]];
                intersector.Triangle[2] = points[indices[2]];

                if (intersector.Intersect())
                    return true;

                if (indices.Length > 3)
                {
                    intersector.Triangle[0] = points[indices[2]];
                    intersector.Triangle[1] = points[indices[3]];
                    intersector.Triangle[2] = points[indices[0]];

                    if (intersector.Intersect())
                        return true;
                }

                return false;
            }
        }

        public OctreeNode FindLargestOrEqual(Vector3 point, float largestSize)
        {
            var node = this;

            while (!node.IsLeaf && node.Size > largestSize)
            {
                Vector3 center = node.Center;

                int nx = (point.X < center.X) ? 0 : 4;
                int ny = (point.Y < center.Y) ? 0 : 2;
                int nz = (point.Z < center.Z) ? 0 : 1;

                var childNode = node.children[nx + ny + nz];

                if (childNode.Size < largestSize)
                    break;

                node = childNode;
            }

            return node;
        }

        public OctreeNode FindLeaf(Vector3 point)
        {
            if (!bbox.Contains(point))
                return null;

            if (children == null)
                return this;

            Vector3 center = Center;

            int nx = (point.X < center.X) ? 0 : 4;
            int ny = (point.Y < center.Y) ? 0 : 2;
            int nz = (point.Z < center.Z) ? 0 : 1;

            OctreeNode childNode = children[nx + ny + nz];

            return childNode.FindLeaf(point);
        }

        public IEnumerable<OctreeNode> FindLeafs(BoundingBox box)
        {
            var stack = new Stack<OctreeNode>();
            stack.Push(this);

            while (stack.Count > 0)
            {
                var node = stack.Pop();

                if (node.bbox.Intersects(box))
                {
                    if (node.children != null)
                    {
                        foreach (OctreeNode child in node.children)
                            stack.Push(child);
                    }
                    else
                    {
                        yield return node;
                    }
                }
            }
        }
    }
}
