diff --git a/lib/feeb/db/schema.ex b/lib/feeb/db/schema.ex index ce57e5c..4e802b9 100644 --- a/lib/feeb/db/schema.ex +++ b/lib/feeb/db/schema.ex @@ -74,10 +74,20 @@ defmodule Feeb.DB.Schema do |> Map.new() |> Map.keys() + after_read_fields = + Enum.reduce(normalized_schema, [], fn {field, {_, opts, _}}, acc -> + if after_read = opts[:after_read] do + [{field, after_read} | acc] + else + acc + end + end) + @schema normalized_schema @modded_fields modded_fields @sorted_cols sorted_cols @virtual_cols virtual_cols + @after_read_fields after_read_fields if is_nil(Module.get_attribute(__MODULE__, :derived_fields)) do @derived_fields [] @@ -90,6 +100,7 @@ defmodule Feeb.DB.Schema do def __schema__, do: @schema def __cols__, do: @sorted_cols def __virtual_cols__, do: @virtual_cols + def __after_read_fields__, do: @after_read_fields def __table__, do: @table def __context__, do: @context def __modded_fields__, do: @modded_fields @@ -260,6 +271,7 @@ defmodule Feeb.DB.Schema do schema = model.__schema__() table_fields = model.__cols__() virtual_fields = model.__virtual_cols__() + after_read_fields = model.__after_read_fields__() fields_to_populate = if fields == [:*], do: table_fields, else: fields # TODO: Test this a lot... @@ -290,9 +302,9 @@ defmodule Feeb.DB.Schema do values = fields_to_populate |> Enum.zip(row) - |> Enum.map(fn {field, v} -> + |> Enum.map(fn {field, raw_value} -> {type_module, opts, _mod} = Map.fetch!(schema, field) - {field, type_module.load!(v, opts, {model, field})} + {field, type_module.load!(raw_value, opts, {model, field})} end) model @@ -300,6 +312,7 @@ defmodule Feeb.DB.Schema do |> Map.put(:__meta__, %{origin: :db}) |> add_missing_values(table_fields, fields_to_populate) |> add_virtual_fields(virtual_fields, schema) + |> trigger_after_read_callbacks(after_read_fields) end def cast_value!(schema_mod, schema, field, raw_value) do @@ -376,4 +389,14 @@ defmodule Feeb.DB.Schema do Map.put(acc, field_name, value) end) end + + defp trigger_after_read_callbacks(struct, []), do: struct + + defp trigger_after_read_callbacks(struct, after_read_fields) do + Enum.reduce(after_read_fields, struct, fn {field, callback}, acc -> + old_value = Map.get(struct, field) + new_value = apply(struct.__struct__, callback, [old_value, struct]) + Map.put(acc, field, new_value) + end) + end end diff --git a/priv/test/migrations/test/241020150201_friends.sql b/priv/test/migrations/test/241020150201_friends.sql index 7e58000..461919c 100644 --- a/priv/test/migrations/test/241020150201_friends.sql +++ b/priv/test/migrations/test/241020150201_friends.sql @@ -1,4 +1,5 @@ CREATE TABLE friends ( id INTEGER PRIMARY KEY, - name TEXT + name TEXT, + sibling_count INTEGER ) STRICT, WITHOUT ROWID; diff --git a/test/db/schema_test.exs b/test/db/schema_test.exs index 564b271..992bbe1 100644 --- a/test/db/schema_test.exs +++ b/test/db/schema_test.exs @@ -29,7 +29,7 @@ defmodule DB.SchemaTest do describe "generated: __cols__/0" do test "includes all non-virtual fields in order" do - assert [:id, :name] == Friend.__cols__() + assert [:id, :name, :sibling_count] == Friend.__cols__() assert [ :boolean, @@ -123,4 +123,18 @@ defmodule DB.SchemaTest do assert monica.repo_config == expected_repo_config end end + + describe "after_read" do + test "columns with after_read are post-processed", %{shard_id: shard_id} do + DB.begin(@context, shard_id, :read) + + joey = DB.one({:friends, :get_by_name}, "Joey") + rachel = DB.one({:friends, :get_by_name}, "Rachel") + pheebs = DB.one({:friends, :get_by_name}, "Phoebe") + + assert joey.sibling_count == 7 + assert rachel.sibling_count == 2 + assert pheebs.sibling_count == 1 + end + end end diff --git a/test/db/sqlite_test.exs b/test/db/sqlite_test.exs index 5b7ad61..9403177 100644 --- a/test/db/sqlite_test.exs +++ b/test/db/sqlite_test.exs @@ -36,7 +36,7 @@ defmodule Feeb.DB.SQLiteTest do test "returns an entry if found", %{c: c} do {:ok, stmt} = SQLite.prepare(c, "SELECT * FROM friends WHERE id = ?") :ok = SQLite.bind(c, stmt, [1]) - assert {:ok, [1, "Phoebe"]} == SQLite.one(c, stmt) + assert {:ok, [1, "Phoebe", nil]} == SQLite.one(c, stmt) end test "returns nil if not found", %{c: c} do diff --git a/test/support/db/schemas/friend.ex b/test/support/db/schemas/friend.ex index 2cc960d..1517e4b 100644 --- a/test/support/db/schemas/friend.ex +++ b/test/support/db/schemas/friend.ex @@ -10,6 +10,7 @@ defmodule Sample.Friend do {:id, :integer}, {:name, :string}, {:divorce_count, {:integer, virtual: :get_divorce_count}}, + {:sibling_count, {:integer, nullable: true, after_read: :get_sibling_count}}, {:repo_config, {:map, virtual: :get_repo_config}} ] @@ -38,4 +39,17 @@ defmodule Sample.Friend do 0 end end + + def get_sibling_count(_, %{name: name}) do + case name do + "Joey" -> + 7 + + "Rachel" -> + 2 + + _ -> + 1 + end + end end