#region Arthea License /*********************************************************************** * Arthea MUD by R. Jennings (2007) http://arthea.googlecode.com/ * * By using this code you comply with the Artistic and GPLv2 Licenses. * ***********************************************************************/ #endregion using System; using System.Data.Odbc; using System.IO; using System.Reflection; using System.Text; using System.Xml; using System.Xml.Serialization; using Arthea.Database.Interfaces; using Arthea.Environment; namespace Arthea.Database { /// <summary> /// Implementation of save methods for a database. /// </summary> public struct DbHandler { #region [rgn] Fields (2) private static readonly BindingFlags bindings = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance; private static readonly string ODBCConnectString; #endregion [rgn] static DbHandler() { switch (Server.Instance.DatabaseType.ToLower()) { case "mssql": ODBCConnectString = "Driver={SQL Server};Server=" + Server.Instance.DatabaseServer + ";Trusted_Connection=yes;Database=" + Server.Instance.DatabaseName + ";"; break; case "mysql": ODBCConnectString = "DRIVER={MySQL ODBC 3.51 Driver};" + "SERVER=" + Server.Instance.DatabaseServer + ";" + "DATABASE=" + Server.Instance.DatabaseName + ";" + "UID=" + Server.Instance.DatabaseUser + ";" + "PASSWORD=" + Server.Instance.DatabasePassword + ";" + "OPTION=3"; break; case "msaccess": string path = Path.GetFullPath(Server.Instance.DatabaseName + ".mdb"); Log.Info("Connecting to database at: " + path + "..."); try { ODBCConnectString = @"Driver={Microsoft Access Driver (*.mdb)};" + @"Dbq=" + path + ";" + @"Uid=Admin;" + @"Pwd="; Log.Info("Succeeded."); } catch { Log.Fatal("Failed."); throw new Exception("Could not find database file."); } break; default: string errorstring = "Error in " + Paths.ServerConfigFile + " file. The database type option is set to an invalid string."; Log.Fatal(errorstring); throw new Exception(errorstring); } } #region [rgn] Methods (8) // [rgn] Public Methods (8) /// <summary> /// Executes the database statement. /// </summary> /// <param name="statement">The statement.</param> /// <returns>the first column of the first row</returns> public static object ExecuteScalar(string statement) { LogSQL(statement); using (OdbcConnection conn = GetConnection()) { OdbcCommand command = new OdbcCommand(); command.Connection = conn; command.CommandText = statement; return command.ExecuteScalar(); } } /// <summary> /// Execute a statement that does not return any data. /// </summary> /// <param name="statement">The statement to execute.</param> /// <returns>The number of rows affected by the query.</returns> public static int ExecuteNonQuery(string statement) { LogSQL(statement); using (OdbcConnection conn = GetConnection()) { try { OdbcCommand command = new OdbcCommand(); command.Connection = conn; command.CommandText = statement; return command.ExecuteNonQuery(); } catch (OdbcException ex) { Log.Error("ODBC EXCEPTION (Running non query): " + ex.Message + ex.StackTrace); return 0; } } } /// <summary> /// Execute a statement that has a count(*) and return the row count. If the /// results are null, then the return value is zero. /// </summary> /// <param name="statement">The statement to execute.</param> /// <returns>The number of rows in the query.</returns> public static int ExecuteRowCount(string statement) { if (!statement.Contains("count(*)")) { Log.Bug("ODBC: invalid statement for execute row count"); return 0; } object result = ExecuteScalar(statement); if (result == null) { return 0; } else { return Int32.Parse(result.ToString()); } } /// <summary> /// Existses this instance. /// </summary> /// <returns></returns> public static bool Exists(Indexed obj) { string sql = "SELECT count(*) FROM " + obj.GetType().Name + " WHERE id='" + obj.Id + "';"; return ExecuteRowCount(sql) != 0; } /// <summary> /// Open a new ODBC connection. Don't forget to close it :) /// </summary> /// <returns>An open ODBC connection.</returns> private static OdbcConnection GetConnection() { OdbcConnection conn = new OdbcConnection(ODBCConnectString); conn.Open(); if(conn.Database.ToLower() != Server.Instance.DatabaseName.ToLower()) { ExecuteNonQuery("CREATE DATABASE " + Server.Instance.DatabaseName + ";"); conn.ChangeDatabase(Server.Instance.DatabaseName); } return conn; } /// <summary> /// Inserts this instance. /// </summary> private static void Insert(Indexed obj) { StringBuilder fieldNames = new StringBuilder(); StringBuilder fieldValues = new StringBuilder(); foreach (FieldInfo fi in obj.GetType().GetFields(bindings)) { PropertyInfo pi = GetFieldProperty(obj, fi); fieldNames.Append(","); fieldNames.Append(fi.Name); fieldValues.Append(","); fieldValues.Append((pi != null) ? pi.GetValue(obj, null) : fi.GetValue(obj)); } if (fieldNames.Length > 0) fieldNames.Remove(0, 1); if (fieldValues.Length > 0) fieldValues.Remove(0, 1); string sql = "INSERT INTO " + obj.GetType().Name + " (" + fieldNames + ") VALUES (" + fieldValues + ");"; ExecuteNonQuery(sql); } private static string DbDataType(Type type) { if(type.IsEnum || type.IsArray || type.IsSubclassOf(typeof(Flag))) { return "varchar"; } switch(type.Name) { case "String": return "varchar"; case "Boolean": return "bit"; case "Byte": return "tinyint"; case "Int16": case "UInt16": return "smallint"; case "Int32": case "UInt32": return "integer"; case "UInt64": case "Int64": return "bigint"; case "Double": return "double"; case "DateTime": return "date"; case "TimeSpan": return "time"; default: throw new Exception("Unknown db data type: " + type.Name); } } private static PropertyInfo GetFieldProperty(Indexed obj, FieldInfo fi) { PropertyInfo temp = obj.GetType().GetProperty(fi.Name, bindings); if(temp != null && temp.GetCustomAttributes(typeof(XmlIgnoreAttribute), true).Length == 0) { return temp; } foreach(PropertyInfo pi in obj.GetType().GetProperties(bindings)) { foreach(XmlElement ele in pi.GetCustomAttributes(typeof (XmlElement), true)) { if (ele.Name == fi.Name) return pi; } } return null; } private static void CreateTable(Indexed obj) { StringBuilder fieldNames = new StringBuilder(); foreach (FieldInfo fi in obj.GetType().GetFields(bindings)) { PropertyInfo pi = GetFieldProperty(obj, fi); Type t = (pi != null) ? pi.PropertyType : fi.FieldType; fieldNames.Append(","); fieldNames.Append(fi.Name); fieldNames.Append(" "); fieldNames.Append(DbDataType(t)); } if (fieldNames.Length > 0) fieldNames.Remove(0, 1); // no table, create one string sql = "CREATE TABLE " + obj.GetType().Name + " (" + fieldNames + ");"; ExecuteNonQuery(sql); } /// <summary> /// Updates this instance. /// </summary> public static void Save(Indexed obj) { string sql = "SELECT count(*) FROM " + obj.GetType().Name + ";"; // check if a table for the data type exists try { ExecuteScalar(sql); } catch (OdbcException) { CreateTable(obj); } // insert if an instance of this object doesn't exist if (!Exists(obj)) { Insert(obj); return; } // otherwise update the object foreach (FieldInfo fi in obj.GetType().GetFields(bindings)) { PropertyInfo pi = obj.GetType().GetProperty(fi.Name, bindings); sql = "UPDATE " + obj.GetType().Name + " SET " + fi.Name + " = '" + ((pi != null) ? pi.GetValue(obj, null) : fi.GetValue(obj)) + "' WHERE Id = '" + obj.Id + "';"; ExecuteNonQuery(sql); } } /// <summary> /// Dumps the contents of the Odbc SQL command to the console window. /// </summary> /// <param name="sql">The SQL string.</param> private static void LogSQL(string sql) { if (Server.Instance.DatabaseLog) { Log.Info("ODBC: " + sql); } } #endregion [rgn] } }