﻿using System;

namespace Oni
{
    internal struct Quaternion : IEquatable<Quaternion>
    {
        public float X;
        public float Y;
        public float Z;
        public float W;

        public Quaternion(Vector3 xyz, float w)
        {
            X = xyz.X;
            Y = xyz.Y;
            Z = xyz.Z;
            W = w;
        }

        public Quaternion(float x, float y, float z, float w)
        {
            X = x;
            Y = y;
            Z = z;
            W = w;
        }

        public Quaternion(Vector4 xyzw)
        {
            X = xyzw.X;
            Y = xyzw.Y;
            Z = xyzw.Z;
            W = xyzw.W;
        }

        private Vector3 XYZ => new Vector3(X, Y, Z);

        public static Quaternion CreateFromAxisAngle(Vector3 axis, float angle)
        {
            float halfAngle = angle * 0.5f;
            float sin = FMath.Sin(halfAngle);
            float cos = FMath.Cos(halfAngle);

            return new Quaternion(axis * sin, cos);
        }

        public void ToAxisAngle(out Vector3 axis, out float angle)
        {
            float halfAngle = FMath.Acos(W);
            float sin = FMath.Sqrt(1 - W * W);

            if (sin < 1e-5f)
            {
                axis = XYZ;
                angle = 0.0f;
            }
            else
            {
                axis = XYZ / sin;
                angle = halfAngle * 2.0f;
            }
        }

        public static Quaternion CreateFromEulerXYZ(float x, float y, float z)
        {
            x = MathHelper.ToRadians(x);
            y = MathHelper.ToRadians(y);
            z = MathHelper.ToRadians(z);

            return CreateFromAxisAngle(Vector3.UnitX, x)
                 * CreateFromAxisAngle(Vector3.UnitY, y)
                 * CreateFromAxisAngle(Vector3.UnitZ, z);
        }

        public Vector3 ToEulerXYZ()
        {
            Vector3 r;

            var p0 = -W;
            var p1 = X;
            var p2 = Y;
            var p3 = Z;
            var e = -1.0f;

            var s = 2.0f * (p0 * p2 + e * p1 * p3);

            if (s > 0.999f)
            {
                r.X = MathHelper.ToDegrees(-2.0f * (float)Math.Atan2(p1, p0));
                r.Y = -90.0f;
                r.Z = 0.0f;
            }
            else if (s < -0.999f)
            {
                r.X = MathHelper.ToDegrees(2.0f * (float)Math.Atan2(p1, p0));
                r.Y = 90.0f;
                r.Z = 0.0f;
            }
            else
            {
                r.X = -MathHelper.ToDegrees((float)Math.Atan2(2.0f * (p0 * p1 - e * p2 * p3), 1.0f - 2.0f * (p1 * p1 + p2 * p2)));
                r.Y = -MathHelper.ToDegrees((float)Math.Asin(s));
                r.Z = -MathHelper.ToDegrees((float)Math.Atan2(2.0f * (p0 * p3 - e * p1 * p2), 1.0f - 2.0f * (p2 * p2 + p3 * p3)));
            }

            return r;
        }

        public static Quaternion CreateFromYawPitchRoll(float yaw, float pitch, float roll)
        {
            float halfRoll = roll * 0.5f;
            float sinRoll = FMath.Sin(halfRoll);
            float cosRoll = FMath.Cos(halfRoll);

            float halfPitch = pitch * 0.5f;
            float sinPitch = FMath.Sin(halfPitch);
            float cosPitch = FMath.Cos(halfPitch);

            float halfYaw = yaw * 0.5f;
            float sinYaw = FMath.Sin(halfYaw);
            float cosYaw = FMath.Cos(halfYaw);

            Quaternion r;

            r.X = (cosYaw * sinPitch * cosRoll) + (sinYaw * cosPitch * sinRoll);
            r.Y = (sinYaw * cosPitch * cosRoll) - (cosYaw * sinPitch * sinRoll);
            r.Z = (cosYaw * cosPitch * sinRoll) - (sinYaw * sinPitch * cosRoll);
            r.W = (cosYaw * cosPitch * cosRoll) + (sinYaw * sinPitch * sinRoll);

            return r;
        }

        public static Quaternion CreateFromRotationMatrix(Matrix m)
        {
            Quaternion q;

            float trace = m.M11 + m.M22 + m.M33;

            if (trace > 0.0f)
            {
                float s = FMath.Sqrt(1.0f + trace);
                float inv2s = 0.5f / s;
                q.X = (m.M23 - m.M32) * inv2s;
                q.Y = (m.M31 - m.M13) * inv2s;
                q.Z = (m.M12 - m.M21) * inv2s;
                q.W = s * 0.5f;
            }
            else if (m.M11 >= m.M22 && m.M11 >= m.M33)
            {
                float s = FMath.Sqrt(1.0f + m.M11 - m.M22 - m.M33);
                float inv2s = 0.5f / s;
                q.X = s * 0.5f;
                q.Y = (m.M12 + m.M21) * inv2s;
                q.Z = (m.M13 + m.M31) * inv2s;
                q.W = (m.M23 - m.M32) * inv2s;
            }
            else if (m.M22 > m.M33)
            {
                float s = FMath.Sqrt(1.0f - m.M11 + m.M22 - m.M33);
                float inv2s = 0.5f / s;
                q.X = (m.M21 + m.M12) * inv2s;
                q.Y = s * 0.5f;
                q.Z = (m.M32 + m.M23) * inv2s;
                q.W = (m.M31 - m.M13) * inv2s;
            }
            else
            {
                float s = FMath.Sqrt(1.0f - m.M11 - m.M22 + m.M33);
                float inv2s = 0.5f / s;
                q.X = (m.M31 + m.M13) * inv2s;
                q.Y = (m.M32 + m.M23) * inv2s;
                q.Z = s * 0.5f;
                q.W = (m.M12 - m.M21) * inv2s;
            }

            return q;
        }

