﻿using System;
using System.Collections.Generic;
using System.IO;
using Oni.Imaging;

namespace Oni.Akira
{
    internal class RoomDaeReader
    {
        private readonly PolygonMesh mesh;
        private readonly List<Vector3> positions;
        private readonly Stack<Matrix> nodeTransformStack;
        private Dae.Scene scene;
        private Matrix nodeTransform;

        public static PolygonMesh Read(Dae.Scene scene)
        {
            var reader = new RoomDaeReader();
            reader.ReadScene(scene);
            return reader.mesh;
        }

        private RoomDaeReader()
        {
            mesh = new PolygonMesh(new MaterialLibrary());

            positions = mesh.Points;

            nodeTransformStack = new Stack<Matrix>();
            nodeTransform = Matrix.Identity;
        }

        private void ReadScene(Dae.Scene scene)
        {
            this.scene = scene;

            foreach (Dae.Node node in scene.Nodes)
                ReadNode(node);
        }

        private void ReadNode(Dae.Node node)
        {
            nodeTransformStack.Push(nodeTransform);

            foreach (var transform in node.Transforms)
                nodeTransform = transform.ToMatrix() * nodeTransform;

            foreach (var geometryInstance in node.GeometryInstances)
                ReadGeometryInstance(node, geometryInstance);

            foreach (var child in node.Nodes)
                ReadNode(child);

            nodeTransform = nodeTransformStack.Pop();
        }

        private void ReadGeometryInstance(Dae.Node node, Dae.GeometryInstance instance)
        {
            var geometry = instance.Target;

            foreach (var primitives in geometry.Primitives)
            {
                if (primitives.PrimitiveType != Dae.MeshPrimitiveType.Polygons)
                {
                    Console.Error.WriteLine("Unsupported primitive type '{0}' found in geometry '{1}', ignoring.", primitives.PrimitiveType, geometry.Id);
                    continue;
                }

                ReadPolygonPrimitives(node, primitives, instance.Materials.Find(m => m.Symbol == primitives.MaterialSymbol));
            }
        }

        private void ReadPolygonPrimitives(Dae.Node node, Dae.MeshPrimitives primitives, Dae.MaterialInstance materialInstance)
        {
            var positionInput = primitives.Inputs.FirstOrDefault(i => i.Semantic == Dae.Semantic.Position);
            var positionIndices = ReadInputIndexed(positionInput, positions, Dae.Source.ReadVector3);

            foreach (int i in positionIndices)
                positions[i] = Vector3.Transform(positions[i], ref nodeTransform);

            int startIndex = 0;

            foreach (int vertexCount in primitives.VertexCounts)
            {
                var polygon = CreatePolygon(positionIndices, startIndex, vertexCount);
                startIndex += vertexCount;

                if (polygon == null)
                {
                    Console.Error.WriteLine("BNV polygon: discarded, polygon is degenerate");
                    continue;
                }

                polygon.FileName = node.FileName;
                polygon.ObjectName = node.Name;

                if (Math.Abs(polygon.Plane.Normal.Y) < 0.0001f)
                {
                    if (polygon.BoundingBox.Height < 1.0f)
                    {
                        Console.Error.WriteLine("BNV polygon: discarded, ghost height must be greater than 1, it is {0}", polygon.BoundingBox.Height);
                        continue;
                    }

                    if (polygon.PointIndices.Length != 4)
                    {
                        Console.Error.WriteLine("BNV polygon: discarded, ghost is a {0}-gon", polygon.PointIndices.Length);
                        continue;
                    }

                    mesh.Ghosts.Add(polygon);
                }
                else if ((polygon.Flags & GunkFlags.Horizontal) != 0)
                {
                    mesh.Floors.Add(polygon);
                }
                else
                {
                    Console.Error.WriteLine("BNV polygon: discarded, not a ghost and not a floor");
                }
            }
        }

        private Polygon CreatePolygon(int[] positionIndices, int startIndex, int vertexCount)
        {
            int endIndex = startIndex + vertexCount;

            var indices = new List<int>(vertexCount);

            for (int i = startIndex; i < endIndex; i++)
            {
                int i0 = positionIndices[i == startIndex ? endIndex - 1 : i - 1];
                int i1 = positionIndices[i];
                int i2 = positionIndices[i + 1 == endIndex ? startIndex : i + 1];

                if (i0 == i1)
                {
                    Console.Error.WriteLine("BNV polygon: discarding degenerate edge {0}", mesh.Points[i1]);
                    continue;
                }

                Vector3 p0 = mesh.Points[i0];
                Vector3 p1 = mesh.Points[i1];
                Vector3 p2 = mesh.Points[i2];

                Vector3 p1p0 = p1 - p0;
                Vector3 p2p1 = p2 - p1;

                //if (p1p0.LengthSquared() < 0.000001f)
                //{
                //    Console.Error.WriteLine("BNV polygon: merging duplicate points {0} {1}", p0, p1);
                //    continue;
                //}

                if (Vector3.Cross(p2p1, p1p0).LengthSquared() < 0.000001f)
                {
                    //Console.Error.WriteLine("BNV polygon: combining colinear edges at {0}", p1);
                    continue;
                }

                indices.Add(i1);
            }

            var indicesArray = indices.ToArray();

            if (CheckDegenerate(mesh.Points, indicesArray))
                return null;

            return new Polygon(mesh, indicesArray);
        }

        private static bool CheckDegenerate(List<Vector3> positions, int[] indices)
        {
            if (indices.Length < 3)
                return true;

            Vector3 p0 = positions[indices[0]];
            Vector3 p1 = positions[indices[1]];
            Vector3 s0, s1, c;

            for (int i = 2; i < indices.Length; i++)
            {
                Vector3 p2 = positions[indices[i]];

                Vector3.Substract(ref p0, ref p1, out s0);
                Vector3.Substract(ref p2, ref p1, out s1);
                Vector3.Cross(ref s0, ref s1, out c);

                if (Math.Abs(c.LengthSquared()) < 0.0001f && Vector3.Dot(ref s0, ref s1) > 0.0f)
                    return true;

                p0 = p1;
                p1 = p2;
            }

            return false;
        }

        private static int[] ReadInputIndexed<T>(Dae.IndexedInput input, List<T> list, Func<Dae.Source, int, T> elementReader)
            where T : struct
        {
            var indices = new int[input.Indices.Count];

            for (int i = 0; i < input.Indices.Count; i++)
            {
                var v = elementReader(input.Source, input.Indices[i]);
                indices[i] = list.Count;
                list.Add(v);
            }

            return indices;
        }
    }
}
