Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorganise the next item rules into subdirectories #74

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 27 additions & 162 deletions src/next_item_rules/NextItemRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,167 +46,32 @@ export AbilityCovarianceStateCriteria, StateCriteria, ItemCriteria
export InformationMatrixCriteria
export ScalarizedStateCriteron, ScalarizedItemCriteron

"""
$(TYPEDEF)

Abstract base type for all item selection rules. All descendants of this type
are expected to implement the interface
`(rule::NextItemRule)(responses::TrackedResponses, items::AbstractItemBank)::Int`

$(FUNCTIONNAME)(bits...; ability_estimator=nothing, parallel=true)

Implicit constructor for $(FUNCTIONNAME). Uses any given `NextItemRule` or
delegates to `ItemStrategyNextItemRule`.
"""
abstract type NextItemRule <: CatConfigBase end

function NextItemRule(bits...;
ability_estimator = nothing,
ability_tracker = nothing,
parallel = true)
@returnsome find1_instance(NextItemRule, bits)
@returnsome ItemStrategyNextItemRule(bits...,
ability_estimator = ability_estimator,
ability_tracker = ability_tracker,
parallel = parallel)
end

include("./random.jl")
include("./information.jl")
include("./information_special.jl")
include("./objective_function.jl")
include("./expectation.jl")

const default_prior = IntegralCoeffs.Prior(Cauchy(5, 2))

function choose_item_1ply(objective::ItemCriterionT,
responses::TrackedResponseT,
items::AbstractItemBank)::Tuple{
Int,
Float64
} where {ItemCriterionT <: ItemCriterion, TrackedResponseT <: TrackedResponses}
#pre_next_item(expectation_tracker, items)
objective_state = init_thread(objective, responses)
min_obj_idx::Int = -1
min_obj_val::Float64 = Inf
for item_idx in eachindex(items)
# TODO: Add these back in
#@init irf_states_storage = zeros(Int, length(responses) + 1)
if (findfirst(idx -> idx == item_idx, responses.responses.indices) !== nothing)
continue
end

obj_val = objective(objective_state, responses, item_idx)

if obj_val <= min_obj_val
min_obj_val = obj_val
min_obj_idx = item_idx
end
end
return (min_obj_idx, min_obj_val)
end

function init_thread(::ItemCriterion, ::TrackedResponses)
nothing
end

"""
$(TYPEDEF)
"""
abstract type NextItemStrategy <: CatConfigBase end

function NextItemStrategy(; parallel = true)
ExhaustiveSearch(parallel)
end

function NextItemStrategy(bits...; parallel = true)
@returnsome find1_instance(NextItemStrategy, bits)
@returnsome find1_type(NextItemStrategy, bits) typ->typ(; parallel = parallel)
@returnsome NextItemStrategy(; parallel = parallel)
end

"""
$(TYPEDEF)
$(TYPEDFIELDS)

"""
@with_kw struct ExhaustiveSearch <: NextItemStrategy
parallel::Bool = false
end

"""
$(TYPEDEF)
$(TYPEDFIELDS)

`ItemStrategyNextItemRule` which together with a `NextItemStrategy` acts as an
adapter by which an `ItemCriterion` can serve as a `NextItemRule`.

$(FUNCTIONNAME)(bits...; ability_estimator=nothing, parallel=true)

Implicit constructor for $(FUNCTIONNAME). Will default to
`ExhaustiveSearch` when no `NextItemStrategy` is given.
"""
struct ItemStrategyNextItemRule{
NextItemStrategyT <: NextItemStrategy,
ItemCriterionT <: ItemCriterion
} <: NextItemRule
strategy::NextItemStrategyT
criterion::ItemCriterionT
end

function ItemStrategyNextItemRule(bits...;
parallel = true,
ability_estimator = nothing,
ability_tracker = nothing)
strategy = NextItemStrategy(bits...; parallel = parallel)
criterion = ItemCriterion(bits...;
ability_estimator = ability_estimator,
ability_tracker = ability_tracker)
if strategy !== nothing && criterion !== nothing
return ItemStrategyNextItemRule(strategy, criterion)
end
end

function (rule::ItemStrategyNextItemRule{ExhaustiveSearch, ItemCriterionT})(responses,
items) where {ItemCriterionT <: ItemCriterion}
#, rule.strategy.parallel
choose_item_1ply(rule.criterion, responses, items)[1]
end

function (item_criterion::ItemCriterion)(::Nothing, tracked_responses, item_idx)
item_criterion(tracked_responses, item_idx)
end

function (item_criterion::ItemCriterion)(tracked_responses, item_idx)
criterion_state = init_thread(item_criterion, tracked_responses)
if criterion_state === nothing
error("Tried to run an state-requiring item criterion $(typeof(item_criterion)), but init_thread(...) returned nothing")
end
item_criterion(criterion_state, tracked_responses, item_idx)
end

function compute_criteria(
criterion::ItemCriterionT,
responses::TrackedResponseT,
items::AbstractItemBank
) where {ItemCriterionT <: ItemCriterion, TrackedResponseT <: TrackedResponses}
objective_state = init_thread(criterion, responses)
return [criterion(objective_state, responses, item_idx)
for item_idx in eachindex(items)]
end

function compute_criteria(
rule::ItemStrategyNextItemRule{StrategyT, ItemCriterionT},
responses,
items
) where {StrategyT, ItemCriterionT <: ItemCriterion}
compute_criteria(rule.criterion, responses, items)
end

