/
etc/
lib/
src/Abilities/
src/Abilities/Skills/
src/Abilities/Spells/
src/Abilities/Spells/Enums/
src/Affects/
src/ArtheaConsole/
src/ArtheaConsole/Properties/
src/ArtheaGUI/Properties/
src/Clans/Enums/
src/Commands/Communication/
src/Commands/ItemCommands/
src/Connections/
src/Connections/Colors/
src/Connections/Enums/
src/Connections/Players/
src/Connections/Players/Enums/
src/Continents/
src/Continents/Areas/
src/Continents/Areas/Characters/
src/Continents/Areas/Characters/Enums/
src/Continents/Areas/Items/
src/Continents/Areas/Items/Enums/
src/Continents/Areas/Rooms/
src/Continents/Areas/Rooms/Enums/
src/Continents/Areas/Rooms/Exits/
src/Creation/
src/Creation/Attributes/
src/Creation/Interfaces/
src/Database/
src/Database/Interfaces/
src/Environment/
src/Properties/
src/Scripts/Enums/
src/Scripts/Interfaces/
#region Arthea License

/***********************************************************************
*  Arthea MUD by R. Jennings (2007)      http://arthea.googlecode.com/ *
*  By using this code you comply with the Artistic and GPLv2 Licenses. *
***********************************************************************/

#endregion

using System;
using System.Data.Odbc;
using System.IO;
using System.Reflection;
using System.Text;
using System.Xml;
using System.Xml.Serialization;
using Arthea.Database.Interfaces;
using Arthea.Environment;

namespace Arthea.Database
{
    /// <summary>
    /// Implementation of save methods for a database.
    /// </summary>
    public struct DbHandler
    {
        #region [rgn] Fields (2)

        private static readonly BindingFlags bindings = BindingFlags.Public |
                                                 BindingFlags.NonPublic | BindingFlags.Instance;

        private static readonly string ODBCConnectString;

        #endregion [rgn]

        static DbHandler()
        {
            switch (Server.Instance.DatabaseType.ToLower())
            {               
                case "mssql":
                    ODBCConnectString = "Driver={SQL Server};Server=" + Server.Instance.DatabaseServer +
                                        ";Trusted_Connection=yes;Database=" + Server.Instance.DatabaseName + ";";
                    break;

                case "mysql":
                    ODBCConnectString = "DRIVER={MySQL ODBC 3.51 Driver};" +
                                        "SERVER=" + Server.Instance.DatabaseServer + ";" +
                                        "DATABASE=" + Server.Instance.DatabaseName + ";" +
                                        "UID=" + Server.Instance.DatabaseUser + ";" +
                                        "PASSWORD=" + Server.Instance.DatabasePassword + ";" +
                                        "OPTION=3";
                    break;
                case "msaccess":
                    string path = Path.GetFullPath(Server.Instance.DatabaseName + ".mdb");
                    Log.Info("Connecting to database at: " + path + "...");
                    try
                    {
                        ODBCConnectString = @"Driver={Microsoft Access Driver (*.mdb)};" +
                                            @"Dbq=" + path + ";" +
                                            @"Uid=Admin;" +
                                            @"Pwd=";
                        Log.Info("Succeeded.");
                    }
                    catch
                    {
                        Log.Fatal("Failed.");
                        throw new Exception("Could not find database file.");
                    }
                    break;
                default:
                    string errorstring = "Error in " + Paths.ServerConfigFile +
                                         " file. The database type option is set to an invalid string.";
                    Log.Fatal(errorstring);
                    throw new Exception(errorstring);
            }    
        }

        #region [rgn] Methods (8)

        // [rgn] Public Methods (8)

        /// <summary>
        /// Executes the database statement.
        /// </summary>
        /// <param name="statement">The statement.</param>
        /// <returns>the first column of the first row</returns>
        public static object ExecuteScalar(string statement)
        {
            LogSQL(statement);
            using (OdbcConnection conn = GetConnection())
            {
                OdbcCommand command = new OdbcCommand();
                command.Connection = conn;
                command.CommandText = statement;
                return command.ExecuteScalar();
            }
        }

        /// <summary>
        /// Execute a statement that does not return any data.
        /// </summary>
        /// <param name="statement">The statement to execute.</param>
        /// <returns>The number of rows affected by the query.</returns>
        public static int ExecuteNonQuery(string statement)
        {
            LogSQL(statement);
            using (OdbcConnection conn = GetConnection())
            {
                try
                {
                    OdbcCommand command = new OdbcCommand();
                    command.Connection = conn;
                    command.CommandText = statement;
                    return command.ExecuteNonQuery();
                }
                catch (OdbcException ex)
                {
                    Log.Error("ODBC EXCEPTION (Running non query): " + ex.Message + ex.StackTrace);
                    return 0;
                }
            }
        }

        /// <summary>
        /// Execute a statement that has a count(*) and return the row count. If the 
        /// results are null, then the return value is zero.
        /// </summary>
        /// <param name="statement">The statement to execute.</param>
        /// <returns>The number of rows in the query.</returns>
        public static int ExecuteRowCount(string statement)
        {
            if (!statement.Contains("count(*)"))
            {
                Log.Bug("ODBC: invalid statement for execute row count");
                return 0;
            }

            object result = ExecuteScalar(statement);
            if (result == null)
            {
                return 0;
            }
            else
            {
                return Int32.Parse(result.ToString());
            }
        }

