﻿using System;
using System.Collections.Generic;
using System.IO;
using Oni.Imaging;

namespace Oni.Akira
{
    internal class AkiraDaeReader
    {
        #region Private data
        private readonly PolygonMesh mesh;
        private readonly List<Vector3> positions;
        private readonly Dictionary<Vector3, int> uniquePositions;
        private readonly List<Vector3> normals;
        private readonly List<Vector2> texCoords;
        private readonly Dictionary<Dae.Material, Material> materialMap;
        private readonly Dictionary<string, Material> materialFileMap;
        private readonly Stack<Matrix> nodeTransformStack;

        private Dae.Scene scene;
        private Dictionary<string, AkiraDaeNodeProperties> properties;
        private Matrix nodeTransform;
        private string nodeName;
        #endregion

        public static PolygonMesh Read(IEnumerable<string> filePaths)
        {
            var reader = new AkiraDaeReader();
            var properties = new Dictionary<string, AkiraDaeNodeProperties>();

            foreach (var filePath in filePaths)
                reader.ReadScene(Dae.Reader.ReadFile(filePath), properties);

            return reader.mesh;
        }

        public AkiraDaeReader()
        {
            mesh = new PolygonMesh(new MaterialLibrary());

            positions = mesh.Points;
            uniquePositions = new Dictionary<Vector3, int>();
            texCoords = mesh.TexCoords;
            normals = mesh.Normals;

            materialMap = new Dictionary<Dae.Material, Material>();
            materialFileMap = new Dictionary<string, Material>(StringComparer.OrdinalIgnoreCase);

            nodeTransformStack = new Stack<Matrix>();
            nodeTransform = Matrix.Identity;
        }

        public PolygonMesh Mesh => mesh;

        public void ReadScene(Dae.Scene scene, Dictionary<string, AkiraDaeNodeProperties> properties)
        {
            this.scene = scene;
            this.properties = properties;

            AkiraDaeNodeProperties sceneProperties;
            properties.TryGetValue(scene.Id, out sceneProperties);

            foreach (var node in scene.Nodes)
                ReadNode(node, sceneProperties);
        }

        private void ReadNode(Dae.Node node, AkiraDaeNodeProperties parentNodeProperties)
        {
            AkiraDaeNodeProperties nodeProperties;

            if (!properties.TryGetValue(node.Id, out nodeProperties))
                nodeProperties = parentNodeProperties;
            else if (nodeProperties.HasPhysics)
                return;

            nodeTransformStack.Push(nodeTransform);

            foreach (var transform in node.Transforms)
                nodeTransform = transform.ToMatrix() * nodeTransform;

            nodeName = node.Name;

            foreach (var geometryInstance in node.GeometryInstances)
                ReadGeometryInstance(node, nodeProperties, geometryInstance);

            foreach (var child in node.Nodes)
                ReadNode(child, nodeProperties);

            nodeTransform = nodeTransformStack.Pop();
        }

        private void ReadGeometryInstance(Dae.Node node, AkiraDaeNodeProperties nodeProperties, Dae.GeometryInstance instance)
        {
            foreach (var primitives in instance.Target.Primitives)
            {
                if (primitives.PrimitiveType != Dae.MeshPrimitiveType.Polygons)
                {
                    Console.Error.WriteLine("Unsupported primitive type '{0}' found in geometry '{1}', ignoring.", primitives.PrimitiveType, instance.Name);
                    continue;
                }

                ReadPolygonPrimitives(node, nodeProperties, primitives, instance.Materials.Find(m => m.Symbol == primitives.MaterialSymbol));
            }
        }

        private Material ReadMaterial(Dae.Material material)
        {
            if (material == null || material.Effect == null)
                return null;

            Material polygonMaterial;

            if (materialMap.TryGetValue(material, out polygonMaterial))
                return polygonMaterial;

            Dae.EffectSampler diffuseSampler = null;
            Dae.EffectSampler transparentSampler = null;

            foreach (var texture in material.Effect.Textures)
            {
                if (texture.Channel == Dae.EffectTextureChannel.Diffuse)
                    diffuseSampler = texture.Sampler;
                else if (texture.Channel == Dae.EffectTextureChannel.Transparent)
                    transparentSampler = texture.Sampler;
            }

            if (diffuseSampler == null || diffuseSampler.Surface == null || diffuseSampler.Surface.InitFrom == null)
            {
                //
                // this material doesn't have a diffuse texture
                //

                return null;
            }

            var image = diffuseSampler.Surface.InitFrom;

            if (materialFileMap.TryGetValue(image.FilePath, out polygonMaterial))
                return polygonMaterial;

            polygonMaterial = mesh.Materials.GetMaterial(Path.GetFileNameWithoutExtension(image.FilePath));
            polygonMaterial.ImageFilePath = image.FilePath;

            if (transparentSampler == diffuseSampler)
                polygonMaterial.Flags |= GunkFlags.Transparent | GunkFlags.NoOcclusion | GunkFlags.TwoSided;

            materialFileMap.Add(image.FilePath, polygonMaterial);
            materialMap.Add(material, polygonMaterial);

            return polygonMaterial;
        }

