/* 
   This file is part of libodbc++.
   
   Copyright (C) 1999 Manush Dodunekov <manush@litecom.net>
   
   This library is free software; you can redistribute it and/or
   modify it under the terms of the GNU Library General Public
   License as published by the Free Software Foundation; either
   version 2 of the License, or (at your option) any later version.
   
   This library is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   Library General Public License for more details.
   
   You should have received a copy of the GNU Library General Public License
   along with this library; see the file COPYING.  If not, write to
   the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
   Boston, MA 02111-1307, USA.
*/

#include <odbc++/statement.h>
#include <odbc++/resultset.h>
#include <odbc++/connection.h>
#include "driverinfo.h"

#include "dtconv.h"

using namespace odbc;
using namespace std;

Statement::Statement(Connection* con, SQLHSTMT hstmt,
		     int resultSetType, int resultSetConcurrency)
  :connection_(con), 
   hstmt_(hstmt),
   currentResultSet_(NULL),
   fetchSize_(SQL_ROWSET_SIZE_DEFAULT),
   resultSetType_(resultSetType),
   resultSetConcurrency_(resultSetConcurrency),
   state_(STATE_CLOSED)
{
  try {
    
    this->_applyResultSetType();
    
  } catch(...) {
    // avoid a statement handle leak (the destructor will not be called)
#if ODBCVER < 0x0300
    SQLFreeStmt(hstmt_,SQL_DROP);
#else
    SQLFreeHandle(SQL_HANDLE_STMT,hstmt_);
#endif
    throw;
  }
}

Statement::~Statement()
{
  if(currentResultSet_!=NULL) {
    currentResultSet_->ownStatement_=false;
    delete currentResultSet_;
    currentResultSet_=NULL;
  }
  
#if ODBCVER < 0x0300
    SQLFreeStmt(hstmt_,SQL_DROP);
#else
    SQLFreeHandle(SQL_HANDLE_STMT,hstmt_);
#endif
  
  connection_->_unregisterStatement(this);
}


//private
void Statement::_registerResultSet(ResultSet* rs)
{
  assert(currentResultSet_==NULL);
  currentResultSet_=rs;
}

void Statement::_unregisterResultSet(ResultSet* rs)
{
  assert(currentResultSet_==rs);
  currentResultSet_=NULL;
}


//protected
SQLUINTEGER Statement::_getNumericOption(SQLINTEGER optnum)
{
  SQLUINTEGER res;
  SQLRETURN r;

#if ODBCVER < 0x0300

  r=SQLGetStmtOption(hstmt_,optnum,(SQLPOINTER)&res);

#else

  SQLINTEGER dummy;
  r=SQLGetStmtAttr(hstmt_,optnum,(SQLPOINTER)&res,SQL_IS_UINTEGER,&dummy);

#endif
  
  this->_checkStmtError(hstmt_,r,"Error fetching numeric statement option");
  
  return res;
}

//protected
void Statement::_setNumericOption(SQLINTEGER optnum, SQLUINTEGER value)
{
  SQLRETURN r;

#if ODBCVER < 0x0300

  r=SQLSetStmtOption(hstmt_,optnum,value);

#else
  

  r=SQLSetStmtAttr(hstmt_,optnum,(SQLPOINTER)value,SQL_IS_UINTEGER);

#endif

  this->_checkStmtError(hstmt_,r,"Error setting numeric statement option");
}

//protected
string Statement::_getStringOption(SQLINTEGER optnum)
{
  SQLRETURN r;
#if ODBCVER < 0x0300

  char buf[SQL_MAX_OPTION_STRING_LENGTH+1];

  r=SQLGetStmtOption(hstmt_,optnum,(SQLPOINTER)buf);

  this->_checkStmtError(hstmt_,r,"Error fetching string statement option");

#else
  
  char buf[256];
  SQLINTEGER dataSize;
  r=SQLGetStmtAttr(hstmt_,optnum,(SQLPOINTER)buf,255,&dataSize);
  this->_checkStmtError(hstmt_,r,"Error fetching string statement option");
  
  if(dataSize>255) {
    // we have a longer attribute here
    char* tmp=new char[dataSize+1];
    odbc::Deleter<char> _tmp(tmp,true);
    
    r=SQLGetStmtAttr(hstmt_,optnum,(SQLPOINTER)buf,dataSize,&dataSize);
    
    this->_checkStmtError(hstmt_,r,"Error fetching string statement option");
    return string(tmp);
  }
#endif
  
  
  return string(buf);
}