include("./mirt.jl")
include("./aliases.jl")
include("./preallocate.jl")

include("./ka.jl")
# Prelude
include("./prelude/abstract.jl")
include("./prelude/next_item_rule.jl")
include("./prelude/strategy.jl")
include("./prelude/criteria.jl")
include("./prelude/preallocate.jl")

# Selection strategies
include("./strategies/random.jl")
include("./strategies/exhaustive.jl")

# Combinators
include("./combinators/expectation.jl")
include("./combinators/scalarizers.jl")

# Criteria
include("./criteria/item/information_special.jl")
include("./criteria/item/information_support.jl")
include("./criteria/item/information.jl")
include("./criteria/item/urry.jl")
include("./criteria/state/ability_variance.jl")

# Porcelain
include("./porcelain/aliases.jl")

# Experimental
include("./experimental/ka.jl")

end
57 changes: 57 additions & 0 deletions src/next_item_rules/combinators/scalarizers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
struct DeterminantScalarizer <: MatrixScalarizer end
(::DeterminantScalarizer)(mat) = det(mat)

struct TraceScalarizer <: MatrixScalarizer end
(::TraceScalarizer)(mat) = tr(mat)

struct ScalarizedItemCriteron{
ItemCriteriaT <: ItemCriteria,
MatrixScalarizerT <: MatrixScalarizer
} <: ItemCriterion
criteria::ItemCriteriaT
scalarizer::MatrixScalarizerT
end

function (ssc::ScalarizedItemCriteron)(tracked_responses, item_idx)
res = ssc.criteria(
init_thread(ssc.criteria, tracked_responses), tracked_responses, item_idx) |>
ssc.scalarizer
if !should_minimize(ssc.criteria)
res = -res
end
res
end

struct ScalarizedStateCriteron{
StateCriteriaT <: StateCriteria,
MatrixScalarizerT <: MatrixScalarizer
} <: StateCriterion
criteria::StateCriteriaT
scalarizer::MatrixScalarizerT
end

function (ssc::ScalarizedStateCriteron)(tracked_responses)
res = ssc.criteria(tracked_responses) |> ssc.scalarizer
if !should_minimize(ssc.criteria)
res = -res
end
res
end

struct WeightedStateCriteria{InnerT <: StateCriteria} <: StateCriteria
weights::Vector{Float64}
criteria::InnerT
end

function (wsc::WeightedStateCriteria)(tracked_responses, item_idx)
wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights
end

struct WeightedItemCriteria{InnerT <: ItemCriteria} <: ItemCriteria
weights::Vector{Float64}
criteria::InnerT
end

function (wsc::WeightedItemCriteria)(tracked_responses, item_idx)
wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights
end
49 changes: 49 additions & 0 deletions src/next_item_rules/criteria/item/information.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# TODO: Should have Variants for point ability versus distribution ability
struct InformationItemCriterion{AbilityEstimatorT <: PointAbilityEstimator, F} <:
ItemCriterion
ability_estimator::AbilityEstimatorT
expected_item_information::F
end

function InformationItemCriterion(ability_estimator)
InformationItemCriterion(ability_estimator, expected_item_information)
end

function (item_criterion::InformationItemCriterion)(tracked_responses::TrackedResponses,
item_idx)
ability = maybe_tracked_ability_estimate(tracked_responses,
item_criterion.ability_estimator)
ir = ItemResponse(tracked_responses.item_bank, item_idx)
return -item_criterion.expected_item_information(ir, ability)
end

struct InformationMatrixCriteria{AbilityEstimatorT <: AbilityEstimator, F} <: ItemCriteria
ability_estimator::AbilityEstimatorT
expected_item_information::F
end

function InformationMatrixCriteria(ability_estimator)
InformationMatrixCriteria(ability_estimator, expected_item_information)
end

function init_thread(item_criterion::InformationMatrixCriteria,
responses::TrackedResponses)
# TODO: No need to do this one per thread. It just need to be done once per
# θ update.
# TODO: Update this to use track!(...) mechanism
ability = maybe_tracked_ability_estimate(responses, item_criterion.ability_estimator)
responses_information(responses.item_bank, responses.responses, ability)
end

function (item_criterion::InformationMatrixCriteria)(acc_info::Matrix{Float64},
tracked_responses::TrackedResponses,
item_idx)
# TODO: Add in information from the prior
ability = maybe_tracked_ability_estimate(
tracked_responses, item_criterion.ability_estimator)
return acc_info .+
item_criterion.expected_item_information(
ItemResponse(tracked_responses.item_bank, item_idx), ability)
end

should_minimize(::InformationMatrixCriteria) = false
22 changes: 22 additions & 0 deletions src/next_item_rules/criteria/item/urry.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
$(TYPEDEF)
$(TYPEDFIELDS)

This item criterion just picks the item with the raw difficulty closest to the
current ability estimate.
"""
struct UrryItemCriterion{AbilityEstimatorT <: PointAbilityEstimator} <: ItemCriterion
ability_estimator::AbilityEstimatorT
end

# TODO: Slow + poor error handling
function raw_difficulty(item_bank, item_idx)
item_params(item_bank, item_idx).difficulty
end

function (item_criterion::UrryItemCriterion)(tracked_responses::TrackedResponses, item_idx)
ability = maybe_tracked_ability_estimate(tracked_responses,
item_criterion.ability_estimator)
diff = raw_difficulty(tracked_responses.item_bank, item_idx)
abs(ability - diff)
end
Loading
Loading