Skip to content

Commit

Permalink
add blending and change mmp_fun
Browse files Browse the repository at this point in the history
  • Loading branch information
itsdfish committed Aug 21, 2021
1 parent 4163564 commit f0cfede
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ACTRModels"
uuid = "c095b0ea-a6ca-5cbd-afed-dbab2e976880"
authors = ["itsdfish"]
version = "0.7.7"
version = "0.8.0"

[deps]
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Expand Down
5 changes: 3 additions & 2 deletions src/ACTRModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,17 @@ module ACTRModels
import Distributions: pdf, logpdf
import SequentialSamplingModels: LNR
import Base: rand, match
export ACTR, Declarative, Imaginal, Chunk, BufferState, Mod
export AbstractACTR, ACTR, Declarative, Imaginal, Chunk, BufferState, Mod
export Goal, Visual, Motor, VisualLocation, Procedural, Rule
export AbstractVisualObject, VisualObject, Parms, AbstractParms
export defaultFun, LNR, reduce_data, get_buffer, set_buffer!
export default_penalty, LNR, reduce_data, get_buffer, set_buffer!
export get_chunks, update_lags!, update_recent!, update_chunk!, modify!, add_chunk!
export retrieval_prob, retrieval_probs, retrieve, compute_activation!, get_parm
export spreading_activation!, match, compute_RT, retrieval_request, get_subset
export first_chunk, posterior_predictive, find_index, find_indices, get_mean_activations
export get_visicon, get_iconic_memory
export get_time, add_time, reset_time!, rnd_time
export blend_chunks, blended_activation