//protected
void Statement::_setStringOption(SQLINTEGER optnum, const string& value)
{
  SQLRETURN r;

#if ODBCVER < 0x0300

  r=SQLSetStmtOption(hstmt_,optnum,(SQLUINTEGER)value.c_str());
  
#else
  
  r=SQLSetStmtAttr(hstmt_,optnum,(SQLPOINTER)value.c_str(),value.length());
  
#endif
  
  this->_checkStmtError(hstmt_,r,"Error setting string statement option");
}  


#if ODBCVER >= 0x0300

SQLPOINTER Statement::_getPointerOption(SQLINTEGER optnum)
{
  SQLPOINTER ret;
  SQLINTEGER len;
  SQLRETURN r=SQLGetStmtAttr(hstmt_,optnum,(SQLPOINTER)&ret,
			     SQL_IS_POINTER,&len);
  this->_checkStmtError(hstmt_,r,"Error fetching pointer statement option");
			     
  return ret;
}


void Statement::_setPointerOption(SQLINTEGER optnum, SQLPOINTER value)
{
  SQLRETURN r=SQLSetStmtAttr(hstmt_,optnum,value,SQL_IS_POINTER);

  this->_checkStmtError(hstmt_,r,"Error setting pointer statement option");
}


#endif



void Statement::_applyResultSetType()
{
  // In general, for a scrollable cursor we set the type to 
  // dynamic - it will fall back to what the driver supports

  int cm=this->_getDriverInfo()->getCursorMask();
  
  // don't do anything for read-only concurrency
  switch(resultSetType_) {
  case ResultSet::TYPE_FORWARD_ONLY:
    // do nothing, this is the default
    break;

  case ResultSet::TYPE_SCROLL_INSENSITIVE:
    if((cm&SQL_SO_STATIC)!=0) {
      this->_setNumericOption
	(ODBC3_C(SQL_ATTR_CURSOR_TYPE,SQL_CURSOR_TYPE),
	 SQL_CURSOR_STATIC);
    } else {
      throw SQLException
	("[libodbc++]: Data source does not support insensitive cursors");
    }
    break;

  case ResultSet::TYPE_SCROLL_SENSITIVE:
    if((cm&SQL_SO_DYNAMIC)!=0) {
      this->_setNumericOption
	(ODBC3_C(SQL_ATTR_CURSOR_TYPE,SQL_CURSOR_TYPE),
	 SQL_CURSOR_DYNAMIC);
    } else if((cm&SQL_SO_KEYSET_DRIVEN)!=0) {
      this->_setNumericOption
	(ODBC3_C(SQL_ATTR_CURSOR_TYPE,SQL_CURSOR_TYPE),
	 SQL_CURSOR_KEYSET_DRIVEN);
    } else {
      throw SQLException
	("[libodbc++]: Data source does not support sensitive cursors");
    }
    break;

  default:
    throw SQLException
      ("[libodbc++]: Internal error: Invalid ResultSet type "+
       intToString(resultSetType_));
  }


  switch(resultSetConcurrency_) {
  case ResultSet::CONCUR_READ_ONLY:
    break;
    
  case ResultSet::CONCUR_UPDATABLE:
    this->_setNumericOption
      (ODBC3_C(SQL_ATTR_CONCURRENCY,SQL_CONCURRENCY),
       SQL_CONCUR_LOCK);
    break;

  default:
    throw SQLException("[libodbc++]: Invalid concurrency level "+
		       intToString(resultSetConcurrency_));
  }

}


//protected
bool Statement::_checkForResults()
{
  SQLSMALLINT nc;
  SQLRETURN r=SQLNumResultCols(hstmt_,&nc);
  return r==SQL_SUCCESS && nc>0;
}

//protected
ResultSet* Statement::_getResultSet(bool hideMe)
{
  ResultSet* rs=new ResultSet(this,hstmt_,hideMe);
  this->_registerResultSet(rs);
  return rs;
}


