Skip to content

Commit

Permalink
Refactor dispersion around ScalarizedStateCriteron
Browse files Browse the repository at this point in the history
  • Loading branch information
frankier committed Oct 29, 2024
1 parent 7bf4e17 commit 038ce18
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 86 deletions.
6 changes: 4 additions & 2 deletions src/next_item_rules/NextItemRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import ForwardDiff

export ExpectationBasedItemCriterion, AbilityVarianceStateCriterion, init_thread
export NextItemRule, ItemStrategyNextItemRule
export UrryItemCriterion, InformationItemCriterion, DRuleItemCriterion, TRuleItemCriterion
export UrryItemCriterion, InformationItemCriterion
export RandomNextItemRule
export ExhaustiveSearch1Ply
export catr_next_item_aliases
Expand All @@ -43,7 +43,8 @@ export compute_criteria
export PointResponseExpectation, DistributionResponseExpectation
export MatrixScalarizer, DeterminantScalarizer, TraceScalarizer
export AbilityCovarianceStateCriteria, StateCriteria, ItemCriteria
export ScalarizedStateCriteron
export InformationMatrixCriteria
export ScalarizedStateCriteron, ScalarizedItemCriteron

"""
$(TYPEDEF)
Expand Down Expand Up @@ -72,6 +73,7 @@ end

include("./random.jl")
include("./information.jl")
include("./information_special.jl")
include("./objective_function.jl")
include("./expectation.jl")

Expand Down
8 changes: 7 additions & 1 deletion src/next_item_rules/aliases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ const mirtcat_next_item_aliases = Dict(
# 'MEPV' for minimum expected posterior variance
"MEPV" => _mirtcat_helper((bits, ability_estimator) -> ExpectationBasedItemCriterion(
ability_estimator,
AbilityVarianceStateCriterion(bits...)))
AbilityVarianceStateCriterion(bits...))),
"Drule" => _mirtcat_helper((bits, ability_estimator) -> ScalarizedItemCriteron(
InformationMatrixCriteria(ability_estimator),
DeterminantScalarizer())),
"Trule" => _mirtcat_helper((bits, ability_estimator) -> ScalarizedItemCriteron(
InformationMatrixCriteria(ability_estimator),
TraceScalarizer()))
)

# 'MLWI' for maximum likelihood weighted information
Expand Down
37 changes: 22 additions & 15 deletions src/next_item_rules/information.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@ using FittedItemBanks: CdfMirtItemBank,
using FittedItemBanks: inner_item_response, norm_abil, y_offset, irf_size
using StatsFuns: logaddexp

function log_resp_vec(ir::ItemResponse{<:TransferItemBank}, θ)
= norm_abil(ir, θ)
return SVector(
logccdf(ir.item_bank.distribution, nθ),
logcdf(ir.item_bank.distribution, nθ)
)
end

function log_resp(ir::ItemResponse{<:TransferItemBank}, resp, θ)
logcdf(ir.item_bank.distribution, norm_abil(ir, θ))
end

function log_resp_vec(ir::ItemResponse{<:CdfMirtItemBank}, θ)
= norm_abil(ir, θ)
SVector(logccdf(ir.item_bank.distribution, nθ),
Expand Down Expand Up @@ -52,26 +64,21 @@ function log_resp(ir::ItemResponse{<:AnySlipOrGuessItemBank}, val, θ)
log_transform_irf_y(ir, val, log_resp(inner_item_response(ir), val, θ))
end

# How does this compare with expected_item_information. Speeds/accuracies?
# TODO: Which response models is this valid for?
# TODO: Citation/source for this equation
# TODO: Do it in log space?
function item_information(ir::ItemResponse, θ)
# irθ_prime = ForwardDiff.derivative(ir, θ)
irθ_prime = ForwardDiff.derivative(x -> resp(ir, x), θ)
irθ = resp(ir, θ)
if irθ_prime == 0.0
return 0.0
else
return (irθ_prime * irθ_prime) / (irθ * (1 - irθ))
end
end

function vector_hessian(f, x, n)
out = ForwardDiff.jacobian(x -> ForwardDiff.jacobian(f, x), x)
return reshape(out, n, n, n)
end

function double_derivative(f, x)
ForwardDiff.derivative(x -> ForwardDiff.derivative(f, x), x)
end

function expected_item_information(ir::ItemResponse, θ::Float64)
exp_resp = resp_vec(ir, θ)
= double_derivative((θ -> log_resp_vec(ir, θ)), θ)
-sum(exp_resp .* d²)
end

# TODO: Unclear whether this should be implemented with ExpectationBasedItemCriterion
# TODO: This is not implementing DRule but postposterior DRule
function expected_item_information(ir::ItemResponse, θ::Vector{Float64})
Expand Down
65 changes: 65 additions & 0 deletions src/next_item_rules/information_special.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#=
This file contains some specialised ways to calculate information.
For some models analytical solutions are possible for information.
Most are simple applications of the chain rule
However, I haven't taken a systematic approach yet yet.
So these are just from equations in the literature.
There aren't really any type guards on these so its up to the caller to make sure they are using the right ones.
=#

function alt_expected_1d_item_information(ir::ItemResponse, θ)
"""
This is a special case of the expected_item_information function for
* 1-dimensional ability
* Dichotomous items
* It should be valid for at least up to the 3PL model, probably others too
TODO: citation
"""
# irθ_prime = ForwardDiff.derivative(ir, θ)
irθ_prime = ForwardDiff.derivative(x -> resp(ir, x), θ)
irθ = resp(ir, θ)
if irθ_prime == 0.0
return 0.0
else
return (irθ_prime * irθ_prime) / (irθ * (1 - irθ))
end
end

function alt_expected_mirt_item_information(ir::ItemResponse, θ)
"""
This is a special case of the expected_item_information function for
* Multidimensional
* Dichotomous items
* It should be valid for at least up to the 3PL model, probably others too
TODO: citation
"""
irθ_prime = ForwardDiff.gradient(x -> resp(ir, x), θ)
= resp(ir, θ)
= 1 -
(irθ_prime * irθ_prime') / (pθ * qθ)
end

function alt_expected_mirt_3pl_item_information(ir::ItemResponse, θ)
"""
This is a special case of the expected_item_information function for
* Multidimensional
* Dichotomous items
* 3PL model only
Mulder J, van der Linden WJ.
Multidimensional Adaptive Testing with Optimal Design Criteria for Item Selection.
Psychometrika. 2009 Jun;74(2):273-296. doi: 10.1007/s11336-008-9097-5.
Equation 4
"""
# XXX: Should avoid using item_params
params = item_params(ir.item_bank.discriminations, ir.index)
= resp(ir, θ)
= 1 -
a = params.discrimination
c = params.guess
common_factor = (qθ * (pθ - c)^2) / (pθ * (1 - c)^2)
common_factor * (a * a')
end
67 changes: 67 additions & 0 deletions src/next_item_rules/mirt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,70 @@ function (ssc::ScalarizedStateCriteron)(tracked_responses)
end
res
end

struct InformationMatrixCriteria{AbilityEstimatorT <: AbilityEstimator, F} <: ItemCriteria
ability_estimator::AbilityEstimatorT
expected_item_information::F
end

function InformationMatrixCriteria(ability_estimator)
InformationMatrixCriteria(ability_estimator, expected_item_information)
end

function init_thread(item_criterion::InformationMatrixCriteria,
responses::TrackedResponses)
# TODO: No need to do this one per thread. It just need to be done once per
# θ update.
# TODO: Update this to use track!(...) mechanism
ability = maybe_tracked_ability_estimate(responses, item_criterion.ability_estimator)
responses_information(responses.item_bank, responses.responses, ability)
end

function (item_criterion::InformationMatrixCriteria)(acc_info::Matrix{Float64},
tracked_responses::TrackedResponses,
item_idx)
# TODO: Add in information from the prior
ability = maybe_tracked_ability_estimate(
tracked_responses, item_criterion.ability_estimator)
return acc_info .+
item_criterion.expected_item_information(
ItemResponse(tracked_responses.item_bank, item_idx), ability)
end

should_minimize(::InformationMatrixCriteria) = false

struct ScalarizedItemCriteron{
ItemCriteriaT <: ItemCriteria,
MatrixScalarizerT <: MatrixScalarizer
} <: ItemCriterion
criteria::ItemCriteriaT
scalarizer::MatrixScalarizerT
end

function (ssc::ScalarizedItemCriteron)(tracked_responses, item_idx)
res = ssc.criteria(
init_thread(ssc.criteria, tracked_responses), tracked_responses, item_idx) |>
ssc.scalarizer
if !should_minimize(ssc.criteria)
res = -res
end
res
end

struct WeightedStateCriteria{InnerT <: StateCriteria} <: StateCriteria
weights::Vector{Float64}
criteria::InnerT
end

function (wsc::WeightedStateCriteria)(tracked_responses, item_idx)
wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights
end

struct WeightedItemCriteria{InnerT <: ItemCriteria} <: ItemCriteria
weights::Vector{Float64}
criteria::InnerT
end

function (wsc::WeightedItemCriteria)(tracked_responses, item_idx)
wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights
end
76 changes: 8 additions & 68 deletions src/next_item_rules/objective_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,79 +126,19 @@ function (item_criterion::UrryItemCriterion)(tracked_responses::TrackedResponses
end

# TODO: Should have Variants for point ability versus distribution ability
struct InformationItemCriterion{AbilityEstimatorT <: PointAbilityEstimator} <: ItemCriterion
struct InformationItemCriterion{AbilityEstimatorT <: PointAbilityEstimator, F} <: ItemCriterion
ability_estimator::AbilityEstimatorT
expected_item_information::F
end

function InformationItemCriterion(ability_estimator)
InformationItemCriterion(ability_estimator, expected_item_information)
end

function (item_criterion::InformationItemCriterion)(tracked_responses::TrackedResponses,
item_idx)
ability = maybe_tracked_ability_estimate(tracked_responses,
item_criterion.ability_estimator)
ir = ItemResponse(tracked_responses.item_bank, item_idx)
return -item_information(ir, ability)
end

abstract type InformationMatrixCriterion <: ItemCriterion end

function init_thread(item_criterion::InformationMatrixCriterion,
responses::TrackedResponses)
# TODO: No need to do this one per thread. It just need to be done once per
# θ update.
# TODO: Update this to use track!(...) mechanism
ability = maybe_tracked_ability_estimate(responses, item_criterion.ability_estimator)
responses_information(responses.item_bank, responses.responses, ability)
end

function information_matrix(ability_estimator,
acc_info,
tracked_responses::TrackedResponses,
item_idx)
# TODO: Add in information from the prior
ability = maybe_tracked_ability_estimate(tracked_responses, ability_estimator)
acc_info .+
expected_item_information(ItemResponse(tracked_responses.item_bank, item_idx), ability)
end

struct DRuleItemCriterion{AbilityEstimatorT <: PointAbilityEstimator} <:
InformationMatrixCriterion
ability_estimator::AbilityEstimatorT
end

function (item_criterion::DRuleItemCriterion)(acc_info::Matrix{Float64},
tracked_responses::TrackedResponses,
item_idx)
-det(information_matrix(item_criterion.ability_estimator,
acc_info,
tracked_responses,
item_idx))
end

# TODO: Weighted version
struct TRuleItemCriterion{AbilityEstimatorT <: PointAbilityEstimator} <:
InformationMatrixCriterion
ability_estimator::AbilityEstimatorT
end

function (item_criterion::TRuleItemCriterion)(acc_info::Matrix{Float64},
tracked_responses,
item_idx)
# XXX: Should not strictly need to calculate whole information matrix to get this.
# Should just be able to calculate Laplacians as we go, but ForwardDiff doesn't support this (yet?).
-tr(information_matrix(item_criterion.ability_estimator,
acc_info,
tracked_responses,
item_idx))
end

struct ARuleItemCriterion{AbilityEstimatorT <: PointAbilityEstimator} <: ItemCriterion
ability_estimator::AbilityEstimatorT
end

function (item_criterion::ARuleItemCriterion)(acc_info::Nothing,
tracked_responses,
item_idx)
# TODO
# Step 1. Get covariance of ability estimate
# Basically the same idea as AbilityVarianceStateCriterion
# Step 2. Get the (weighted) trace
end
return -item_criterion.expected_item_information(ir, ability)
end
13 changes: 13 additions & 0 deletions test/ability_estimator_2d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ end
end

@testcase "2 dim information higher closer to current estimate" begin
information_matrix_criteria = InformationMatrixCriteria(mle_mean_2d)
information_criterion = ScalarizedItemCriteron(
information_matrix_criteria, DeterminantScalarizer())

# Item closer to the current estimate (1, 1)
close_item = 5
# Item further from the current estimate
far_item = 6

close_info = information_criterion(tracked_responses_2d, close_item)
far_info = information_criterion(tracked_responses_2d, far_item)

@test close_info > far_info
end

@testcase "2 dim variance smaller closer to current estimate" begin
Expand Down

0 comments on commit 038ce18

Please sign in to comment.