include("Structs.jl")
include("MemoryFunctions.jl")
Expand Down
109 changes: 105 additions & 4 deletions src/MemoryFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ Computes activation for partial matching component
- `request...`: optional keyword arguments for retrieval request
"""
function partial_matching!(actr, chunk; request...)
p = actr.parms.mmpFun(actr, chunk; request...)
p = actr.parms.mmp_fun(actr, chunk; request...)
chunk.act_pm = p
return nothing
end
Expand Down Expand Up @@ -236,13 +236,13 @@ function spreading_activation!(actr, chunk)
imaginal = actr.imaginal
isempty(imaginal.buffer) ? (return nothing) : nothing
w = compute_weights(imaginal)
r = 0.0; sa = 0.0; γ = actr.parms.γ
γ = actr.parms.γ; r = zero(γ); sa = zero(γ)
slots = imaginal.buffer[1].slots
denoms = imaginal.denoms
for (v,d) in zip(slots, denoms)
num = count_values(chunk, v)
fan = num / (d + 1)
r = fan == 0 ? 0.0 : γ + log(fan)
r = fan == 0 ? zero(γ) : γ + log(fan)
sa += w * r
end
chunk.act_sa = sa
Expand Down Expand Up @@ -914,4 +914,105 @@ function get_parm(actr, p)
return misc[p]
end
return getfield(actr.parms, p)
end
end

"""
blend_chunks(actr, cur_time::Float64=0.0; request...)
Computes blended value over chunks given a retrieval request. By default,
values are blended over the set of slots formed by the set difference between all
slots of a chunk and the slots specified in the retrieval request. Currently, blended
is only supported for numeric slot-values.
# Arguments
- `actr`: an `ACTR` model object
- `cur_time::Float64=0.0`: current simulated time
- `request...`: optional keywords for the retrieval request
"""
function blend_chunks(actr, cur_time::Float64=0.0; request...)
blended_slots = setdiff(keys(actr.declarative.memory[1].slots), keys(request))
return blend_chunks(actr, blended_slots, cur_time; request...)
end

"""
blend_chunks(actr, cur_time::Float64=0.0; request...)
Computes blended value over chunks given a retrieval request. Values are blended
over the slots specified in `blended_slots`. Currently, blended is only supported
for numeric slot-values.
# Arguments
- `actr`: an `ACTR` model object
- `blended_slots`: a set of slots over which slot-values are blended
- `cur_time::Float64=0.0`: current simulated time
- `request...`: optional keywords for the retrieval request
"""
function blend_chunks(actr, blended_slots, cur_time=0.0; request...)
chunks = retrieval_request(actr; request...)
compute_activation!(actr, chunks, cur_time; request...)
probs = soft_max(actr, chunks)
return blend_slots(chunks, probs, blended_slots)
end

blend_slots(chunks, probs, slots) = map(s -> blend_slots(chunks, probs, s), slots)

"""
blend_slots(chunks, probs, slot::Symbol)
Computes an expected value over chunks for a specified slot.
# Arguments
- `chunks`: a set of chunks over which slot-values are blended
- `probs`: a vector of retrieval probabilities
- `slot::Symbol`: a slot over which slot-values are blended
"""
function blend_slots(chunks, probs, slot::Symbol)
v = 0.0
for (c,p) in zip(chunks, probs)
v += p * c.slots[slot]
end
return v
end

# non numeric
# blend_chunk(actr, chunks, probs, chunk) = blend_chunk(actr, chunks, probs; chunk.slots...)

# function blend_chunk(actr, chunks, probs; slots...)
# penalties = actr.parms.mmp_fun.(actr, chunks; slots...)
# return probs' * penalties
# end

function soft_max(actr, chunks)
σ = actr.parms.s * sqrt(2)
v = map(x -> exp(x.act / σ), chunks)
return v ./ sum(v)
end

"""
blended_activation(chunks)
Computes a blended activation value by exponentiating, summing and taking the
log of activations across a set of chunks.
# Arguments
- `chunks`: a set of chunks over which slot-values are blended
"""
function blended_activation(chunks)
exp_act = map(x->exp(x.act_mean), chunks)
return log(sum(exp_act))
end

"""
compute_RT(blended_act)
Computes retrieval time for a given blended activation value.
# Arguments
- `blended_act`: a blended activation value
"""
compute_RT(blended_act) = exp(-blended_act)
12 changes: 6 additions & 6 deletions src/Structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ ACT-R parameters with default values. Default values are overwritten with keywor
- `blc=0.0`: base level constant
- `δ=0.0`: mismatch penalty
- `ter=0.0`: a constant for encoding and responding time
- `mmpFun`: a mismatch penalty function. By default, `mmpFun` subtracts `δ` from each non-matching slot value
- `mmp_fun`: a mismatch penalty function. By default, `mmp_fun` subtracts `δ` from each non-matching slot value
- `sa_fun`: a function for spreading activation which requires arguments for actr and chunk
- `select_rule`: a function for selecting production rule
- `lf=1.0:` latency factor parameter
Expand All @@ -53,7 +53,7 @@ ACT-R parameters with default values. Default values are overwritten with keywor
δ
blc
ter
mmpFun
mmp_fun
sa_fun
select_rule
lf
Expand All @@ -73,7 +73,7 @@ function Parms(;
δ = 0.0,
blc = 0.0,
ter = 0.0,
mmpFun = defaultFun,
mmp_fun = default_penalty,
sa_fun = spreading_activation!,
select_rule = exact_match,
lf = 1.0,
Expand All @@ -93,7 +93,7 @@ function Parms(;
δ,
blc,
ter,
mmpFun,
mmp_fun,
sa_fun,
select_rule,
lf,
Expand Down Expand Up @@ -304,7 +304,7 @@ function Declarative(;memory=Chunk[], filtered=(:isa,:retrieved))
end

"""
defaultFun(actr, chunk; request...)
default_penalty(actr, chunk; request...)
A default function for mismatch penalty. Subtracts δ if
slot does not exist or slot value does not match
Expand All @@ -318,7 +318,7 @@ slot does not exist or slot value does not match
- `request`: a variable size collection of slot-value pairs for the retrieval request
"""
function defaultFun(actr, chunk; request...)
function default_penalty(actr, chunk; request...)
slots = chunk.slots
p = 0.0; δ = actr.parms.δ
for (k,v) in request
Expand Down
60 changes: 60 additions & 0 deletions test/Memory_Tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -466,4 +466,64 @@ using SafeTestsets
# creation time: 45.185 decay: 0.5 Optimized-learning: 1
# base-level value: -0.905594
end

@safetestset "blend_chunks" begin
using ACTRModels, Test, Random
Random.seed!(598)
chunks = [Chunk(;a=1, b=0), Chunk(;a=1, b=3)]
parms = (mmp = true, δ=1.0, noise=true, s=.2)
declarative = Declarative(;memory=chunks)
actr = ACTR(;declarative, parms...)

request = (a=2,)
blended_slots = :b
n_sim = 10_000
mean_value1 = map(_->blend_chunks(actr, blended_slots; request...), 1:n_sim) |> mean
@test mean_value1 1.5 atol = .01

chunks = [Chunk(;a=1, b=0), Chunk(;a=2, b=3)]
parms = (mmp = true, δ=1.0, noise=true, s=.2)
declarative = Declarative(;memory=chunks)
actr = ACTR(;declarative, parms...)
request = (a=2,)
blended_slots = :b
n_sim = 10_000
mean_value2 = map(_->blend_chunks(actr, blended_slots; request...), 1:n_sim) |> mean
@test mean_value1 < mean_value2

chunks = [Chunk(;a=2, b=0), Chunk(;a=1, b=3)]
parms = (mmp = true, δ=1.0, noise=true, s=.2)
declarative = Declarative(;memory=chunks)
actr = ACTR(;declarative, parms...)
request = (a=2,)
blended_slots = :b
n_sim = 10_000
mean_value3 = map(_->blend_chunks(actr, blended_slots; request...), 1:n_sim) |> mean
@test mean_value1 > mean_value3

chunks = [Chunk(;a=2, b=0), Chunk(;a=1, b=3)]
parms = (mmp = true, δ=.5, noise=true, s=.2)
declarative = Declarative(;memory=chunks)
actr = ACTR(;declarative, parms...)
request = (a=2,)
blended_slots = :b
n_sim = 10_000
mean_value4 = map(_->blend_chunks(actr, blended_slots; request...), 1:n_sim) |> mean
@test mean_value4 > mean_value3
end

@safetestset "blend_slots" begin
using ACTRModels, Test, Random
import ACTRModels: blend_slots
Random.seed!(598)
chunks = [Chunk(;a=1, b=0), Chunk(;a=1, b=3)]
parms = (mmp = true, δ=1.0, noise=true, s=.2)
declarative = Declarative(;memory=chunks)
actr = ACTR(;declarative, parms...)

blended_slots = :b
probs = [.3,.7]
v = blend_slots(chunks, probs, blended_slots)
@test v 2.1 atol = 1e-4
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ tests = [

for test in tests
include(test * ".jl")
end
end

0 comments on commit f0cfede

Please sign in to comment.