//protected
void Statement::_beforeExecute()
{
  this->clearWarnings();

  if(currentResultSet_!=NULL) {
    throw SQLException
      ("[libodbc++]: Cannot re-execute; statement has an open resultset");
  }

  if(state_==STATE_OPEN) {
    SQLRETURN r=SQLFreeStmt(hstmt_,SQL_CLOSE);
    this->_checkStmtError(hstmt_,r,"Error closing statement");
    
    state_=STATE_CLOSED;
  }
}

//protected
void Statement::_afterExecute()
{
  state_=STATE_OPEN;
}

//private catalog stuff
//this statement should be hidden behind a ResultSet
//but since it can be obtained with ResultSet->getStatement()
//we still track the state (before/afterExecute).

inline SQLCHAR* valueOrNull(const string& str)
{
  return (SQLCHAR*)(str.length()>0?str.data():NULL);
}

ResultSet* Statement::_getTypeInfo()
{
  this->_beforeExecute();

  SQLRETURN r=SQLGetTypeInfo(hstmt_,SQL_ALL_TYPES);
  this->_checkStmtError(hstmt_,r,"Error fetching type information");

  this->_afterExecute();

  ResultSet* rs=this->_getResultSet(true);

  return rs;
}


ResultSet* Statement::_getColumns(const string& catalog,
				  const string& schema,
				  const string& tableName,
				  const string& columnName)
{
  this->_beforeExecute();
  SQLRETURN r=SQLColumns(hstmt_,
			 valueOrNull(catalog),
			 catalog.length(),
			 valueOrNull(schema),
			 schema.length(),
			 valueOrNull(tableName),
			 tableName.length(),
			 valueOrNull(columnName),
			 columnName.length());

  this->_checkStmtError(hstmt_,r,"Error fetching column information");

  ResultSet* rs=this->_getResultSet(true);

  return rs;
}


ResultSet* Statement::_getTables(const string& catalog,
				 const string& schema,
				 const string& tableName,
				 const string& types)
{
  this->_beforeExecute();
  SQLRETURN r=SQLTables(hstmt_,
			valueOrNull(catalog),
			catalog.length(),
			valueOrNull(schema),
			schema.length(),
			valueOrNull(tableName),
			tableName.length(),
			(SQLCHAR*)types.data(),
			types.length());

  this->_checkStmtError(hstmt_,r,"Error fetching table information");
  
  this->_afterExecute();

  ResultSet* rs=this->_getResultSet(true);


  return rs;
}


ResultSet* Statement::_getTablePrivileges(const string& catalog,
					  const string& schema,
					  const string& tableName)
{
  this->_beforeExecute();
  SQLRETURN r=SQLTablePrivileges(hstmt_,
				 valueOrNull(catalog),
				 catalog.length(),
				 valueOrNull(schema),
				 schema.length(),
				 (SQLCHAR*)tableName.data(),
				 tableName.length());
  
  this->_checkStmtError(hstmt_,r,"Error fetching table privileges information");
  
  this->_afterExecute();
  
  ResultSet* rs=this->_getResultSet(true);


  return rs;
}

ResultSet* Statement::_getColumnPrivileges(const string& catalog,
					   const string& schema,
					   const string& tableName,
					   const string& columnName)
{
  this->_beforeExecute();
  SQLRETURN r=SQLColumnPrivileges(hstmt_,
				  valueOrNull(catalog),
				  catalog.length(),
				  valueOrNull(schema),
				  schema.length(),
				  (SQLCHAR*)tableName.data(),
				  tableName.length(),
				  (SQLCHAR*)columnName.data(),
				  columnName.length());
  
  this->_checkStmtError(hstmt_,r,"Error fetching column privileges information");
  
  this->_afterExecute();
  
  ResultSet* rs=this->_getResultSet(true);


  return rs;
}



