﻿using System;
using System.Collections.Generic;
using Oni.Collections;

namespace Oni.Dae
{
    internal class AxisConverter
    {
        private Scene scene;
        private Axis fromUpAxis;
        private Axis toUpAxis;
        private Set<float[]> convertedValues;

        public static void Convert(Scene scene, Axis fromUpAxis, Axis toUpAxis)
        {
            var converter = new AxisConverter {
                scene = scene,
                fromUpAxis = fromUpAxis,
                toUpAxis = toUpAxis,
                convertedValues = new Set<float[]>()
            };

            converter.Convert();
        }

        private void Convert()
        {
            Convert(scene);
        }

        private void Convert(Node node)
        {
            foreach (var transform in node.Transforms)
                Convert(transform);

            foreach (var instance in node.Instances)
                Convert(instance);

            foreach (var child in node.Nodes)
                Convert(child);
        }

        private void Convert(Instance instance)
        {
            var geometryInstance = instance as GeometryInstance;

            if (geometryInstance != null)
            {
                Convert(geometryInstance.Target);
                return;
            }
        }

        private void Convert(Geometry geometry)
        {
            foreach (var primitives in geometry.Primitives)
            {
                //
                // HACK: this assumes that position and normal sources are never reused with other semantic
                //

                foreach (var input in primitives.Inputs)
                {
                    if (input.Semantic == Semantic.Position || input.Semantic == Semantic.Normal)
                        ConvertPosition(input.Source.FloatData, input.Source.Stride);
                }
            }
        }

        private void Convert(Transform transform)
        {
            var scale = transform as TransformScale;

            if (scale != null)
            {
                ConvertScale(scale.Values, 3);

                if (transform.HasAnimations)
                    ConvertScaleAnimation(transform);

                return;
            }

            var rotate = transform as TransformRotate;

            if (rotate != null)
            {
                //ConvertPosition(rotate.Values, 3);

                if (transform.HasAnimations)
                    ConvertRotationAnimation(transform);

                return;
            }

            var translate = transform as TransformTranslate;

            if (translate != null)
            {
                ConvertPosition(translate.Values, 3);

                if (transform.HasAnimations)
                    ConvertPositionAnimation(transform);

                return;
            }

            var matrix = transform as TransformMatrix;

            if (matrix != null)
            {
                ConvertMatrix(matrix);

                //
                // TODO: matrix animation???
                //

                return;
            }
        }

        private void ConvertMatrix(TransformMatrix transform)
        {
            if (fromUpAxis == Axis.Z && toUpAxis == Axis.Y)
            {
                Matrix zm = transform.Matrix;
                Matrix ym = zm;
                //Matrix ym = Matrix.CreateRotationX(
                //    MathHelper.PiOver2) *
                //zm *
                //Matrix.CreateRotationX(
                //    -MathHelper.PiOver2);
                ym.M12 = zm.M13;
                ym.M13 = -zm.M12;
                ym.M21 = zm.M31;
                ym.M22 = zm.M33;
                ym.M23 = -zm.M32;
                ym.M31 = -zm.M21;
                ym.M32 = -zm.M23;
                ym.M33 = zm.M22;
                ym.M42 = zm.M43;
                ym.M43 = -zm.M42;
                transform.Matrix = ym;
            }
            //else if (fromUpAxis == Axis.Y && toUpAxis == Axis.Z)
            //{
            //    rotate.XAxis = new Vector3(1.0f, 0.0f, 0.0f);
            //    rotate.YAxis = new Vector3(0.0f, 0.0f, -1.0f);
            //    rotate.ZAxis = new Vector3(0.0f, 1.0f, 0.0f);
            //}
            //else if (fromUpAxis == Axis.X && toUpAxis == Axis.Y)
            //{
            //    rotate.XAxis = new Vector3(0.0f, 0.0f, -1.0f);
            //    rotate.YAxis = new Vector3(1.0f, 0.0f, 1.0f);
            //    rotate.ZAxis = new Vector3(0.0f, 1.0f, 0.0f);
            //}
        }

        private void ConvertPosition(float[] values, int stride)
        {
            if (!convertedValues.Add(values))
                return;

            for (int i = 0; i + stride - 1 < values.Length; i += stride)
                Convert(values, i, f => -f);
        }

        private void ConvertPositionAnimation(Transform transform)
        {
            Convert(transform.Animations, 0, s => s != null ? s.Scale(-1.0f) : null);
        }

        private void ConvertRotationAnimation(Transform transform)
        {
            ConvertPositionAnimation(transform);
        }

        private void ConvertScale(float[] values, int stride)
        {
            for (int i = 0; i + stride - 1 < values.Length; i += stride)
                Convert(values, i, null);
        }

        private void ConvertScaleAnimation(Transform transform)
        {
            Convert(transform.Animations, 0, null);
        }

        private void Convert<T>(IList<T> list, int baseIndex, Func<T, T> negate)
        {
            T t0 = list[baseIndex + 0];
            T t1 = list[baseIndex + 1];
            T t2 = list[baseIndex + 2];

            if (fromUpAxis == Axis.Z && toUpAxis == Axis.Y)
            {
                list[baseIndex + 0] = t0;
                list[baseIndex + 1] = t2;
                list[baseIndex + 2] = negate != null ? negate(t1) : t1;
            }
            else if (fromUpAxis == Axis.Y && toUpAxis == Axis.Z)
            {
                list[baseIndex + 0] = t0;
                list[baseIndex + 1] = negate != null ? negate(t2) : t2;
                list[baseIndex + 2] = t1;
            }
            else if (fromUpAxis == Axis.X && toUpAxis == Axis.Y)
            {
                list[baseIndex + 0] = negate != null ? negate(t2) : t2;
                list[baseIndex + 1] = t0;
                list[baseIndex + 2] = t1;
            }
        }
    }
}
