Skip to content

Commit

Permalink
Add get changes
Browse files Browse the repository at this point in the history
  • Loading branch information
saa committed Mar 15, 2015
1 parent 32149a3 commit 84cd542
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 6 deletions.
29 changes: 26 additions & 3 deletions c_src/sqlite3_drv.c
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,12 @@ static void stop(ErlDrvData handle) {
}
driver_free(drv->prepared_stmts);
}

close_result = sqlite3_close(drv->db);
if (close_result != SQLITE_OK) {
LOG_ERROR("Failed to close DB %s, some resources aren't finalized!", drv->db_name);
}

if (drv->log && (drv->log != stderr)) {
fclose(drv->log);
}
Expand Down Expand Up @@ -363,13 +363,36 @@ static ErlDrvSSizeT control(
case CMD_ENABLE_LOAD_EXTENSION:
enable_load_extension(drv, buf, (int) len);
break;
case CMD_CHANGES:
changes(drv, buf, (int) len);
break;
default:
unknown(drv, buf, (int) len);
}
}
return 0;
}

static int changes(sqlite3_drv_t *drv, char *buf, int len) {
int changes = sqlite3_changes(drv->db);
ErlDrvTermData spec[6];

spec[0] = ERL_DRV_PORT;
spec[1] = driver_mk_port(drv->port);
spec[2] = ERL_DRV_UINT;
spec[3] = changes;
spec[4] = ERL_DRV_TUPLE;
spec[5] = 2;

return
#ifdef PRE_R16B
driver_output_term(drv->port,
#else
erl_drv_output_term(spec[1],
#endif
spec, sizeof(spec) / sizeof(spec[0]));
}

static int enable_load_extension(sqlite3_drv_t* drv, char *buf, int len) {
#ifdef ERLANG_SQLITE3_LOAD_EXTENSION
char enable = buf[0];
Expand Down Expand Up @@ -433,7 +456,7 @@ static inline void exec_async_command(
} else {
async_invoke(async_command);
ready_async((ErlDrvData) drv, (ErlDrvThreadData) async_command);
}
}
}