ResultSet* Statement::_getPrimaryKeys(const string& catalog,
				      const string& schema,
				      const string& tableName)
{
  this->_beforeExecute();
  SQLRETURN r=SQLPrimaryKeys(hstmt_,
			     valueOrNull(catalog),
			     catalog.length(),
			     valueOrNull(schema),
			     schema.length(),
			     (SQLCHAR*)tableName.data(),
			     tableName.length());
  
  this->_checkStmtError(hstmt_,r,"Error fetching primary keys information");
  
  this->_afterExecute();
  
  ResultSet* rs=this->_getResultSet(true);


  return rs;
}


ResultSet* Statement::_getCrossReference(const string& pc,
					 const string& ps,
					 const string& pt,
					 const string& fc,
					 const string& fs,
					 const string& ft)
{
  this->_beforeExecute();

  SQLRETURN r=SQLForeignKeys(hstmt_,
			     valueOrNull(pc),pc.length(),
			     valueOrNull(ps),ps.length(),
			     (SQLCHAR*)pt.data(),pt.length(),
			     valueOrNull(fc),fc.length(),
			     valueOrNull(fs),fs.length(),
			     (SQLCHAR*)ft.data(),ft.length());

  this->_checkStmtError(hstmt_,r,"Error fetching foreign keys information");

  this->_afterExecute();

  ResultSet* rs=this->_getResultSet(true);
  
  return rs;
}



ResultSet* Statement::_getIndexInfo(const string& catalog,
				    const string& schema,
				    const string& tableName,
				    bool unique, bool approximate)
{
  this->_beforeExecute();
  SQLRETURN r=SQLStatistics(hstmt_,
			    valueOrNull(catalog),
			    catalog.length(),
			    valueOrNull(schema),
			    schema.length(),
			    (SQLCHAR*)tableName.data(),
			    tableName.length(),
			    unique?SQL_INDEX_UNIQUE:SQL_INDEX_ALL,
			    approximate?SQL_QUICK:SQL_ENSURE);
  
  this->_checkStmtError(hstmt_,r,"Error fetching index information");
  
  this->_afterExecute();
  
  ResultSet* rs=this->_getResultSet(true);
  
  return rs;
}


ResultSet* Statement::_getProcedures(const string& catalog,
				     const string& schema,
				     const string& procName)
{
  this->_beforeExecute();
  SQLRETURN r=SQLProcedures(hstmt_,
			    valueOrNull(catalog),catalog.length(),
			    valueOrNull(schema),schema.length(),
			    (SQLCHAR*)procName.data(),procName.length());

  this->_checkStmtError(hstmt_,r,"Error fetching procedures information");

  ResultSet* rs=this->_getResultSet(true);
  
  return rs;
}


ResultSet* Statement::_getProcedureColumns(const string& catalog,
					   const string& schema,
					   const string& procName,
					   const string& colName)
{
  this->_beforeExecute();
  SQLRETURN r=SQLProcedureColumns(hstmt_,
				  valueOrNull(catalog),catalog.length(),
				  valueOrNull(schema),schema.length(),
				  (SQLCHAR*)procName.data(),procName.length(),
				  (SQLCHAR*)colName.data(),colName.length());

  this->_checkStmtError(hstmt_,r,"Error fetching procedures information");
  
  ResultSet* rs=this->_getResultSet(true);
  
  return rs;
}

ResultSet* Statement::_getSpecialColumns(const string& catalog,
					 const string& schema,
					 const string& table,
					 int what, int scope,
					 int nullable)
{
  this->_beforeExecute();
  SQLRETURN r=SQLSpecialColumns(hstmt_,what,
				valueOrNull(catalog),catalog.length(),
				valueOrNull(schema),schema.length(),
				(SQLCHAR*)table.data(),table.length(),
				scope,nullable);
  this->_checkStmtError(hstmt_,r,"Error fetching special columns");

  ResultSet* rs=this->_getResultSet(true);

  return rs;
}



Connection* Statement::getConnection()
{
  return connection_;
}


int Statement::getQueryTimeout()
{
  return this->_getNumericOption
    (ODBC3_C(SQL_ATTR_QUERY_TIMEOUT,SQL_QUERY_TIMEOUT));
}


void Statement::setQueryTimeout(int seconds)
{
  this->_setNumericOption
    (ODBC3_C(SQL_ATTR_QUERY_TIMEOUT,SQL_QUERY_TIMEOUT),seconds);
}