        private void ReadPolygonPrimitives(Dae.Node node, AkiraDaeNodeProperties nodeProperties, Dae.MeshPrimitives primitives, Dae.MaterialInstance materialInstance)
        {
            Material material = null;

            if (materialInstance != null)
                material = ReadMaterial(materialInstance.Target);

            if (material == null)
                material = mesh.Materials.NotFound;

            int[] positionIndices = null;
            int[] texCoordIndices = null;
            int[] normalIndices = null;
            Color[] colors = null;

            foreach (var input in primitives.Inputs)
            {
                switch (input.Semantic)
                {
                    case Dae.Semantic.Position:
                        positionIndices = ReadInputIndexed(input, positions, uniquePositions, PositionReader);
                        break;

                    case Dae.Semantic.TexCoord:
                        texCoordIndices = ReadInputIndexed(input, texCoords, Dae.Source.ReadTexCoord);
                        break;

                    case Dae.Semantic.Normal:
                        normalIndices = ReadInputIndexed(input, normals, Dae.Source.ReadVector3);
                        break;

                    case Dae.Semantic.Color:
                        colors = ReadInput(input, Dae.Source.ReadColor);
                        break;
                }
            }

            if (texCoordIndices == null)
                Console.Error.WriteLine("Geometry '{0}' does not contain texture coordinates.", nodeName);

            int startIndex = 0;
            int degeneratePolygonCount = 0;

            foreach (int vertexCount in primitives.VertexCounts)
            {
                var polygonPointIndices = new int[vertexCount];
                Array.Copy(positionIndices, startIndex, polygonPointIndices, 0, vertexCount);

                if (CheckDegenerate(positions, polygonPointIndices))
                {
                    degeneratePolygonCount++;
                    startIndex += vertexCount;
                    continue;
                }

                var polygon = new Polygon(mesh, polygonPointIndices)
                {
                    FileName = node.FileName,
                    ObjectName = node.Name,
                    Material = material
                };

                if (texCoordIndices != null)
                {
                    polygon.TexCoordIndices = new int[vertexCount];
                    Array.Copy(texCoordIndices, startIndex, polygon.TexCoordIndices, 0, vertexCount);
                }
                else
                {
                    polygon.TexCoordIndices = new int[vertexCount];
                }

                if (normalIndices != null)
                {
                    polygon.NormalIndices = new int[vertexCount];
                    Array.Copy(normalIndices, startIndex, polygon.NormalIndices, 0, vertexCount);
                }

                if (colors != null)
                {
                    polygon.Colors = new Color[vertexCount];
                    Array.Copy(colors, startIndex, polygon.Colors, 0, vertexCount);
                }

                startIndex += vertexCount;

                if (nodeProperties != null)
                {
                    polygon.ScriptId = nodeProperties.ScriptId;
                    polygon.Flags |= nodeProperties.GunkFlags;
                }

                if (material == mesh.Materials.Markers.Ghost)
                    mesh.Ghosts.Add(polygon);
                else if (material == mesh.Materials.Markers.DoorFrame)
                    mesh.Doors.Add(polygon);
                else if (material.Name.StartsWith("bnv_grid_", StringComparison.Ordinal))
                    mesh.Floors.Add(polygon);
                else
                    mesh.Polygons.Add(polygon);
            }

            if (degeneratePolygonCount > 0)
            {
                Console.Error.WriteLine("Ignoring {0} degenerate polygons", degeneratePolygonCount);
            }
        }

        private static bool CheckDegenerate(List<Vector3> positions, int[] positionIndices)
        {
            if (positionIndices.Length < 3)
                return true;

            var p0 = positions[positionIndices[0]];
            var p1 = positions[positionIndices[1]];

            for (int i = 2; i < positionIndices.Length; i++)
            {
                var p2 = positions[positionIndices[i]];

                var s0 = p0 - p1;
                var s1 = p2 - p1;

                Vector3 c;
                Vector3.Cross(ref s0, ref s1, out c);

                if (Math.Abs(c.LengthSquared()) < 0.0001f && Vector3.Dot(s0, s1) > 0.0f)
                    return true;

                p0 = p1;
                p1 = p2;
            }

            return false;
        }

        private static int[] ReadInputIndexed<T>(Dae.IndexedInput input, List<T> list, Func<Dae.Source, int, T> elementReader)
            where T : struct
        {
            var indices = new int[input.Indices.Count];

            for (int i = 0; i < input.Indices.Count; i++)
            {
                var v = elementReader(input.Source, input.Indices[i]);
                indices[i] = list.Count;
                list.Add(v);
            }

            return indices;
        }

        private static int[] ReadInputIndexed<T>(Dae.IndexedInput input, List<T> list, Dictionary<T, int> uniqueList, Func<Dae.Source, int, T> elementReader)
            where T : struct
        {
            var indices = new int[input.Indices.Count];

            for (int i = 0; i < input.Indices.Count; i++)
            {
                var v = elementReader(input.Source, input.Indices[i]);

                int index;

                if (!uniqueList.TryGetValue(v, out index))
                {
                    index = list.Count;
                    list.Add(v);
                    uniqueList.Add(v, index);
                }

                indices[i] = index;
            }

            return indices;
        }

        private static T[] ReadInput<T>(Dae.IndexedInput input, Func<Dae.Source, int, T> elementReader)
            where T : struct
        {
            var values = new T[input.Indices.Count];

            for (int i = 0; i < input.Indices.Count; i++)
                values[i] = elementReader(input.Source, input.Indices[i]);

            return values;
        }

        private Vector3 PositionReader(Dae.Source source, int index)
        {
            Vector3 p = Dae.Source.ReadVector3(source, index);
            Vector3 r;
            Vector3.Transform(ref p, ref nodeTransform, out r);
            return r;
        }
    }
}