static inline int sql_exec_statement(
Expand Down
2 changes: 2 additions & 0 deletions c_src/sqlite3_drv.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ typedef int ErlDrvSSizeT;
#define CMD_PREPARED_COLUMNS 11
#define CMD_SQL_EXEC_SCRIPT 12
#define CMD_ENABLE_LOAD_EXTENSION 13
#define CMD_CHANGES 14

typedef struct ptr_list {
void *head;
Expand Down Expand Up @@ -118,6 +119,7 @@ static void sql_free_async(void *async_command);
static void ready_async(ErlDrvData drv_data, ErlDrvThreadData thread_data);
static int unknown(sqlite3_drv_t *bdb_drv, char *buf, int len);
static int enable_load_extension(sqlite3_drv_t *drv, char *buf, int len);
static int changes(sqlite3_drv_t *drv, char *buf, int len);

#if defined(_MSC_VER)
#pragma warning(default: 4201)
Expand Down
22 changes: 21 additions & 1 deletion src/sqlite3.erl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
-export([delete/2, delete/3, delete_timeout/4]).
-export([drop_table/1, drop_table/2, drop_table_timeout/3]).
-export([vacuum/0, vacuum/1, vacuum_timeout/2]).
-export([changes/1]).

%% -export([create_function/3]).

Expand Down Expand Up @@ -164,6 +165,18 @@ stop() ->
enable_load_extension(Db, Value) ->
gen_server:call(Db, {enable_load_extension, Value}).

%%--------------------------------------------------------------------
%% @doc
%% Get affected rows.
%% @end
%%--------------------------------------------------------------------

changes(Db) ->
gen_server:call(Db, changes).

changes(Db, Timeout) ->
gen_server:call(Db, changes, Timeout).

%%--------------------------------------------------------------------
%% @doc
%% Executes the Sql statement directly.
Expand Down Expand Up @@ -974,6 +987,9 @@ handle_call({enable_load_extension, _Value} = Payload, _From, State = #state{por
Port, refs = _Refs}) ->
Reply = exec(Port, Payload),
{reply, Reply, State};
handle_call(changes = Payload, _From, State = #state{port = Port, refs = _Refs}) ->
Reply = exec(Port, Payload),
{reply, Reply, State};
handle_call({Cmd, Ref}, _From, State = #state{port = Port, refs = Refs}) ->
Reply = case dict:find(Ref, Refs) of
{ok, Index} ->
Expand Down Expand Up @@ -1033,7 +1049,7 @@ terminate(_Reason, #state{port = Port}) ->
{error, permanent} ->
%% Older Erlang versions mark any driver using driver_async
%% as permanent
ok;
ok;
{error, ErrorDesc} ->
error_logger:error_msg("Error unloading sqlite3 driver: ~s~n",
[erl_ddll:format_error(ErrorDesc)])
Expand Down Expand Up @@ -1074,6 +1090,7 @@ get_priv_dir() ->
-define(PREPARED_COLUMNS, 11).
-define(SQL_EXEC_SCRIPT, 12).
-define(ENABLE_LOAD_EXTENSION, 13).
-define(CHANGES, 14).

create_port_cmd(DbFile) ->
atom_to_list(?DRIVER_NAME) ++ " " ++ DbFile.
Expand Down Expand Up @@ -1126,6 +1143,9 @@ exec(Port, {enable_load_extension, Value}) ->
end,
port_control(Port, ?ENABLE_LOAD_EXTENSION, <<Payload>>),
wait_result(Port);
exec(Port, changes) ->
port_control(Port, ?CHANGES, <<"">>),
wait_result(Port);
exec(Port, {Cmd, Index}) when is_integer(Index) ->
CmdCode = case Cmd of
next -> ?PREPARED_STEP;
Expand Down
2 changes: 2 additions & 0 deletions test.erl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ test() ->
sqlite3:open(ct),
sqlite3:create_table(ct, user, [{id, integer, [primary_key]}, {name, text}, {age, integer}, {wage, integer}]),
[user] = sqlite3:list_tables(ct),
0 = sqlite3:changes(ct),
[{id, integer, [primary_key]}, {name, text}, {age, integer}, {wage, integer}] = sqlite3:table_info(ct, user),
{rowid, Id1} = sqlite3:write(ct, user, [{name, "abby"}, {age, 20}, {wage, 2000}]),
Id1 = 1,
1 = sqlite3:changes(ct),
{rowid, Id2} = sqlite3:write(ct, user, [{name, "marge"}, {age, 30}, {wage, 2000}]),
Id2 = 2,
[{columns, Columns}, {rows, Rows1}] = sqlite3:sql_exec(ct, "select * from user;"),
Expand Down
19 changes: 17 additions & 2 deletions test/sqlite3_test.erl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ all_test_() ->
?FuncTest(large_offset),
?FuncTest(issue23),
?FuncTest(issue13),
?FuncTest(enable_load_extension)]}.
?FuncTest(enable_load_extension),
?FuncTest(changes)]}.

anonymous_test() ->
{ok, Pid} = sqlite3:open(anonymous, []),
Expand Down Expand Up @@ -349,6 +350,20 @@ issue13() ->
enable_load_extension() ->
?assertEqual(ok, sqlite3:enable_load_extension(ct, 1)).

changes() ->
sqlite3:open(changes, [in_memory]),
sqlite3:sql_exec(changes, "CREATE TABLE person(id INTEGER);"),
?assertEqual(0, sqlite3:changes(changes)),
{rowid, _} = sqlite3:sql_exec(changes, "INSERT INTO person (id) VALUES (1)"),
{rowid, _} = sqlite3:sql_exec(changes, "INSERT INTO person (id) VALUES (2)"),
{rowid, _} = sqlite3:sql_exec(changes, "INSERT INTO person (id) VALUES (3)"),
{rowid, _} = sqlite3:sql_exec(changes, "INSERT INTO person (id) VALUES (4)"),
{rowid, _} = sqlite3:sql_exec(changes, "INSERT INTO person (id) VALUES (5)"),
?assertEqual(1, sqlite3:changes(changes)),
ok = sqlite3:sql_exec(changes, "UPDATE person SET id = 10"),
?assertEqual(5, sqlite3:changes(changes)),
sqlite3:close(changes).

issue23() ->
sqlite3:open(issue23, [in_memory]),
ok = sqlite3:create_table(issue23, issue23, [{issue23, integer}]),
Expand All @@ -362,7 +377,7 @@ non_db_file_test() ->
process_flag(trap_exit, true),
?assertMatch({error, _},
sqlite3:start_link(bad_file, [{file, "/"}])),
receive
receive
{'EXIT', _, _} -> ok;
_ -> ?assert(false)
end.
Expand Down

0 comments on commit 84cd542

Please sign in to comment.