        public static Quaternion Lerp(Quaternion q1, Quaternion q2, float amount)
        {
            float invAmount = 1.0f - amount;

            if (Dot(q1, q2) < 0.0f)
                amount = -amount;

            q1.X = invAmount * q1.X + amount * q2.X;
            q1.Y = invAmount * q1.Y + amount * q2.Y;
            q1.Z = invAmount * q1.Z + amount * q2.Z;
            q1.W = invAmount * q1.W + amount * q2.W;

            q1.Normalize();

            return q1;
        }

        public static float Dot(Quaternion q1, Quaternion q2)
            => q1.X * q2.X + q1.Y * q2.Y + q1.Z * q2.Z + q1.W * q2.W;

        public static Quaternion operator +(Quaternion q1, Quaternion q2)
        {
            q1.X += q2.X;
            q1.Y += q2.Y;
            q1.Z += q2.Z;
            q1.W += q2.W;

            return q1;
        }

        public static Quaternion operator -(Quaternion q1, Quaternion q2)
        {
            q1.X -= q2.X;
            q1.Y -= q2.Y;
            q1.Z -= q2.Z;
            q1.W -= q2.W;

            return q1;
        }

        public static Quaternion operator *(Quaternion q1, Quaternion q2) => new Quaternion
        {
            X = q1.X * q2.W + q1.Y * q2.Z - q1.Z * q2.Y + q1.W * q2.X,
            Y = -q1.X * q2.Z + q1.Y * q2.W + q1.Z * q2.X + q1.W * q2.Y,
            Z = q1.X * q2.Y - q1.Y * q2.X + q1.Z * q2.W + q1.W * q2.Z,
            W = -q1.X * q2.X - q1.Y * q2.Y - q1.Z * q2.Z + q1.W * q2.W,
        };

        public static Quaternion operator *(Quaternion q, float s)
        {
            q.X *= s;
            q.Y *= s;
            q.Z *= s;
            q.W *= s;

            return q;
        }

        public static bool operator ==(Quaternion q1, Quaternion q2) => q1.Equals(q2);
        public static bool operator !=(Quaternion q1, Quaternion q2) => !q1.Equals(q2);

        public static Quaternion Conjugate(Quaternion q)
        {
            q.X = -q.X;
            q.Y = -q.Y;
            q.Z = -q.Z;

            return q;
        }

        public Quaternion Inverse()
        {
            float inv = 1.0f / SquaredLength();

            Quaternion r;
            r.X = -X * inv;
            r.Y = -Y * inv;
            r.Z = -Z * inv;
            r.W = W * inv;
            return r;
        }

        public void Normalize()
        {
            float f = 1.0f / Length();

            X *= f;
            Y *= f;
            Z *= f;
            W *= f;
        }

        public float Length() => FMath.Sqrt(SquaredLength());

        public float SquaredLength() => X * X + Y * Y + Z * Z + W * W;

        public bool Equals(Quaternion other) => X == other.X && Y == other.Y && Z == other.Z && W == other.W;

        public override bool Equals(object obj) => obj is Quaternion && Equals((Quaternion)obj);

        public override int GetHashCode() => X.GetHashCode() ^ Y.GetHashCode() ^ Z.GetHashCode() ^ W.GetHashCode();

        public override string ToString() => $"{{{X} {Y} {Z} {W}}}";

        public Matrix ToMatrix()
        {
            float xx = X * X;
            float yy = Y * Y;
            float zz = Z * Z;
            float xy = X * Y;
            float zw = Z * W;
            float zx = Z * X;
            float yw = Y * W;
            float yz = Y * Z;
            float xw = X * W;

            Matrix m;

            m.M11 = 1.0f - 2.0f * (yy + zz);
            m.M12 = 2.0f * (xy + zw);
            m.M13 = 2.0f * (zx - yw);
            m.M14 = 0.0f;

            m.M21 = 2.0f * (xy - zw);
            m.M22 = 1.0f - 2.0f * (zz + xx);
            m.M23 = 2.0f * (yz + xw);
            m.M24 = 0.0f;

            m.M31 = 2.0f * (zx + yw);
            m.M32 = 2.0f * (yz - xw);
            m.M33 = 1.0f - 2.0f * (yy + xx);
            m.M34 = 0.0f;

            m.M41 = 0.0f;
            m.M42 = 0.0f;
            m.M43 = 0.0f;
            m.M44 = 1.0f;

            return m;
        }

        public Vector4 ToVector4() => new Vector4(X, Y, Z, W);

        private static readonly Quaternion identity = new Quaternion(0.0f, 0.0f, 0.0f, 1.0f);

        public static Quaternion Identity => identity;
    }
}
