#!/usr/bin/python
#
# 	DbWrappers: C# database code generator
#
#	http://www.myelin.co.nz/dbwrappers/
#
#	Copyright (C) 2002 Phillip Pearson <pp@myelin.co.nz>
# 
# Permission is hereby granted, free of charge, to any person obtaining a 
# copy of this software and associated documentation files (the 
# "Software"), to deal in the Software without restriction, including 
# without limitation the rights to use, copy, modify, merge, publish, 
# distribute, sublicense, and/or sell copies of the Software, and to 
# permit persons to whom the Software is furnished to do so, subject to 
# the following conditions:
# 
# The above copyright notice and this permission notice shall be included 
# in all copies or substantial portions of the Software.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
#	This isn't required, but if you find this software useful, a 
#	link back to http://www.myelin.co.nz/dbwrappers from your web 
#	site or weblog would be much appreciated!
#
# See http://www.pycs.net/devlog/ for news and other developments.

from xml.dom import minidom
import string

# File extension for generated code
SUFFIX = '.cs'

# Prefix for the database classes, e.g. 'OleDb' for OLE DB, 'Sql' for MS SQL
DBPREFIX = "OleDb"

# Database type -> C# data type mapping
CSTYPES = {
	'AutoNumber': 'int',
	'Number': 'int',
	'DateTime': 'DateTime',
	'Memo': 'string',
	'Text': 'string',
	'Boolean': 'bool',
	}

# Database type -> OleDbType mapping
DBTYPES = {
	'AutoNumber': 'Integer',
	'Number': 'Integer',
	'DateTime': 'Date',
	'Memo': 'LongVarWChar',
	'Text': 'VarWChar',
	'Boolean': 'Boolean',
	}

# Database type -> DataReader.Get___() mapping
DBGETTYPES = {
	'AutoNumber': 'Int32', # use GetInt32( index ) to get an AutoNumber ...
	'Number': 'Int32',
	'DateTime': 'DateTime',
	'Memo': 'String',
	'Text': 'String',
	'Boolean': 'Boolean',
	}
	
HEADER = """/*

	This file has been automatically generated by MakeWrappers.py; don't 
	edit it directly or you will lose your changes next time it is run. 
	Edit Tables.xml instead.
	
	DbWrappers (C) 2002 Phillip Pearson, pp@myelin.co.nz
	
		http://www.myelin.co.nz/dbwrappers/
	
*/

using System;
using System.Data.<db>;

namespace <namespace>
{
"""

def fetch_attrs( node ):
	"Pulls all the attributes of a node out as a dict"
	return dict( [ ( k, v.nodeValue ) for k,v in node._attrs.items() ] )

def parse_field( field ):
	d = fetch_attrs( field )
	print "\t\tfield",d
	return d['name'], d

def parse_index( index ):
	d = fetch_attrs( index )
	print "\t\tindex",d
	return d['name'], d

dom = minidom.parse( open( 'Tables.xml', 'r' ) )

tables = []

for tablesNode in dom.childNodes:
	if tablesNode.nodeName == 'tables':
		
		print "tables:"
		
		for tableNode in tablesNode.childNodes:
			if tableNode.nodeName == 'table':
				
				table = { 'data': fetch_attrs( tableNode ),
					'fields': {},
					'indices': {},
					}
				
				print "\n\ttable:", table['data'], "\n"

				for innerNode in tableNode.childNodes:
					if innerNode.nodeName == 'fields':
						
						for fieldNode in innerNode.childNodes:
							if fieldNode.nodeName == 'field':
								name, data = parse_field( fieldNode )
								table['fields'][name] = data
						
					elif innerNode.nodeName == 'indices':
						
						for indexNode in innerNode.childNodes:
							if indexNode.nodeName == 'index':
								name, data = parse_index( indexNode )
								table['indices'][name] = data
								
				tables.append( table )
		
import pprint
pprint.pprint( tables )