int Statement::getMaxRows()
{
  return this->_getNumericOption
    (ODBC3_C(SQL_ATTR_MAX_ROWS,SQL_MAX_ROWS));
}


void Statement::setMaxRows(int maxRows)
{
  this->_setNumericOption
    (ODBC3_C(SQL_ATTR_MAX_ROWS,SQL_MAX_ROWS),maxRows);
}

int Statement::getMaxFieldSize()
{
  return this->_getNumericOption
    (ODBC3_C(SQL_ATTR_MAX_LENGTH,SQL_MAX_LENGTH));
}

void Statement::setMaxFieldSize(int maxFieldSize)
{
  this->_setNumericOption
    (ODBC3_C(SQL_ATTR_MAX_LENGTH,SQL_MAX_LENGTH),maxFieldSize);
}


void Statement::cancel()
{
  SQLRETURN r=SQLCancel(hstmt_);
  this->_checkStmtError(hstmt_,r,"Error canceling statement");
}


int Statement::getUpdateCount()
{
  SQLINTEGER res;
  SQLRETURN r=SQLRowCount(hstmt_,&res);
  this->_checkStmtError(hstmt_,r,"Error fetching update count");
  return res;
}


void Statement::setCursorName(const string& name)
{
  SQLRETURN r=SQLSetCursorName(hstmt_,
			     (SQLCHAR*)name.data(),
			     name.length());
  this->_checkStmtError(hstmt_,r,"Error setting cursor name");
}


bool Statement::execute(const string& sql)
{
  
  this->_beforeExecute();

  SQLRETURN r=SQLExecDirect(hstmt_,(SQLCHAR*)sql.data(),sql.length());

  string msg="Error executing \""+sql+"\"";

  this->_checkStmtError(hstmt_,r,msg.c_str());

  this->_afterExecute();

  return this->_checkForResults();
}

ResultSet* Statement::executeQuery(const string& sql)
{
  this->execute(sql);
  return this->_getResultSet();
}


int Statement::executeUpdate(const string& sql)
{
  this->execute(sql);
  return this->getUpdateCount();
}

ResultSet* Statement::getResultSet()
{
  if(this->_checkForResults()) {
    return this->_getResultSet();
  }
  return NULL;
}


bool Statement::getMoreResults()
{
  SQLRETURN r=SQLMoreResults(hstmt_);
  this->_checkStmtError(hstmt_,r,"Error checking for more results");
  
  if(r==SQL_SUCCESS || r==SQL_SUCCESS_WITH_INFO) {
    return true;
  }
  return false;
}


void Statement::setFetchSize(int fs)
{
  if(fs>0) {
    fetchSize_=fs;
  } else if (fs==0) {
    fetchSize_=SQL_ROWSET_SIZE_DEFAULT;
  } else {
    throw SQLException("Invalid fetch size");
  }
}


void Statement::setEscapeProcessing(bool on)
{
  this->_setNumericOption
    (ODBC3_C(SQL_ATTR_NOSCAN,SQL_NOSCAN),on?SQL_NOSCAN_OFF:SQL_NOSCAN_ON);
}

bool Statement::getEscapeProcessing()
{
  return this->_getNumericOption
    (ODBC3_C(SQL_ATTR_NOSCAN,SQL_NOSCAN))==SQL_NOSCAN_OFF;
}


#if 0
void Statement::addBatch(const std::string& sql)
{
  batches_.push_back(sql);
}


void Statement::clearBatch()
{
  batches_.clear();
}


// the big question here is literally: what the heck do we do?
// this should return a set of update counts for each batched
// command; what if the command is a SELECT?
// another question: do we send this as a ';'-separated string
// or do we call ExecDirect multiple times?

void Statement::executeBatch()
{
  if(batches_.size()==0) {
    throw SQLException
      ("[libodbc++]: nothing to execute");
  }

  string q;
  unsigned int cnt=0;
  for(vector<string>::iterator i=batches_.begin();
      i!=batches_.end(); i++) {
    if(cnt>0) {
      q+="; ";
    }
    q+=*i;
    cnt++;
  }

  this->execute(q);

  vector<int>* res=new vector<int>();

  return res;
}

#endif
