diff --git a/c_src/sqlite3_drv.c b/c_src/sqlite3_drv.c index 7706031..fa958b2 100644 --- a/c_src/sqlite3_drv.c +++ b/c_src/sqlite3_drv.c @@ -307,12 +307,12 @@ static void stop(ErlDrvData handle) { } driver_free(drv->prepared_stmts); } - + close_result = sqlite3_close(drv->db); if (close_result != SQLITE_OK) { LOG_ERROR("Failed to close DB %s, some resources aren't finalized!", drv->db_name); } - + if (drv->log && (drv->log != stderr)) { fclose(drv->log); } @@ -363,6 +363,9 @@ static ErlDrvSSizeT control( case CMD_ENABLE_LOAD_EXTENSION: enable_load_extension(drv, buf, (int) len); break; + case CMD_CHANGES: + changes(drv, buf, (int) len); + break; default: unknown(drv, buf, (int) len); } @@ -370,6 +373,26 @@ static ErlDrvSSizeT control( return 0; } +static int changes(sqlite3_drv_t *drv, char *buf, int len) { + int changes = sqlite3_changes(drv->db); + ErlDrvTermData spec[6]; + + spec[0] = ERL_DRV_PORT; + spec[1] = driver_mk_port(drv->port); + spec[2] = ERL_DRV_UINT; + spec[3] = changes; + spec[4] = ERL_DRV_TUPLE; + spec[5] = 2; + + return + #ifdef PRE_R16B + driver_output_term(drv->port, + #else + erl_drv_output_term(spec[1], + #endif + spec, sizeof(spec) / sizeof(spec[0])); +} + static int enable_load_extension(sqlite3_drv_t* drv, char *buf, int len) { #ifdef ERLANG_SQLITE3_LOAD_EXTENSION char enable = buf[0]; @@ -433,7 +456,7 @@ static inline void exec_async_command( } else { async_invoke(async_command); ready_async((ErlDrvData) drv, (ErlDrvThreadData) async_command); - } + } } static inline int sql_exec_statement( diff --git a/c_src/sqlite3_drv.h b/c_src/sqlite3_drv.h index bdf422b..45202cc 100644 --- a/c_src/sqlite3_drv.h +++ b/c_src/sqlite3_drv.h @@ -49,6 +49,7 @@ typedef int ErlDrvSSizeT; #define CMD_PREPARED_COLUMNS 11 #define CMD_SQL_EXEC_SCRIPT 12 #define CMD_ENABLE_LOAD_EXTENSION 13 +#define CMD_CHANGES 14 typedef struct ptr_list { void *head; @@ -118,6 +119,7 @@ static void sql_free_async(void *async_command); static void ready_async(ErlDrvData drv_data, ErlDrvThreadData thread_data); static int unknown(sqlite3_drv_t *bdb_drv, char *buf, int len); static int enable_load_extension(sqlite3_drv_t *drv, char *buf, int len); +static int changes(sqlite3_drv_t *drv, char *buf, int len); #if defined(_MSC_VER) #pragma warning(default: 4201) diff --git a/src/sqlite3.erl b/src/sqlite3.erl index 4094e18..f94be1b 100644 --- a/src/sqlite3.erl +++ b/src/sqlite3.erl @@ -40,6 +40,7 @@ -export([delete/2, delete/3, delete_timeout/4]). -export([drop_table/1, drop_table/2, drop_table_timeout/3]). -export([vacuum/0, vacuum/1, vacuum_timeout/2]). +-export([changes/1]). %% -export([create_function/3]). @@ -164,6 +165,18 @@ stop() -> enable_load_extension(Db, Value) -> gen_server:call(Db, {enable_load_extension, Value}). +%%-------------------------------------------------------------------- +%% @doc +%% Get affected rows. +%% @end +%%-------------------------------------------------------------------- + +changes(Db) -> + gen_server:call(Db, changes). + +changes(Db, Timeout) -> + gen_server:call(Db, changes, Timeout). + %%-------------------------------------------------------------------- %% @doc %% Executes the Sql statement directly. @@ -974,6 +987,9 @@ handle_call({enable_load_extension, _Value} = Payload, _From, State = #state{por Port, refs = _Refs}) -> Reply = exec(Port, Payload), {reply, Reply, State}; +handle_call(changes = Payload, _From, State = #state{port = Port, refs = _Refs}) -> + Reply = exec(Port, Payload), + {reply, Reply, State}; handle_call({Cmd, Ref}, _From, State = #state{port = Port, refs = Refs}) -> Reply = case dict:find(Ref, Refs) of {ok, Index} -> @@ -1033,7 +1049,7 @@ terminate(_Reason, #state{port = Port}) -> {error, permanent} -> %% Older Erlang versions mark any driver using driver_async %% as permanent - ok; + ok; {error, ErrorDesc} -> error_logger:error_msg("Error unloading sqlite3 driver: ~s~n", [erl_ddll:format_error(ErrorDesc)]) @@ -1074,6 +1090,7 @@ get_priv_dir() -> -define(PREPARED_COLUMNS, 11). -define(SQL_EXEC_SCRIPT, 12). -define(ENABLE_LOAD_EXTENSION, 13). +-define(CHANGES, 14). create_port_cmd(DbFile) -> atom_to_list(?DRIVER_NAME) ++ " " ++ DbFile. @@ -1126,6 +1143,9 @@ exec(Port, {enable_load_extension, Value}) -> end, port_control(Port, ?ENABLE_LOAD_EXTENSION, <>), wait_result(Port); +exec(Port, changes) -> + port_control(Port, ?CHANGES, <<"">>), + wait_result(Port); exec(Port, {Cmd, Index}) when is_integer(Index) -> CmdCode = case Cmd of next -> ?PREPARED_STEP; diff --git a/test.erl b/test.erl index 7fbe8f1..2309ba7 100755 --- a/test.erl +++ b/test.erl @@ -8,9 +8,11 @@ test() -> sqlite3:open(ct), sqlite3:create_table(ct, user, [{id, integer, [primary_key]}, {name, text}, {age, integer}, {wage, integer}]), [user] = sqlite3:list_tables(ct), + 0 = sqlite3:changes(ct), [{id, integer, [primary_key]}, {name, text}, {age, integer}, {wage, integer}] = sqlite3:table_info(ct, user), {rowid, Id1} = sqlite3:write(ct, user, [{name, "abby"}, {age, 20}, {wage, 2000}]), Id1 = 1, + 1 = sqlite3:changes(ct), {rowid, Id2} = sqlite3:write(ct, user, [{name, "marge"}, {age, 30}, {wage, 2000}]), Id2 = 2, [{columns, Columns}, {rows, Rows1}] = sqlite3:sql_exec(ct, "select * from user;"), diff --git a/test/sqlite3_test.erl b/test/sqlite3_test.erl index 5ad2af5..731c763 100644 --- a/test/sqlite3_test.erl +++ b/test/sqlite3_test.erl @@ -53,7 +53,8 @@ all_test_() -> ?FuncTest(large_offset), ?FuncTest(issue23), ?FuncTest(issue13), - ?FuncTest(enable_load_extension)]}. + ?FuncTest(enable_load_extension), + ?FuncTest(changes)]}. anonymous_test() -> {ok, Pid} = sqlite3:open(anonymous, []), @@ -349,6 +350,20 @@ issue13() -> enable_load_extension() -> ?assertEqual(ok, sqlite3:enable_load_extension(ct, 1)). +changes() -> + sqlite3:open(changes, [in_memory]), + sqlite3:sql_exec(changes, "CREATE TABLE person(id INTEGER);"), + ?assertEqual(0, sqlite3:changes(changes)), + {rowid, _} = sqlite3:sql_exec(changes, "INSERT INTO person (id) VALUES (1)"), + {rowid, _} = sqlite3:sql_exec(changes, "INSERT INTO person (id) VALUES (2)"), + {rowid, _} = sqlite3:sql_exec(changes, "INSERT INTO person (id) VALUES (3)"), + {rowid, _} = sqlite3:sql_exec(changes, "INSERT INTO person (id) VALUES (4)"), + {rowid, _} = sqlite3:sql_exec(changes, "INSERT INTO person (id) VALUES (5)"), + ?assertEqual(1, sqlite3:changes(changes)), + ok = sqlite3:sql_exec(changes, "UPDATE person SET id = 10"), + ?assertEqual(5, sqlite3:changes(changes)), + sqlite3:close(changes). + issue23() -> sqlite3:open(issue23, [in_memory]), ok = sqlite3:create_table(issue23, issue23, [{issue23, integer}]), @@ -362,7 +377,7 @@ non_db_file_test() -> process_flag(trap_exit, true), ?assertMatch({error, _}, sqlite3:start_link(bad_file, [{file, "/"}])), - receive + receive {'EXIT', _, _} -> ok; _ -> ?assert(false) end.