class CodeWriter:

	def __init__( self, table ):
		self.table = table

		self.baseName = self.table['data']['basename']

		self.tableCode = open( self.baseName + 'Table' + SUFFIX, 'wt' )
		self.readerCode = open( self.baseName + 'Reader' + SUFFIX, 'wt' )
		self.rowCode = open( self.baseName + 'Row' + SUFFIX, 'wt' )

	def Write( self, file, text ):
		file.write(
			string.replace(
				string.replace(
					text.replace( "<base>", self.baseName ),
					'<namespace>', self.table['data']['namespace']
				),
				'<db>', DBPREFIX
			)
		)
	
	def Go( self ):
		
		for indexName, d in self.table['indices'].items():
			indexType = d['type']
			if indexType == 'pkey':
				self.pkey = indexName



		# Writing <foo>Table.cs ###############################
		w = lambda s: self.Write( self.tableCode, s )
					
		w(
HEADER + """	public enum <base>Column
	{
""" )
		txt = []
		for col in self.table['fields'].values():
			if col.has_key( 'dbtype' ):
				txt.append( """		""" + col['name'])
		w( """,
""".join( txt ) )
		w( """
	}
	
	public class <base>Table
	{
		<db>Connection connection;
		internal <db>Connection Connection
		{
			get
			{
				return this.connection;
			}
		}

		internal <base>Reader ActiveReader;		
		
		private static <base>Column[] allColumns;
		public static <base>Column[] AllColumns
		{
			get {
				if ( <base>Table.allColumns == null )
				{
					<base>Table.allColumns = new <base>Column[]
					{
""" )
		txt = []
		for col in self.table['fields'].values():
			if col.has_key( 'dbtype' ):
				txt.append(
"""						<base>Column.""" + col['name']
			)
		w( """,
""".join( txt ) )
		w( """
					};
				}
				return <base>Table.allColumns;
			}
		}
	
		public <base>Table( <db>Connection connection )
		{
			this.connection = connection;
		}
		
		// Make an SQL SELECT string out of a 'where' condition and
		// a list of columns to return
		private string MakeQuery( <base>Column[] columns, string queryWhere )
		{
			string query = "SELECT ";
			bool first = true;
			foreach ( <base>Column column in columns )
			{
				if ( first ) { first = false; }
				else { query += ", "; }
				
				switch ( column )
				{
""" )
		for col in self.table['fields'].values():
			if col.has_key( 'dbtype' ):
				w(
"""					case <base>Column.""" + col['name'] + ''':
						query += "''' + col['name'] + '''";
						break;
'''  )
		w(
"""					default:
						throw new Exception( "Invalid column" );
				}
			}
			query += " FROM """ + self.table['data']['name'] + '''";
			if ( queryWhere != null && queryWhere != "" )
			{
				query += " WHERE " + queryWhere;
			}
			
			return query;
		}
		
		// Selects a number of records out of the database, returning a
		// <base>Reader which produces <base>Row objects with the specified
		// columns populated from the database.
		public <base>Reader SelectMany( string queryWhere, <base>Column[] columns )
		{
			<base>Column[] queryColumns = ( columns == null ) ? <base>Table.AllColumns : columns;
			<db>Command cmd = new <db>Command( this.MakeQuery( queryColumns, queryWhere ), this.connection );
			return new <base>Reader( this, cmd.ExecuteReader(), queryColumns );
		}

		// Selects a single record out of the database, returning a
		// <base>Row for it.  Dies horribly if there are no matches.
		public <base>Row SelectOne( string queryWhere, <base>Column[] columns )
		{
			using ( <base>Reader r = this.SelectMany( queryWhere, columns ) )
			{
				if ( ! r.Read() ) return null;
				return r.Row;
			}
		}
			
		
		public <base>Row NewRow()
		{
			return new <base>Row( this );
		}
	}
}

''' ) # tableCode





		# Writing <foo>Reader.cs ###############################
		w = lambda s: self.Write( self.readerCode, s )

		w(
HEADER + """	public class <base>Reader : IDisposable
	{
		// Parent table
		<base>Table table;
		
		// Underlying link to the database that provides us with rows to process
		<db>DataReader reader;
		internal <db>DataReader Reader
		{
			get { return this.reader; }
		}
		
		// The columns specified in the SELECT query that produced this reader
		<base>Column[] columns;
		internal <base>Column[] Columns
		{
			get { return this.columns; }
		}
		
		internal <base>Reader( <base>Table table, <db>DataReader reader, <base>Column[] columns )
		{
			this.table = table;
			this.reader = reader;
			this.columns = columns;

			this.table.ActiveReader = this;			
		}
		
		// Advance to the next row, returning 'false' if we have run off the 
		// end of the results, otherwise 'true'.  NB: You need to call this
		// once to get the first result.
		public bool Read()
		{
			return this.reader.Read();
		}
		
		// Get the current row
		public <base>Row Row
		{
			get
			{
				return new <base>Row( this.table, this );
			}
		}
		
		public void Close()
		{
			this.Dispose();
		}
		
		public void Dispose()
		{
			this.reader.Close();
			GC.SuppressFinalize( this );
			this.table.ActiveReader = null;
		}
		
		~<base>Reader()
		{
			this.Dispose();
		}
	}
}

""" ) # tableReader






		# Writing <foo>Row.cs ###############################
		w = lambda s: self.Write( self.rowCode, s )

		self.manualFields = filter( ( lambda c: c.has_key( 'dbtype' ) and c['dbtype'] != 'AutoNumber' ), self.table['fields'].values() )
		
		if self.table['data'].has_key( 'base' ):
			parentClass = " : " + self.table['data']['base']
		else:
			parentClass = ""
		w(
HEADER + """	public class <base>Row""" + parentClass + """
	{
		// Pointer to a user-defined object (TreeNode or something)
		public object Tag;
	
		// Reference back to this row's parent table
		<base>Table table;
		
		// Flag to say if this row has already been inserted into the database
		// (used to decide whether to use an INSERT or UPDATE when committing)
		bool inDb;
		
""" )
		def findCsType( field ):
			if field.has_key( 'cstype' ):
				return field['cstype']
			else:
				return CSTYPES[ field['dbtype'] ]
			
		for field in self.table['fields'].values():
			cstype = findCsType( field )

			w(
"""		// Database field '""" + field['name'] + """'
		""" + cstype + " data_" + field['name'] + """;
""" )

			if field.has_key( 'propname' ):
				# we have a 'propname' attr: we need to write get / set methods
				if field.has_key( 'getset' ):
					# we have the 'getset' attr: just dump that text out instead of working them out
					w(
"""		public """ + cstype + " " + field['propname'] + """
		{
""" + field['getset'] + """
		}
""" )
				else:
					# no 'getset' attr: we have to work it out ourselves
					w(
"""		public """ + cstype + " " + field['propname'] + """
		{
			get { return (""" + cstype + """) this.data_""" + field['name'] + """; }
			set { this.data_""" + field['name'] + """ = value; }
		}
""" )
			elif field.has_key( 'getset' ):
				w( field['getset'] )

			w( """
""" )
		w(
"""

		// Return text representations of data for use in scripts
		public string GetItem( string key )
		{
			switch ( key )
			{
""" )

		# run through fields and make an entry for them if they have a 'scriptkey' attribute		
		for field in self.table['fields'].values():

			if field.has_key( 'scriptkey' ):
				if not field.has_key( 'propname' ):
					raise ( "Don't have a propname for scriptkey '" + field['scriptkey'] + "'" ).encode()
				if findCsType( field ) != 'string':
					typconv = ".ToString()"
				else:
					typconv = ""
				w(
'''				case "''' + field['scriptkey'] + '''":
					return this.''' + field['propname'] + typconv + """;
""" )
			
		w( """
				default:
					return "[Can't render unknown key '" + key + "']";
			}
		}
		
		// Construct a <base>Row object with data from the row pointed
		// to by the <db>DataReader
		internal <base>Row( <base>Table table, <base>Reader reader )
		{
			this.table = table;
			this.inDb = true;
			this.InitFromDb( reader );
		}

		string MungeString( string text )
		{
			if ( text == null || text.Length == 0 )
				return "?";
			if ( text[0] == '?' )
				return "?" + text;
			return text;
		}

		string UnMungeString( string munged )
		{
			if ( munged == null || munged.Length == 0 )
				return "";
			if ( munged[0] == '?' )
				return munged.Substring( 1 );
			return munged;
		}

		void InitFromDb( <base>Reader reader )
		{
			int colIdx = 0;
			foreach ( <base>Column column in reader.Columns )
			{
				if ( ! reader.Reader.IsDBNull( colIdx ) )
				{
					switch ( column )
					{
""" )
		for col in self.table['fields'].values():
			if not col.has_key( 'name' ) or not col.has_key( 'dbtype' ): continue
			w( """
						case <base>Column.""" + col['name'] + """:
							this.data_""" + col['name'] + " = " )
			if col['dbtype'] == 'Memo' or col['dbtype'] == 'Text':
				w( "UnMungeString" )
			w( "( " )
			w( """reader.Reader.Get""" + DBGETTYPES[ col['dbtype'] ] + """( colIdx ) );
							break;
""" )
		w(
'''					}
				}
				++ colIdx;
			}
		}
		
		internal <base>Row( <base>Table table )
		{
			this.table = table;
			this.inDb = false;

			// set up default values			
''' )
		w( "".join( [
"""			this.data_""" + col['name'] + """ = """ + col['default'] + """;
""" for col in filter( ( lambda h: h.has_key( 'default' ) ), self.table['fields'].values() ) ] )
		)
		w(
'''		}

		public void Reload( <base>Column[] columns )
		{
			using ( <base>Reader r = this.table.SelectMany( "''' + self.pkey + """=" + this.data_""" + self.pkey + """.ToString(), columns ) )
			{
				if ( ! r.Read() ) throw new Exception( "Can't reload <base>Row w/ primary key " + this.data_""" + self.pkey + '''.ToString() );
				this.InitFromDb( r );
			}
		}

		public <base>Row SelectAgain( <base>Column[] columns )
		{
			<base>Row ret = this.table.SelectOne( "''' + self.pkey + """=" + this.data_""" + self.pkey + """.ToString(), columns );
			ret.Tag = this.Tag;
			return ret;
		}
		
		public int Commit()
		{
			// Prepare the insert or update command
			string cmdText;
			if ( this.inDb ) {
				// Make an UPDATE command
				cmdText = "UPDATE """ + self.table['data']['name'] + """ SET """ +
				", ".join( [
					"%s=@%s" % ( col['name'], col['name'] )
					for col in self.manualFields
				] ) +
				""" WHERE """ + self.pkey + """=" + this.data_""" + self.pkey + """.ToString();
			} else {
				// Make an INSERT command
				cmdText = "INSERT INTO """ + self.table['data']['name'] + """ (""" +
				", ".join( [
					col['name'] for col in self.manualFields
				] ) + ") VALUES (" +
				", ".join( [
					"@" + col['name'] for col in self.manualFields
				] ) + ''')";
			}
			<db>Command cmd = new <db>Command( cmdText, this.table.Connection );

''' )
		w(
"""			// Put in all the values
			<db>Parameter par;
""" )
		for col in self.manualFields:
			w( 
"""			par = cmd.Parameters.Add( "@""" + col['name'] + '''", <db>Type.''' + DBTYPES[ col['dbtype'] ] + """ );
""" )
			if col['dbtype'] == 'Memo' or col['dbtype'] == 'Text':
				w(
"""			par.Value = MungeString( """ )
			else:
				w(
"""			par.Value = ( """ )
			w( """this.data_""" + col['name'] + """ );
			
""" )
		w(
"""			// Execute (return number of rows affected - should be 1)
			int ret = cmd.ExecuteNonQuery();

			if ( ! this.inDb )
			{
				// Get the generated primary key
				OleDbCommand cmdGetId = new OleDbCommand( "SELECT @@IDENTITY", this.table.Connection );
				this.data_""" + self.pkey + """ = (int) cmdGetId.ExecuteScalar();

				this.inDb = true;				
			}

			return ret;			
		}
	}
}

""" ) # tableReader
			
for table in tables:
	CodeWriter( table ).Go()

print """

DbWrappers: http://www.myelin.co.nz/dbwrappers/
	Copyright (C) 2002 Phillip Pearson <pp@myelin.co.nz>

All done!

Now you can include the generated files into your project and recompile, then 
you can use the new database classes:
"""

for table in tables:
	print ( "\t<>Table, <>Reader, <>Row\n\t\t(in <>*%s)" % ( SUFFIX, ) ).replace( '<>', table['data']['basename'] )

print """
See ExampleTables.xml and the DbWrappers web site for more information."""
