﻿using System;
using System.Collections.Generic;

namespace Oni.Dae
{
    internal class Sampler : Entity
    {
        private readonly List<Input> inputs = new List<Input>();
        private float outputScale = 1.0f;

        public List<Input> Inputs => inputs;

        public int FrameCount
        {
            get
            {
                var input = inputs.Find(i => i.Semantic == Semantic.Input);

                if (input == null)
                    return 0;

                return FMath.RoundToInt32(input.Source.FloatData.Last() * 60.0f) + 1;
            }
        }

        public Sampler Scale(float scale)
        {
            var newSampler = new Sampler
            {
                outputScale = scale
            };

            newSampler.inputs.AddRange(inputs);

            return newSampler;
        }

        public Sampler Split(int offset)
        {
            var newSampler = new Sampler();

            foreach (var input in inputs)
            {
                var source = input.Source;

                switch (input.Semantic)
                {
                    case Semantic.Input:
                        newSampler.inputs.Add(input);
                        break;

                    case Semantic.Interpolation:
                        newSampler.inputs.Add(input);
                        break;

                    case Semantic.Output:
                        {
                            var data = new float[source.Count];

                            for (int i = 0; i < data.Length; i++)
                                data[i] = source.FloatData[i * source.Stride + offset];

                            newSampler.inputs.Add(new Input
                            {
                                Source = new Source(data, 1),
                                Semantic = input.Semantic
                            });
                        }
                        break;

                    case Semantic.InTangent:
                    case Semantic.OutTangent:
                        {
                            var data = new float[source.Count * 2];

                            for (int i = 0; i < data.Length; i++)
                            {
                                data[i + 0] = source.FloatData[i * source.Stride];
                                data[i + 1] = source.FloatData[i * source.Stride + (offset + 1)];
                            }

                            newSampler.inputs.Add(new Input
                            {
                                Source = new Source(data, 2),
                                Semantic = input.Semantic
                            });
                        }
                        break;
                }
            }

            return newSampler;
        }

        public float[] Sample() => Sample(0, FrameCount - 1);

        public float[] Sample(int start, int end)
        {
            var result = Sample(start, end, 0);

            if (outputScale != 1.0f)
            {
                for (int i = 0; i < result.Length; i++)
                    result[i] *= outputScale;
            }

            return result;
        }

        private float[] Sample(int start, int end, int offset)
        {
            float[] input = null;
            float[] output = null;
            int outputStride = 1;
            Vector2[] inTangent = null;
            Vector2[] outTangent = null;
            string[] interpolation = null;

            foreach (var i in inputs)
            {
                switch (i.Semantic)
                {
                    case Semantic.Input:
                        input = i.Source.FloatData;
                        break;

                    case Semantic.Output:
                        output = i.Source.FloatData;
                        outputStride = i.Source.Stride;
                        break;

                    case Semantic.InTangent:
                        inTangent = FloatArrayToVector2Array(i.Source.FloatData);
                        break;

                    case Semantic.OutTangent:
                        outTangent = FloatArrayToVector2Array(i.Source.FloatData);
                        break;

                    case Semantic.Interpolation:
                        interpolation = i.Source.NameData;
                        break;
                }
            }

            if (offset >= outputStride)
                throw new ArgumentException("The offset must be less than the output stride", "offset");

            float[] result = new float[end - start + 1];

            if (input == null || output == null || interpolation == null)
            {
                //
                // If we don't have enough data to sample then we just return 0 for all frames.
                //

                return result;
            }

            if (output.Length == outputStride)
            {
                //
                // If the output contains only one element then use that for all frames.
                //

                for (int i = 0; i < result.Length; i++)
                    result[i] = output[offset];

                return result;
            }

            float inputFirst = input.First();
            float outputFirst = output[offset];

            float inputLast = input.Last();
            float outputLast = output[output.Length - outputStride + offset];

            for (int frame = 0; frame < result.Length; frame++)
            {
                float t = (frame + start) / 60.0f;

                if (t <= inputFirst)
                {
                    result[frame] = outputFirst;
                    continue;
                }

                if (t >= inputLast)
                {
                    result[frame] = outputLast;
                    continue;
                }

                int index = Array.BinarySearch(input, t);

                if (index >= 0)
                {
                    result[frame] = output[index * outputStride + offset];
                    continue;
                }

                index = ~index;

                if (index == 0)
                {
                    result[frame] = outputFirst;
                    continue;
                }

                if (index * outputStride + offset >= output.Length)
                {
                    result[frame] = outputLast;
                    continue;
                }

                var p0 = new Vector2(input[index - 1], output[(index - 1) * outputStride + offset]);
                var p1 = new Vector2(input[index], output[index * outputStride + offset]);

                float s = (t - p0.X) / (p1.X - p0.X);

                switch (interpolation[index - 1])
                {
                    default:
                        Console.Error.WriteLine("Interpolation type '{0}' is not supported, using LINEAR", interpolation[index - 1]);
                        goto case "LINEAR";

                    case "LINEAR":
                        result[frame] = p0.Y + s * (p1.Y - p0.Y);
                        break;

                    case "BEZIER":
                        if (inTangent == null || outTangent == null)
                            throw new System.IO.InvalidDataException("Bezier interpolation was specified but in/out tangents are not present");

                        var c0 = outTangent[index - 1];
                        var c1 = inTangent[index];

                        float invS = 1.0f - s;

                        result[frame] =
                              p0.Y * invS * invS * invS
                            + 3.0f * c0.Y * invS * invS * s
                            + 3.0f * c1.Y * invS * s * s
                            + p1.Y * s * s * s;

                        break;
                }
            }

            return result;
        }

        private static Vector2[] FloatArrayToVector2Array(float[] array)
        {
            var result = new Vector2[array.Length / 2];

            for (int i = 0; i < result.Length; i++)
            {
                result[i].X = array[i * 2 + 0];
                result[i].Y = array[i * 2 + 1];
            }

            return result;
        }
    }
}