        /// <summary>
        /// Existses this instance.
        /// </summary>
        /// <returns></returns>
        public static bool Exists(Indexed obj)
        {
            string sql = "SELECT count(*) FROM " + obj.GetType().Name + " WHERE id='" + obj.Id + "';";

            return ExecuteRowCount(sql) != 0;
        }

        /// <summary>
        /// Open a new ODBC connection. Don't forget to close it :)
        /// </summary>
        /// <returns>An open ODBC connection.</returns>
        private static OdbcConnection GetConnection()
        {
            OdbcConnection conn = new OdbcConnection(ODBCConnectString);
            conn.Open();
            if(conn.Database.ToLower() != Server.Instance.DatabaseName.ToLower())
            {
                ExecuteNonQuery("CREATE DATABASE " + Server.Instance.DatabaseName + ";");
                conn.ChangeDatabase(Server.Instance.DatabaseName);
            }
            return conn;
        }

        /// <summary>
        /// Inserts this instance.
        /// </summary>
        private static void Insert(Indexed obj)
        {
            StringBuilder fieldNames = new StringBuilder();
            StringBuilder fieldValues = new StringBuilder();

            foreach (FieldInfo fi in obj.GetType().GetFields(bindings))
            {
                PropertyInfo pi = GetFieldProperty(obj, fi);

                fieldNames.Append(",");
                fieldNames.Append(fi.Name);

                fieldValues.Append(",");
                fieldValues.Append((pi != null) ? pi.GetValue(obj, null) : fi.GetValue(obj));
            }

            if (fieldNames.Length > 0)
                fieldNames.Remove(0, 1);
            if (fieldValues.Length > 0)
                fieldValues.Remove(0, 1);

            string sql = "INSERT INTO " + obj.GetType().Name + " (" + fieldNames +
                         ") VALUES (" + fieldValues + ");";

            ExecuteNonQuery(sql);
        }

        private static string DbDataType(Type type)
        {
            if(type.IsEnum || type.IsArray || type.IsSubclassOf(typeof(Flag)))
            {
                return "varchar";
            }

            switch(type.Name)
            {
                case "String":
                    return "varchar";
                case "Boolean":
                    return "bit";
                case "Byte":
                    return "tinyint";
                case "Int16":
                case "UInt16":
                    return "smallint";
                case "Int32":
                case "UInt32":
                    return "integer";
                case "UInt64":
                case "Int64":
                    return "bigint";
                case "Double":
                    return "double";
                case "DateTime":
                    return "date";
                case "TimeSpan":
                    return "time";
                default:
                    throw new Exception("Unknown db data type: " + type.Name);
            }
        }

        private static PropertyInfo GetFieldProperty(Indexed obj, FieldInfo fi)
        {
            PropertyInfo temp = obj.GetType().GetProperty(fi.Name, bindings);

            if(temp != null && temp.GetCustomAttributes(typeof(XmlIgnoreAttribute), true).Length == 0)
            {
                return temp;
            }

            foreach(PropertyInfo pi in obj.GetType().GetProperties(bindings))
            {
                foreach(XmlElement ele in pi.GetCustomAttributes(typeof (XmlElement), true))
                {
                    if (ele.Name == fi.Name)
                        return pi;
                }
            }

            return null;
        }

        private static void CreateTable(Indexed obj)
        {
            StringBuilder fieldNames = new StringBuilder();

            foreach (FieldInfo fi in obj.GetType().GetFields(bindings))
            {
                PropertyInfo pi = GetFieldProperty(obj, fi);
                Type t = (pi != null) ? pi.PropertyType : fi.FieldType;
                fieldNames.Append(",");
                fieldNames.Append(fi.Name);
                fieldNames.Append(" ");
                fieldNames.Append(DbDataType(t));
            }

            if (fieldNames.Length > 0)
                fieldNames.Remove(0, 1);

            // no table, create one
            string sql = "CREATE TABLE " + obj.GetType().Name + " (" + fieldNames + ");";
            ExecuteNonQuery(sql);            
        }
        /// <summary>
        /// Updates this instance.
        /// </summary>
        public static void Save(Indexed obj)
        {
            string sql = "SELECT count(*) FROM " + obj.GetType().Name + ";";

            // check if a table for the data type exists
            try
            {
                ExecuteScalar(sql);
            }
            catch (OdbcException)
            {
                CreateTable(obj);
            }

            // insert if an instance of this object doesn't exist
            if (!Exists(obj))
            {
                Insert(obj);
                return;
            }

            // otherwise update the object
            foreach (FieldInfo fi in obj.GetType().GetFields(bindings))
            {
                PropertyInfo pi = obj.GetType().GetProperty(fi.Name, bindings);
                sql = "UPDATE " + obj.GetType().Name + " SET " + fi.Name + " = '" +
                             ((pi != null) ? pi.GetValue(obj, null) : fi.GetValue(obj)) +
                             "' WHERE Id = '" + obj.Id + "';";
                ExecuteNonQuery(sql);
            }
        }


        /// <summary>
        /// Dumps the contents of the Odbc SQL command to the console window.
        /// </summary>
        /// <param name="sql">The SQL string.</param>
        private static void LogSQL(string sql)
        {
            if (Server.Instance.DatabaseLog)
            {
                Log.Info("ODBC: " + sql);
            }
        }
        #endregion [rgn]
    }
}